ベイズ最適化って何がベイズ?

最終更新日: 2020年12月8日

この記事は fMRILab Advent Calendar 2020 の 12/8 分の記事です。


ふと、ベイズ最適化って何がベイズなんだ?と思ったので調べました。

参考文献

Shahriari et al., (2015), ガウス過程と機械学習 (2019) が大変参考になりました。emoji-bow

ベイズ最適化とは

ベイズ最適化の目的:未知の目的関数 ff の global maximizer (minimizer) を発見すること

x=argmaxxXf(x)\boldsymbol{x}^* = \underset{\boldsymbol{x}\in\mathcal{X}}{\text{argmax}} f(\boldsymbol{x})

ベイズ最適化を利用すると、効率的に1入力の空間 X\mathcal{X} を探索して目的関数 ff の最適値を求めることができます。

具体的には、ベイズ的な更新を利用して獲得関数 αn:XR\alpha_n: \mathcal{X} \to\R を導出し、探索を効率化します。αn\alpha_n を最大化するように xn+1\boldsymbol{x}_{n+1} を選ぶことで、ff を最大化するのに有用な候補点を選択できるということです。

アルゴリズム 1: ベイズ最適化

  1. 獲得関数 αn\alpha_n を最適化するように新しい点 xn+1\boldsymbol{x}_{n+1} を選ぶ。
xn+1=argmaxx αn(x;Dn)\boldsymbol{x}_{n+1} = \underset{\boldsymbol{x}}{\text{argmax}} \ \alpha_n(\boldsymbol{x}; \mathcal{D}_n)
  1. xn+1\boldsymbol{x}_{n+1} に対応する出力 yn+1=f(xn+1)y_{n+1}=f(\boldsymbol{x}_{n+1}) を取得する。
  2. データを拡張する。 Dn+1={Dn,(xn+1,yn+1)}\mathcal{D}_{n+1}=\{\mathcal{D}_n,(\boldsymbol{x}_{n+1},y_{n+1})\}
  3. 獲得関数を定める統計モデル p(wDn)p(\boldsymbol{w}|\mathcal{D}_n) を更新する。
p(wDn+1)=p((xn+1,yn+1)w)p((xn+1,yn+1))×p(wDn)p(\boldsymbol{w}|\mathcal{D}_{n+1}) = \frac{p((\boldsymbol{x}_{n+1},y_{n+1})|\boldsymbol{w})}{p((\boldsymbol{x}_{n+1},y_{n+1}))} \times p(\boldsymbol{w}|\mathcal{D}_{n})
  1. 手順 1 から手順 4 を繰り返す。

ベイズ更新は、ベイズの定理から導かれます。

p(AB,C)=p(CA)p(C)p(BA)p(A)p(B)=p(CA)p(C)×p(AB)\begin{aligned} p(A|B,C) &= \frac{p(C|A)}{p(C)} \frac{p(B|A)p(A)}{p(B)} = \frac{p(C|A)}{p(C)} \times p(A|B) \end{aligned}
なるほど。"ベイズ的に獲得関数を更新していく"からベイズ最適化というのか。

以下では、パラメトリックな方法とノンパラメトリックな方法によるベイズ最適化を見ていきます

パラメトリックなベイズ最適化

ここでは、線形モデルを利用したパラメトリックなベイズ最適化について考えます。この場合、目的関数 ff を線形モデルとしてモデリングします

fw(x)=xwf_{\boldsymbol{w}}(\boldsymbol{x}) = \boldsymbol{x}^\top \boldsymbol{w}

線形モデルなので、観測値 yy は平均 xw\boldsymbol{x}^\top \boldsymbol{w}、分散 σ2\sigma^2正規分布に従うとします。

yN(xw, σ2)y \sim \mathcal{N}(\boldsymbol{x}^\top \boldsymbol{w}, \ \sigma^2)

正規分布の共役事前分布として、正規逆ガンマ分布 (NIG)2 を利用します。

