期待情報利得計算の変分ベイズ法の適用について

こんにちは、InsightEdgeでデータ分析をしている新見です。 今日はUberAIからNeurIPS2019に出た 論文 を紹介します。

この論文では、ベイズ最適実験計画(Bayesian Optimal Experimental Design:BOED)における期待情報利得(Expected Information Gain:EIG)の推定を高速かつ正確に行う新しい手法を提案しています。 BOEDは、限られた実験リソースを効率的に活用するための枠組みですが、EIGの正確な推定が難しいという課題があります。今回紹介する論文では、変分推論の手法を用いることで、高速化を達成しています。 本記事では提案されたEIGの近似式を紹介し、その使用例をコードベースで紹介します。

もくじ

ベイズ最適実験計画(BOED)とは

一般的に、何かしらの仮説を立証したいときは実験をし、その結果をデータとして取得します。ただし、実験には、時間的金銭的なコストが発生する場合があるので、そのようなケースではなるべく最小の実施回数で仮説を立証できることが望ましいです。このように、効率的な実験条件を設計すること実験計画法と呼びます。たとえば、ある化合物の収量を最大化するための条件の調整や心理学実験における適切な質問の作成等が代表的な例として挙げられます。また、機械学習の枠組みにおいても、最適なハイパーパラメータの探索にこの考え方が応用されています。

ベイズ最適実験計画(BOED)は、立証したい仮説をデータ生成プロセスとしてモデリングを行い、シミュレーションした結果得られる情報が最大となるような実験を実施する方法になります。 具体的には、次式で与えられるような期待情報利得(EIG)を最大化するような実験を設計します。変数$\theta$を目標変数とした時、EIGは実験$d$の結果$y$として $$ \text{EIG}(d)=\mathbb{E}_{p(y|d)}[\text{H}[p(\theta)] - \text{H}[p(\theta|y,d)]] $$のように書けます。$\text{H}[\cdot]$はエントロピーを表し、$p(\theta|y,d) \propto p(y|\theta,d)p(\theta)$は実験$d$の結果$y$を得た時の$\theta$の事後分布を表します。 式の意味は、実験とその結果によってどの程度知りたい変数の不確定性が減少したか、その得られた情報の期待値となります。これを最大とするような実験$d^{*}$を選択することがBOEDの目的になります。

しかしながら、一般的にEIGの計算は困難です。事後分布$p(\theta|y,d)$の評価に入れ子となったモンテカルロ積分(NMC)が必要で(この後すぐ説明)、高い計算コストになります。NMC計算の収束速度は、計算量を$T$とした時、$O(T^{-1/3})$となります。 今回提案する近似手法の収束速度は$O(T^{-1/2})$で、NMCよりも良い性質を持つことが証明されています。

EIGの近似指標

EIGの式をさらに書き変えると、 $$ \text{EIG}(d) = \mathbb{E}_{p(y,\theta|d)}\left[\log \frac{p(\theta|y,d)}{p(\theta)}\right] = \mathbb{E}_{p(y,\theta|d)}\left[\log \frac{p(y|\theta,d)}{p(y|d)}\right] $$となります。 EIGを先ほど述べたモンテカルロ(NMC)で計算する場合の推定式$\hat{\mu}_{\text{NMC}}$は次のように書けます。 $$ \hat{\mu}_{\text{NMC}} = \frac{1}{N} \sum_{n=1}^{N} \log \frac{p(y_{n}|\theta_{n},d)}{\frac{1}{M}\sum_{m=1}^{M}p(y_{n}|\theta_{n,m}, d)} \\ \text{where} \space \theta_{n,m} \sim p(\theta) \space y_{n} \sim p(y|\theta=\theta_{n,0},d) $$上式の計算量は$T=\mathcal{O}(NM)$で、収束速度は$O(T^{-1/3})$です。提案手法では、期待値中の分布関数を近似式で変分的に評価します。

Variational Posterior $\hat{\mu}_{\text{post}}$

事後分布$p(\theta|y,d)$を変分分布$q_{p}(\theta|y,d)$で近似することを考えます。この時、近似EIGは真のEIGの下限となり、その推定量は$\hat{\mu}_{\text{post}}$は $$ \text{EIG}(d) \ge \mathcal{L}_{\text{post}}(d) = \mathbb{E}_{p(y,\theta|d)}\left[\log \frac{q_{p}(\theta|y,d)}{p(\theta)}\right] \\ \approx \hat{\mu}_{\text{post}} = \frac{1}{N} \sum_{n=1}^{N} \log \frac{q_{p}(\theta_{n}|y,d)}{p(\theta_{n})} $$と表せます。近似分布をパラメトリックに書けると仮定し、以下のSGDを用いて、$\mathcal{L}_{\text{post}}(d) $の最大値を求めます。 $$ \nabla_{\phi}\mathcal{L}_{\text{post}}(d;\phi) \approx \frac{1}{S} \sum_{i=1}^{S} \nabla_{\phi}\log q_{p}(\theta_{i}|y,d;\phi) \space \text{where} \space y_{i}, \theta_{i} \sim p(y,\theta|d) $$ここで、パラメータの更新回数を$K$回とすると、$\hat{\mu}_{\text{post}}$の計算量は$T=\mathcal{O}(SK+N)$で、収束速度は$O(T^{-1/2})$となります。収束の証明は論文を参照してください。

