オムライスの備忘録

数学・統計学・機械学習・プログラミングに関することを記す

【深層学習】スコアベースモデル / Score Base Model / SBM

Index

スコア

任意の入力について微分可能な確率分布で、スコアを定義する.

対数尤度  \log\ p(x) の入力  x についての勾配をスコアと呼ぶ.

スコアを与える関数  s(x) をスコア関数と呼ぶ.

 s(x)\ =\ \nabla_{x}\ \log\ p(x)\ :\ R^{d}\ \rightarrow\ R^{d}



スコアは、入力  x と同じ次元数  d をもつベクトル.

スコアは、入力空間でのベクトル場を表し、各点のベクトルはその位置での対数尤度が 最も急激に大きくなる方向とその大きさを表す.



また、スコアは微分の公式より、

 \nabla_{x}\ \log\ p(x)\ =\ \displaystyle \frac{\nabla_{x}\ p(x)}{p(x)}



と表されるため、確率が最も急激に上昇するベクトルを、確率で割った値となる.

そのため、スコアは確率が小さい領域で大きくなりやすく、大きい領域で小さくなりやすい.

確率分布のスコアが得られれば、確率分布から効率的にサンプリングできる.

確率分布を直接学習する代わりに確率分布のスコアを学習し、 スコアを使って生成モデルを実現するモデルをスコアベースモデル / SBMと呼ぶ.

スコア導入のメリット

分配関数が、現れない.

近似法

ランジュバン・モンテカルロ法

ランジュバン・モンテカルロ法は、スコアを使った MCMC 法.

 p(x) からサンプルを得ることが目的.



MCMC 法の問題点の一つである、周辺の確率が大きい皇甫を効率的に探す問題を解決できる.

スコアベースモデル / Score Base Model / SBM

確率分布を直接学習する代わりに、確率分布のスコアを学習し、スコアを使って生成モデルを実現するモデルを スコアベースモデル / Score Base Model / SBM と呼ぶ.

スコアマッチング

スコア関数の学習方法のひとつ.

スコア関数をニューラルネットワークで近似するというアイディア.

 s_{\theta}(x)\ :\ R^{d}\ \rightarrow\ R^{d}

明示的スコアマッチング / Explicit Score Matching / ESM

学習対象の分布のスコアとモデル間の 2 乗誤差が最小となるアプローチ.

この際、目標分布  p(x) で期待値をとる.

この目的関数を明示的スコアマッチング / Explicit Score Matching / ESMと呼ぶ.


\begin{eqnarray}
J_{ESM_{p}} (\theta)\  &=& E_{p(x)}\ \left[\ ||\ s(x)\ -\ s_{\theta}(x)\ ||^{2}\ \right] \\
  &=& E_{p(x)}\ \left[\ ||\ \nabla_{x}\ \log\ p(x)\ -\ s_{\theta}(x)\ ||^{2}\ \right]
\end{eqnarray}



しかし、一般に、生成モデルの学習には訓練データのみが与えれスコアは未知である.

暗黙的スコアマッチング / Implicit Score Matching / ISM

訓練データのみから明示的スコアマッチングを使って学習できない問題を解きたい.

そこで、スコアを使わずに定義する.

 J_{ISM_{p}} (\theta)\ =\ E_{p(x)}\ \left[\ \displaystyle \frac{1}{2}\ ||\ s_{\theta}(x)\ ||^{2}\ +\ tr(\nabla_{x}\ s_{\theta}(x))\ \right]


  •  s_{\theta}(x) : モデルが表す推定されたスコア
  •  \nabla_{x}\ s_{\theta}(x) :   s_{\theta} の各成分について再度  x で勾配をとったヘッセ行列
  •  tr : 行列のトレース (対角成分の和) を計算



この暗黙的スコアマッチングは、スコアを使わないにも関わらず、明示的スコアマッチングとパラメータに依存しない定数項を 除いて等しくなる.

 J_{ESM_{p}} (\theta)\ =\ J_{ISM_{p}} (\theta)\ +\ C_{1}

デノイジングスコアマッチング / Denosing Score Matching / DSM

暗黙的スコアマッチングを適用するには、以下のような問題がある.

  1. 計算量が大きい
  2. 過学習が起こりやすい

参考