事前分布 尤度 事後分布
正規逆ガンマ分布 正規分布 正規逆ガンマ分布
p(y  x,w,σ2)=N(y  xw, σ2)p(w,σ2)=NIG(w,σ2w0,V0,α0,β0)\begin{aligned} p(y \ | \ \boldsymbol{x}, \boldsymbol{w},\sigma^2) &= \mathcal{N}(y \ | \ \boldsymbol{x}^\top \boldsymbol{w}, \ \sigma^2) \\ p(\boldsymbol{w},\sigma^2) &= \text{NIG}(\boldsymbol{w},\sigma^2|\boldsymbol{w}_0,\boldsymbol{V}_0, \alpha_0,\beta_0) \end{aligned}

このとき、共役性から事後分布も正規逆ガンマ分布となります。

p(w,σ2  y,x)=NIG(w,σ2y,x,wn,Vn,αn,βn)Vn=(V01+xx)1wn=Vn(V0w0+xy)αn=α0+n/2βn=β0+12(w0V01w0+y2wnVnwn)p(\boldsymbol{w},\sigma^2 \ | \ y, \boldsymbol{x}) = \text{NIG}(\boldsymbol{w},\sigma^2|y, \boldsymbol{x}, \boldsymbol{w}_n,\boldsymbol{V}_n, \alpha_n,\beta_n) \\[0.5em] \begin{aligned} \boldsymbol{V}_n &= (\boldsymbol{V}_0^{-1}+\boldsymbol{x}^\top\boldsymbol{x})^{-1} \\ \boldsymbol{w}_n &= \boldsymbol{V}_n (\boldsymbol{V}_0\boldsymbol{w}_0+\boldsymbol{x}^\top {y}) \\ \alpha_n &= \alpha_0 + n/2 \\ \beta_n &= \beta_0 + \frac{1}{2} (\boldsymbol{w}_0^\top \boldsymbol{V}_0^{-1} \boldsymbol{w}_0 + y^2 - \boldsymbol{w}_n\boldsymbol{V}_n \boldsymbol{w}_n) \end{aligned}

したがって、事後分布からサンプリングした w~\tilde{\boldsymbol{w}} を利用し、次の点を決めることができます。(サンプリングして獲得関数を生成します。)

w~p(wDn),xnew=argmax xXxw~\tilde{\boldsymbol{w}} \sim p(\boldsymbol{w}|\mathcal{D}_n), \quad \boldsymbol{x}_{new} = \underset{\boldsymbol{x}\in\mathcal{X}}{\text{argmax }} \boldsymbol{x}^\top \tilde{\boldsymbol{w}}

引き続き xnew\boldsymbol{x}_{new} に対応する出力を取得してベイズ更新を繰り返せば、効率的に local maximizer に辿り着くことができます。これが線形モデルを用いたベイズ最適化です3

アルゴリズム 2: ベイズ最適化(線形モデル)

  1. 線形モデルを用意する。
p(y  x,w,σ2)=N(y  xw, σ2)p(w,σ2)=NIG(w,σ2w0,V0,α0,β0)p(w,σ2  y,x)=NIG(w,σ2y,x,wn,Vn,αn,βn)\begin{aligned} p(y \ | \ \boldsymbol{x}, \boldsymbol{w},\sigma^2) &= \mathcal{N}(y \ | \ \boldsymbol{x}^\top \boldsymbol{w}, \ \sigma^2) \\ p(\boldsymbol{w},\sigma^2) &= \text{NIG}(\boldsymbol{w},\sigma^2|\boldsymbol{w}_0,\boldsymbol{V}_0, \alpha_0,\beta_0)\\ p(\boldsymbol{w},\sigma^2 \ | \ y, \boldsymbol{x}) &= \text{NIG}(\boldsymbol{w},\sigma^2|y, \boldsymbol{x}, \boldsymbol{w}_n,\boldsymbol{V}_n, \alpha_n,\beta_n) \end{aligned}
  1. 獲得関数 αn\alpha_n を最適化するように新しい点 xn+1\boldsymbol{x}_{n+1} を選ぶ。
