オムライスの備忘録

数学・統計学・機械学習・プログラミングに関することを記す

【深層学習】Generative Adversarial Network / GAN #アルゴリズム編

この記事の読者

深層学習・DeepLearningを利用している
「Generative Adversarial Network : GAN」の基本的な内容について知りたい.



この記事では、Generative Adversarial Network / GANの基礎概念のまとめる.

Index

Generative Adversarial Network / GAN

深層学習を用いた生成モデルの一つ.

日本語では、「敵対的生成ネットワーク」ともいわれる.



このフレームワークでは、2 つのモデルが相互に学習を行う.

1 つは、データの分布を捉える / 理解することで、
データを生成する Generator / 生成器 と呼ばれる生成モデル.

1 つは、サンプルデータが真のデータである確率を推定する
Discriminator / 識別器 と呼ばれる識別モデル.

この2つのニューラルネットワークを相互に学習することにより、
「質の良いデータの生成」と「精度の高い識別」を可能にするフレームワーク.

GAN構造

Generator の目的

Generator の目的は、「Discriminator が、識別の間違いを起こす確率」を最大化すること.

つまり、生成した偽物のデータを本物と誤識別させることにある.

2 人の MiniMax Game

このフレームワークは、どちらかの精度が上がれば良いというわけではない.

それぞれが、互いに向上しあって精度を高めていくことが望ましい.

Architecture

Generator と Discriminator のどちらも深層学習が用いられている.

本論文では、全結合層のみのネットワークを想定している.

入力データに画像を扱う場合、Convolution Neural Network / CNN を利用されることが多い.

その場合、DC GAN / Deep Convolution GAN と呼称される.

Generator

Generator は、ランダムなベクトルを入力とし、
元のデータセットからサンプリングしたようなデータを出力するように学習する.

入力のベクトルは、多変量正規分布が用いられることが多い.

このベクトルは、潜在変数 / latent variable と呼ばれることが多く、
その変数の範囲 (空間) を潜在空間 / latent space と呼ばれる.

また、潜在変数は、 z で表現されることが多い.



「Generator の目的」でも記述したが、Generator は、
Discriminator を騙すようなデータを生成することを目的として学習を行う.

Discriminator

Discriminator は、
入力されたデータが、元のデータセットからサンプリングされた真のデータか、
Generator が生成したデータかどうかを予測する.

これは、教師あり学習の枠組みの分類問題と考えることができる.

そのために、
訓練データセット(真のデータ)からランダムに選んだ本物の何件かのデータ
Generator が生成したデータ何件かのデータで構成されたデータの集合を用意する必要がある.

そして、真のデータのラベルを  1 に、生成されたデータのラベルを  0 として、
予測した分類確率との損失を計算できれば、教師あり学習の考えが適用できる.

定式化

Generator が生成する確率分布を  p_g として、
Generator は真のデータ分布  p_{data} からサンプリングされた  x から、
 p_g を真のデータ分布  p_{data} に近づけることが目的.

そのために、入力のノイズ (潜在) 変数  z の確率分布  p_z(z) を事前に定義する.
(正規分布が使用されることが多い)

そして、データ空間へのマッピングを行う Generator の処理を関数 Gとして、  G(z;\theta_g)\ (=\ G(z)\ ) と表す.
(ここで、 \theta_gニューラルネットのパラメータ)

次に、Discriminator を定義する.

Discriminator の処理を関数 D とし、入力  x と考えると、
出力は  D(x; \theta_d)\ (=\ D(x)\ ) となる.
(ここでも、 \theta_dニューラルネットのパラメータ)

 D(x) は、真の分布  p_{data} からサンプリングされたデータである確率を出力する.

それぞれの目的

Discriminator は、
「生成された偽物のデータ」と、「学習データの本物のデータ」を分類するための学習を行う.

Discriminator の 最大化したい関数 (Object Function) は、  \log(D(x)) と表現できる.

Discriminator は通常の分類問題であるので、
真のデータ  x の分類確率  D(x) を最大化できればよい.



そして、Discriminator の学習と同時に、Generator の学習も行う.

Generator の 最小化したい関数 (Loss Function) は、  \log(\ 1 - D(\ G(z)\ )\ ) と表現できる.

 G(z) は、Generator によって、生成されたデータを示す.

 D(\ G(z)\ ) は、生成されたデータが入力されたときの、
Discriminator が本物のデータであると判断した確率である.

Discriminator からしたら、この値は、小さい方が良い.

逆に、Generator からしたら、この値を大きく (Maximize) したい.

ので、 1 - D(\ G(z)\ ) を最小化すことが Generator の目的になる.



目的関数


\begin{align}
V(D, G)\ =&\ E_{\ x \sim p_{data}(x)}\ [\ \log(D(x))\ ] \\
+&\ E_{\ z\ \sim\ p_{z}(z)}\ [\ \log(\ 1 - D(\ G(z)\ )\ )\ ]
\end{align}



この目的関数を最適化した  G D を得たい.

 \displaystyle \min_{G} \max_{D} V(D, G)



バッチ内のデータの平均値を目的関数 / 損失関数にするために、期待値をとっている.

 E[X]


学習の話

Discriminator の学習のみが進んでしまうことを設計上、防ぐ必要がある.

そのために、Discriminator  k 回の学習ステップに対し、Generator  1 回の学習ステップを行うようにする.

このような仕組みにすることで、Generator の出力がゆっくり変わるようにすることで、 Discriminator の最適解が、徐々に、真のデータ分布に適用していく( 近づいていく / 理解していく).

ここで、Generator の出力が変わらない現象(モード崩壊)が起きると、
Discriminator が学習するデータ分布が固定され、真のデータ分布を正しく理解ができない.





  • 青い破線 : Discriminator の識別分布 (分類確率)
  • 黒い点線 : 真のデータ分布
  • 緑の実践 : Generator が生成したデータの分布

  • 下の水平軸
    •  z : ランダム (均一) にサンプリングされた潜在変数
    •  x : データの定義域
    •  G(z) = x がどのように写像されているのかを表している

モード崩壊 / Mode

Discriminator の学習のために、Generator が十分な変化を提供しない場合がある.

つまり、同じようなデータ生成していまう.

Generator は Discriminator を欺ければ良いので、騙せるデータが分かれば、そのようなデータばかり生成してしまう.

-> 生成してしまうような局所解に陥ってしまって、抜け出せない.

-> また、Discriminator も、ある決まったパターンのデータが来ると、騙されてしまうような局所解に陥ってしまって、抜け出せない.

学習の段階

最初、Generator のデータ生成は、ランダムに生成しているため、再現度に関しては貧弱である.

そのとき、 \log(\ 1 - D(\ G(z)\ )\ ) が増加する.
(Discriminator の分類精度が高く、Generator の生成精度が低いため)

その後、Generator が学習を進めるにつれ、 \log(\ 1 - D(\ G(z)\ )\ ) を最小化していく.
(つまり、 D(G(x)) を最大化しようとする)

まとめ

Generative Adversarial Network : GAN は Generator (生成器) と Discriminator (識別器)の2つのニューラルネットワークで構成されている.

Generator はデータを生成し、Discriminator はデータの真偽を見破る役割を担っている.

相互に学習することで、データの生成と識別の精度を向上させていく.

参考

  • Generative Adversarial Networks
    • [2014]
    • Abstract
    • 3 Adversarial nets
    • arxiv.org

書籍

Web サイト