Variational marginal $\hat{\mu}_{\text{marg}}$

次は、周辺分布$p(y|d)$を$q_{m}(y|d)$で近似することを考えます。この時、近似式はEIGの上限となります。 $$ \text{EIG}(d) \le \mathcal{L}_{\text{marg}}(d) = \mathbb{E}_{p(y,\theta|d)}\left[\log \frac{p(y|\theta,d)}{q_{m}(y|d)}\right] \\ \approx \hat{\mu}_{\text{marg}} = \frac{1}{N} \sum_{n=1}^{N} \log \frac{p(y_{n}|\theta_{n},d)}{q_{m}(y_{n}|d)} $$ $\hat{\mu}_{\text{post}}$の計算時と同様、$q_{m}(y|d,\phi)$の最適なパラメータは$\mathcal{L}_{\text{marg}}(d,\phi)$を最小化することで、計算できます。

Variational NMC $\hat{\mu}_{\text{VNMC}}$

$\hat{\mu}_{\text{post}}$と$\hat{\mu}_{\text{marg}}$はともにNMCよりも高速に収束しますが、近似分布が真の分布に対してバイアスが残る可能性があります。 ここでは、バイアスを低減する為にNMCの計算をベースに、提案分布に近似分布を用いることで高速化を考えます。この時、EIGの変分上限として $$ \text{EIG}(d) \le \mathcal{L}_{\text{VNMC}}(d;M) = \mathbb{E} \left[ \log p(y|\theta_{0},d) - \log \frac{1}{M} \sum_{m=1}^{M} \frac{p(y,\theta_{m}|d)}{q_{v}(\theta_{m}|y,d)} \right] \\ \approx \hat{\mu}_{\text{VNMC}} (d)= \frac{1}{N} \sum_{n=1}^{N} \left( \log p(y_{n}|\theta_{n,0},d) - \log \frac{1}{M} \sum_{m=1}^{M} \frac{p(y_{n},\theta_{n,m}|d)}{q_{v}(\theta_{n,l}|y_{n},d)} \right) $$が得られます。 サンプリングは$\theta_{n,0} \sim p(\theta)$,$y_{n} \sim p(y|\theta=\theta_{n,0},d)$, $\theta_{n,m} \sim q_{v}(\theta_{n,m}|y=y_{n},d)$で行います。1サンプル$(\theta_{n,0}, y_{n})$ごとに、$\{\theta_{n,m}\}_{m\ge1}$のサンプルを発生させているイメージです。 $\mathcal{L}_{\text{VNMC}}(d)$の第二項は$p(y|d)=\mathbb{E}_{p(\theta|d)}\left[p(y|\theta,d)\right]$を提案分布$q_{v}(\theta|y,d)$で重み付けした重点サンプリングで計算しています。 たとえ$q_{v}$が真の分布に収束しなくても、$L\rightarrow \infty$とした時、$\mathcal{L}_{\text{VNMC}}(d)$はEIGに収束することが証明されており、バイアスが生じないことが保証されています。 計算は、学習と評価の2ステップに分かれており、学習時の$y$のサンプル数を$S$、$\theta$のサンプル数を$L$、学習ステップを$K$回とした時、計算量は$T=\mathcal{O}(SKL+MN)$となります。NMCの計算量と同じのように見えますが、学習フェーズがあることにより、$M$の値を小さくでき、実際は早く収束します。

Implicit Likelihood $\hat{\mu}_{m+l}$