w~p(wDn),αn(x)=xw~xn+1=argmaxx αn(x;Dn)\tilde{\boldsymbol{w}} \sim p(\boldsymbol{w}|\mathcal{D}_n), \quad \alpha_n(\boldsymbol{x}) = \boldsymbol{x}^\top \tilde{\boldsymbol{w}} \\[0.5em] \boldsymbol{x}_{n+1} = \underset{\boldsymbol{x}}{\text{argmax}} \ \alpha_n(\boldsymbol{x}; \mathcal{D}_n)
  1. xn+1\boldsymbol{x}_{n+1} に対応する出力 yn+1=f(xn+1)y_{n+1}=f(\boldsymbol{x}_{n+1}) を取得する。
  2. データを拡張する。 Dn+1={Dn,(xn+1,yn+1)}\mathcal{D}_{n+1}=\{\mathcal{D}_n,(\boldsymbol{x}_{n+1},y_{n+1})\}
  3. 線形モデルを更新する。
p(w,σ2 Dn+1)=NIG(w,σ2Dn+1,wn+1,Vn+1,αn+1,βn+1)p(\boldsymbol{w},\sigma^2 \ | \mathcal{D}_{n+1} ) = \text{NIG}(\boldsymbol{w},\sigma^2|\mathcal{D}_{n+1}, \boldsymbol{w}_{n+1},\boldsymbol{V}_{n+1}, \alpha_{n+1},\beta_{n+1})
  1. 手順 1 から手順 4 を繰り返す。

ノンパラメトリックなベイズ最適化

目的関数に明示的なモデルを設定しないノンパラメトリックなベイズ最適化も利用されています。そのうち最も良く利用されているのがガウス過程です

ガウス過程 GP\text{GP} による回帰では、次のような生成モデル(観測誤差ありのガウス過程回帰)を考えます。

p(y  f,σ2)=N(yf,σ2I)p(f  X)=N(fm,K)\begin{aligned} p(\boldsymbol{y} \ | \ \boldsymbol{f}, \sigma^2) &= \mathcal{N}(\boldsymbol{y} | \boldsymbol{f},\sigma^2 \bold{I}) \\ p(\boldsymbol{f} \ | \ \bold{X}) &= \mathcal{N}(\boldsymbol{f} |\boldsymbol{m},\bold{K}) \end{aligned}

ただし、X=(x1,,xn)\bold{X}=(\boldsymbol{x}_1, \dots, \boldsymbol{x}_n), y=(y1,,yn)\boldsymbol{y}=(y_1, \dots, y_n), f=(f(x1),,f(xn))\boldsymbol{f}=(f(\boldsymbol{x}_1), \dots, f(\boldsymbol{x}_n)) と表記します。また、m\boldsymbol{m} が平均ベクトル、K\bold{K} はカーネル行列で、その要素はカーネル関数 Kij=k(xi,xj)K_{ij}=k(\boldsymbol{x}_i,\boldsymbol{x}_j) になっています。

X\bold{X} が与えられたときの y\boldsymbol{y} の分布は f\boldsymbol{f} に関して周辺化することで求められます4

p(yX)=p(y,fX) df=p(yf) p(fX) df=N(yf,σ2I) N(fm,K) df=N(ym,K+σ2I)\begin{aligned} p(\boldsymbol{y}|\bold{X}) &= \int p(\boldsymbol{y}, \boldsymbol{f}|\bold{X}) \ d\boldsymbol{f} \\ &= \int p(\boldsymbol{y}|\boldsymbol{f}) \ p(\boldsymbol{f}|\bold{X}) \ d\boldsymbol{f} \\ &= \int \mathcal{N}(\boldsymbol{y}|\boldsymbol{f},\sigma^2 \bold{I}) \ \mathcal{N}(\boldsymbol{f}|\boldsymbol{m},\bold{K}) \ d\boldsymbol{f} \\ &= \mathcal{N}(\boldsymbol{y}|\boldsymbol{m}, \bold{K} + \sigma^2 \bold{I}) \end{aligned}

