オムライスの備忘録

数学・統計学・機械学習・プログラミングに関することを記す

【深層学習】BERT #実装編 #03

この記事の読者

深層学習・ディープラーニングの手法の1つである「BERT」について知りたい.


キーワード・知ってると理解がしやすい

  • BERT

yhayato1320.hatenablog.com


yhayato1320.hatenablog.com

Index

環境とライブラリ

#01 と同様の環境

yhayato1320.hatenablog.com

実装

#02 では、BERT を稼働して、ベクトルが出力されることを確認した.


yhayato1320.hatenablog.com

#03 では、具体的な問題に取り組む. それは、文章の穴埋めだ.

文章の穴埋め

複数の穴埋め

例えば、以下のような入力を作ったとする.


今日は [MASK] [MASK] へ行く.


今回のように [MASK] に入る語彙の候補が 32000通りもあるので、 組み合わせの数は膨大になる.
そこで、ナイーブな方法として貪欲法がある.

予測する箇所が複数あるから、候補が多くなってしますので、 予測する箇所を分割して、それぞれを予測する.


しかし、BERT は、文章を前から順番に生成するような文章生成が得意ではない. (RNN系 のような前の単語から予測するネットワークではなく、 Attention系 のような全体を見ているネットワークだからか)

貪欲法では、次の単語として予測確率(スコア)が高いものを選ぶ操作を繰り返すだけだが、 最終的な文章全体としてのスコアが、高い文章を作成する方法としてビームサーチ / Beam Searchがある.

ビームサーチは、[MASK] を含む文章が与えられたときに、 まず一つ目の [MASK] を (例えばスコア上位10の)トークンで置き換えた 10 の文章を作る.
そして、次は得られた 10 の文章のそれぞれに対して、次の [MASK] を、また上位 10 のトークンで置き換えた 10 の文章を作ります.

これにより、現時点で、2 つの [MASK] が穴埋めされた 100 の文章が得られている.

そして、次にこの 100 個の中から合計スコアの高い 10 の文章をさらに選ぶ.

ここでいう合計スコアは、それまでに穴埋めされたトークンのスコアを合計したもの.

その後は、この 10 の文章の次の [MASK] に対し同じ処理を繰り返す.

ビームサーチを利用した「複数の穴埋め処理」の実行のために、3 つの関数を作成.

  • make_mask
  • predict_mask_topk
  • beam_search

メインの処理は、いくつかの文章に対し、マスクをランダムに複数作成し、
それぞれのマスキング済みの文章の穴埋めをビームサーチを用いて行う.

参考