(なかなかわかりにくい名前ですが)、$y$のサンプリングはできるけど、明示的に尤度$p(y|\theta,d)$を計算できない場合に使える手法になります。 例えば、潜在変数$\psi$を有するモデル$p(y|\theta, d)=\mathbb{E}_{p(\psi|\theta)}\left[p(y|\theta, \psi, d)\right]$はよく見かけますが、これはサンプルは簡単に計算できるものの、尤度の計算は困難なことが多いです。 ここでは、周辺分布$q_{m}(y|d)$と尤度$q_{l}(y|\theta,d)$の2つを変分分布で近似した上で、EIGを計算します。 $$ \text{EIG}(d) \approx \mathcal{I}_{m+l}(d) = \mathbb{E}_{p(y,\theta|d)}\left[\log \frac{q_{l}(y|\theta,d)}{q_{m}(y|d)}\right] \\ \approx \hat{\mu}_{m+l} = \frac{1}{N} \sum_{n=1}^{N} \log \frac{q_{l}(y_{n}|\theta_{n},d)}{q_{m}(y_{n}|d)} $$ これまでの手法と違って、$\mathcal{I}_{m+l}(d)$は上限、下限ではないので、この値をターゲットとして、$q_{m}(y|d)$と$q_{l}(y|\theta,d)$を学習できません。 幸い、以下の不等式が成り立つのとで、次式の右辺を最小化することで学習可能です。 $$ |\mathcal{I}_{m+l}(d) - \text{EIG}(d)| \le - \mathbb{E}_{p(y,\theta|d)}\left[\log q_{l}(y|\theta,d) + \log q_{m}(y|d)\right] $$ $q_{l}$と$q_{m}$は別々のパラメータとしても、$q_{m}(y|d)=\mathbb{E}_{p(\theta|d)}\left[q_{l}(y|\theta,d)\right]$を利用してパラメータを共有しても学習可能です。

近似式の使い分けについて

$\hat{\mu}_{\text{VNMC}}$は唯一バイアスが生じない手法なので、計算コストが許容できる場合は、優先して利用すべき選択となります。その上で、$y$と$\theta$それぞれの次元数の大小と尤度$p(y|\theta,d)$を陽的に計算できるかを基準に使い分けることができます。 $\hat{\mu}_{\text{post}}$と$\hat{\mu}_{\text{VNMC}}$は$\theta$に対する近似分布を、$\hat{\mu}_{\text{marg}}$と$\hat{\mu}_{m+l}$は$y$に対する近似分布を利用しています。よって、$\dim(y) \ll \dim(\theta)$の場合は、より推定が容易になる$\hat{\mu}_{\text{marg}}$と$\hat{\mu}_{m+l}$を利用することが望ましいです。その上で、陽的に尤度$p(y|\theta,d)$を計算できる場合は$\hat{\mu}_{\text{marg}}$を、できない場合は$\hat{\mu}_{m+l}$を利用することが推奨されます。$\hat{\mu}_{\text{post}}$は陽的な尤度評価が必要ないので、そのような時に利用できます。

プログラム例:短期記憶モデルの推定