したがって、y\boldsymbol{y}X\bold{X} から計算できる正規分布に従うことがわかります。これは、カーネル関数を

k(xi,xj)=k(xi,xj)+σ2δijk'(\boldsymbol{x}_i,\boldsymbol{x}_j) = k(\boldsymbol{x}_i,\boldsymbol{x}_j) + \sigma^2 \delta_{ij}

おきかえたガウス過程になっています

fGP( m(x), k(x,x) )f \sim \text{GP}( \ m(\boldsymbol{x}), \ k'(\boldsymbol{x},\boldsymbol{x}') \ )

ff を明示的にモデル化しなくても、その同時分布が正規分布になっていて、ff がブラックボックスのままで X\bold{X} から y\boldsymbol{y} の分布が計算できます。さらに、新しい入力 xnew\boldsymbol{x}_{new} に対する ynewy_{new} の予測分布は、

p(ynewxnew,y,X)=N( ynew  m,K )m=m(xnew)+k(K)1(ym)K=k(xnew,xnew)k(K)1kk=(k(xnew,x1),,k(xnew,xn))\begin{aligned} p(y_{new}|\boldsymbol{x}_{new},\boldsymbol{y}, \bold{X}) &= \mathcal{N}( \ y_{new} \ | \ \boldsymbol{m}^*, \bold{K}^* \ ) \\ \boldsymbol{m}^* &= m(\boldsymbol{x}_{new}) + \boldsymbol{k}^* (\bold{K}')^{-1} (\boldsymbol{y}-\boldsymbol{m}) \\ \bold{K}^* &= k'(\boldsymbol{x}_{new},\boldsymbol{x}_{new}) - \boldsymbol{k}^{*\top} (\bold{K}')^{-1} \boldsymbol{k}^* \\ \boldsymbol{k}^*&= (k'(\boldsymbol{x}_{new},\boldsymbol{x}_1), \dots, k'(\boldsymbol{x}_{new},\boldsymbol{x}_n))^\top \end{aligned}

と計算できます。

獲得関数には色々なものが提案されていて、以下の Expected Improvement (EI) という方法がよく用いられるようです。ここで、Φ,ϕ\Phi,\phi はそれぞれ標準正規分布の分布関数と密度関数です。

α(x)=(m(x)τ)Φ(z)+k(x,x)ϕ(z)( τ=max1in(yi), z=m(x)τk(x,x) )\alpha(\boldsymbol{x}) = (m(\boldsymbol{x}) - \tau) \Phi(z) + \sqrt{k'(\boldsymbol{x}, \boldsymbol{x})} \phi(z) \\[0.5em] \left( \ \tau = \max_{1\le i \le n}(y_i), \ z = \frac{m(\boldsymbol{x}) - \tau}{\sqrt{k'(\boldsymbol{x}, \boldsymbol{x})}} \ \right)

アルゴリズム 3: ベイズ最適化(ガウス過程)

  1. ガウス過程モデルを用意する。
p(y  X)=N(ym,K)fGP( m(x), k(x,x) )\begin{aligned} p(\boldsymbol{y} \ | \ \bold{X}) &= \mathcal{N}(\boldsymbol{y}|\boldsymbol{m}, \bold{K}) \\ f &\sim \text{GP}( \ m(\boldsymbol{x}), \ k(\boldsymbol{x},\boldsymbol{x}) \ ) \end{aligned}
  1. 獲得関数 αn\alpha_n を最適化するように新しい点 xn+1\boldsymbol{x}_{n+1} を選ぶ。
