オムライスの備忘録

数学・統計学・機械学習・プログラミングに関することを記す

【深層学習】Sentence BERT / SBERT

この記事の読者

Sentence BERT / SBERT」について知りたい

キーワード・知ってると理解がしやすい

  • BERT
  • Siamese Network
  • Cosine Similarity
  • Triplet Network / Triplet Loss

yhayato1320.hatenablog.com

まとめ編 yhayato1320.hatenablog.com

Index

Sentence BERT / SBERT とは

BERT (2018) や RoBERTa(2019) やその亜種の「事前学習されたモデル」と、「Siamese Net」を利用して 再学習(ファインチューニング)して、良質な文章ベクトルを生成する手法.

yhayato1320.hatenablog.com

アルゴリズム

アーキテクチャ

BERT のアウトプットに Pooling を加えたネットワークを使用する.
Pooling は 3 種類あり、

  • CLS トークンを利用
  • ベクトルの出力の平均を利用(mean strategy / mean pooling)
  • 最大値を利用 (max strategy / max pooling)



2 つの文章が入力とされ、2 つのembedding された ベクトル情報を利用する.



再学習(ファインチューニング)する目的や学習データ、タスクに合わせて 2 つの損失 / 目的関数、ネットワークアーキテクチャを構築.

分類問題 + Siamese Network

Siamese Network で2 つのベクトルの差分を計算し、そのベクトルから分類確率を計算.



損失 / 目的関数は

 o = softmax( W_t (|u - v|))

回帰問題

2 つの文章ベクトルの類似度を Cosine Similarity (コサイン類似度)で計算. その類似度を回帰問題として予測する.



損失 / 目的関数はmean squared error loss.

Triplet Network

Anchor Sentence, Positive Sentence(Anchor Sentence と意味が同じ文章), Negative Sentence (Anchor Sentence と意味が反対 / 異なる文章) の 3 が入力するネットワークアーキテクチャを構築.

損失 / 目的関数は、

 max(|| s_a - s_p || - || s_a - s_n|| + \epsilon, 0)

参考