今紹介した手法を、確率プログラムPyroで実装した例が 公式 にありましたので、紹介します。 人は最大7つのことまでは短期記憶で覚えることができると言われています。ここでは、ある集団の短期記憶を測定する実験を考えたいと思います。 まず、短期記憶をモデル化し、そのパラメータを推定するための実験計画を考えてみます。 短期記憶のモデルを簡単に以下のように設定します。 $$ p(\theta) = \text{N}(\mu_{pre}, \sigma_{pre}^{2}) \\ p(y|\theta,d) = \text{Bernoulli}(p=\text{sigmoid}(k(\theta - d)) $$$d$は実験で提示する数列のサイズを、$y$は答えられたかどうかの結果を表します。たとえば、$d=5$の場合、$(1,2,3,2,5)$のような5つの数列を提示された時、全て正しく答えられた場合は$y=1$となります。$\theta$はこの実験における短期記憶の容量を表します。モデル上では50/50の確率で正解できる最大の数列数を意味します。

以下のコードは、上記のモデルをPyroで実装したものです。

sensitivity = 1.0
prior_mean = torch.tensor(7.0)
prior_sd = torch.tensor(2.0)


# theta ~ p(theta), y ~ p(y|theta, d)
def model(d):
    # Dimension -1 of `l` represents the number of rounds
    # Other dimensions are batch dimensions: we indicate this with a plate_stack
    with pyro.plate_stack("plate", d.shape[:-1]):
        theta = pyro.sample("theta", dist.Normal(prior_mean, prior_sd))
        # Share theta across the number of rounds of the experiment
        # This represents repeatedly testing the same participant
        theta = theta.unsqueeze(-1)
        # This define a *logistic regression* model for y
        logit_p = sensitivity * (theta - l)
        # The event shape represents responses from the same participant
        y = pyro.sample("y", dist.Bernoulli(logits=logit_p).to_event(1))
        return y

今回は、$\hat{\mu}_{\text{marg}}$を計算することでEIGを推定します。その為には、$q_{m}(y|d)$を変分近似する必要がありました。以下のコードでは、その実装です。候補となる実験それぞれにパラメータを割り振っています。

# q(y|d)
def marginal_guide(design, observation_labels, target_labels):
    # This shape allows us to learn a different parameter for each candidate design l
    q_logit = pyro.param("q_logit", torch.zeros(design.shape[-2:]))
    pyro.sample("y", dist.Bernoulli(logits=q_logit).to_event(1))

以下は、EIGを計算するコードになります。 marginal_eig 関数の引数として上で定義した$p(y,\theta|d)$と近似分布$q_{m}(y|d)$を与えています。 また、陽に観測値と潜在変数のラベルを指定する必要があります。これは、modelmarginal_guide 関数の中で指定したラベルと一致する必要があります。加えて、実験候補を candidate_designs として与えます。ここでは次元を [N, 1] として、N 人に対して1回ずつ実験をを行っているイメージです。

from pyro.contrib.oed.eig import marginal_eig

# The shape of `candidate_designs` is (number designs, 1)
# This represents a batch of candidate designs, each design is for one round of experiment
candidate_designs = torch.arange(1, 15, dtype=torch.float).unsqueeze(-1)
pyro.clear_param_store()
num_steps, start_lr, end_lr = 1000, 0.1, 0.001
optimizer = pyro.optim.ExponentialLR({'optimizer': torch.optim.Adam,
                                      'optim_args': {'lr': start_lr},
                                      'gamma': (end_lr / start_lr) ** (1 / num_steps)})

eig = marginal_eig(model,
    candidate_designs,  # design, or in this case, tensor of possible designs
    observation_labels = ["y"], # site label of observations, could be a list
    target_labels = ["theta"], # site label of 'targets' (latent variables)
    num_samples=100,  # number of samples to draw per step in the expectation
    num_steps=num_steps,     # number of gradient steps
    guide=marginal_guide,    # guide q(y)
    optim=optimizer,         # optimizer with learning rate decay
    final_num_samples=10000  # at the last step, we draw more samples
)

ここでは、marginal_eig 関数内でのlossの計算部分の実装を見てみます.$\hat{\mu}_{\text{marg}}$の計算に相当します。

def _marginal_loss(model, guide, observation_labels, target_labels):
    """Marginal loss: to evaluate directly use `marginal_eig` setting `num_steps=0`."""

    def loss_fn(design, num_particles, evaluation=False, **kwargs):
        expanded_design = lexpand(design, num_particles)

        # Sample from p(y | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}

        # Run through q(y | d)
        conditional_guide = pyro.condition(guide, data=y_dict)
        cond_trace = poutine.trace(conditional_guide).get_trace(
            expanded_design, observation_labels, target_labels
        )
        cond_trace.compute_log_prob()

        terms = -sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels)

        # At eval time, add p(y | theta, d) terms
        if evaluation:
            trace.compute_log_prob()
            terms += sum(trace.nodes[l]["log_prob"] for l in observation_labels)

        return _safe_mean_terms(terms)

    return loss_fn

Pyroでは、poutine.trace(model) を使って、モデルの内部状態をトレースすることができます。このオブジェクトに model と同じ引数を与えた上で get_trace を呼ぶことで、サンプリングを行い、内部変数の情報を取得できます。 上の例では、まず model の定義に従い$y$のサンプリングを行い、その結果を y_dict に格納しています。

次に、guide で定義された$q(y|d)$を結果 y_dict で条件付けを行い、対数尤度を計算しています。得られた対数尤度の平均をlossとして$q(y|d)$のパラメータを最適化しています。 評価時では、学習時には不要だった$p(y|\theta,d)$の対数尤度を加算し、その値を最終的なEIGの推定値としています。

上の marginal_eig の結果をplotしたものが以下になります。

plt.figure(figsize=(10,5))
matplotlib.rcParams.update({'font.size': 22})
plt.plot(candidate_designs.numpy(), eig.detach().numpy(), marker='o', linewidth=2)
plt.xlabel("$d$")
plt.ylabel("EIG($d$)")
plt.show()

eig

ここでは、EIGの最大値は$d=7$となりました。初回なので、事前分布の平均に引っ張られる形となりました。 紹介したリンクでは、その後上の操作を繰り返し、$p(\theta)$の更新とEIGの計算を繰り返し、最適な事後分布の算出しています。興味がある方は参考にしてみてください。

まとめ

今回は、実験計画法における期待情報利得の計算時の変分ベイズ法の適用について紹介しました。データの発生プロセスをモデル化するというハードルはあるので、使いどころは限定される印象です。しかし、実験のコストが高い場合において、より効率的に仮説を立証するためにどのような実験計画を組むべきかを定量的に評価できるのはメリットといえるでしょう。