αn(x)=(m(x)τ)Φ(z)+k(x,x)ϕ(z)xn+1=argmaxx αn(x;Dn)\alpha_n(\boldsymbol{x}) = (m(\boldsymbol{x}) - \tau) \Phi(z) + \sqrt{k(\boldsymbol{x}, \boldsymbol{x})} \phi(z) \\[0.5em] \boldsymbol{x}_{n+1} = \underset{\boldsymbol{x}}{\text{argmax}} \ \alpha_n(\boldsymbol{x}; \mathcal{D}_n)
  1. xn+1\boldsymbol{x}_{n+1} に対応する出力 yn+1=f(xn+1)y_{n+1}=f(\boldsymbol{x}_{n+1}) を取得する。
  2. データを拡張する。 Dn+1={Dn,(xn+1,yn+1)}\mathcal{D}_{n+1}=\{\mathcal{D}_n,(\boldsymbol{x}_{n+1},y_{n+1})\}
  3. ガウス過程モデルを更新する。
[yyn+1]N([mm(xn+1)],[Kkkk(xn+1,xn+1)])\begin{bmatrix} \boldsymbol{y} \\ y_{n+1} \end{bmatrix} \sim \mathcal{N} \left( \begin{bmatrix} \boldsymbol{m} \\ m(\boldsymbol{x}_{n+1}) \end{bmatrix}, \begin{bmatrix} \bold{K} & \boldsymbol{k}^{*} \\ \boldsymbol{k}^{*\top} & k(\boldsymbol{x}_{n+1}, \boldsymbol{x}_{n+1}) \end{bmatrix} \right)
  1. 手順 1 から手順 4 を繰り返す。

Reference


  1. 効率的というのは、探索(未知の行動を選択する)と利用(経験的な最善の行動を選択する)のトレードオフについてバランスが良いということである。

  2. 正規逆ガンマ分布の定義は以下。

    NormInv-Gamma(w,σ2w0,V0,α0,β0)=2πσ2V01/2exp[12σ2(ww0)V01(ww0)]×β0α0Γ(α0)(σ2)α0+1exp[β0σ2]\begin{aligned} &\text{NormInv-Gamma}(\boldsymbol{w},\sigma^2| \boldsymbol{w}_0,\boldsymbol{V}_0, \alpha_0,\beta_0) \\ &\quad = |2 \pi \sigma^2 \boldsymbol{V}_0 |^{-1/2} \exp \left[ -\frac{1}{2\sigma^2} (\boldsymbol{w}-\boldsymbol{w}_0)^\top \boldsymbol{V}_0^{-1} (\boldsymbol{w}-\boldsymbol{w}_0)\right] \\ &\qquad \times \frac{\beta_0^{\alpha_0}}{\Gamma(\alpha_0)(\sigma^2)^{\alpha_0+1}} \exp \left[ -\frac{\beta_0}{\sigma^2} \right] \end{aligned}

  3. 離散的な探索空間を考えればトンプソンサンプリングと同等になると思います。

  4. これは、日本語版 PRML p.90 の「ガウス分布の周辺分布と条件付き分布」を利用すればよいです。

    p(yx)=N(yAx+b,L1)p(x)=N(xμ,Λ1) p(y)=N(Aμ+b,L1+AΛ1A)\begin{aligned} p(\boldsymbol{y}|\boldsymbol{x}) &= \mathcal{N}(\boldsymbol{y}|\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b},\boldsymbol{L}^{-1}) \\ p(\boldsymbol{x}) &= \mathcal{N}(\boldsymbol{x}|\boldsymbol{\mu},\boldsymbol{\Lambda}^{-1}) \\[0.5em] \therefore \ p(\boldsymbol{y}) &= \mathcal{N}(\boldsymbol{A}\boldsymbol{\mu}+\boldsymbol{b},\boldsymbol{L}^{-1}+\boldsymbol{A}\boldsymbol{\Lambda}^{-1}\boldsymbol{A}^\top) \end{aligned}