オムライスの備忘録

数学・統計学・機械学習・プログラミングに関することを記す

【深層学習】Scaled Dot Product Attention

Index

Scaled Dot Product Attention とは

Attention の仕組みの中で利用されるスコア関数のひとつ.

yhayato1320.hatenablog.com

諸定義

 n 個の入力(トークン)で構成されいる時系列データ(文章) を処理することを考える.

ある層の出力
時刻  i の出力ベクトル:  x_i\ (i\ =\ 1,\ \cdots,\ n)

重みパラメータ
 W^{Q}
 W^{K}
 W^{V}

 d次元ベクトル Query、 Key、 Value
 
\begin{align}
q_{i}&\ =\ x_{i} W^{Q} \\
k_{i}&\ =\ x_{i} W^{K} \\
v_{i}&\ =\ x_{i} W^{V}
\end{align}

アルゴリズム

それぞれのトークンは、これら三つのベクトル(Query / Key / Value) により特徴付られる. (新たな特徴量として、それぞれの用途で使われる.)

出力の計算

これらのベクトルの入力として、それぞれのトークンに対して、ベクトル  a_{i} を出力する.

 a_{i}\ =\ \displaystyle \sum_{j=1}^{n} \alpha_{i,\ j} v_{j}


ここで、 \alpha \leq 0,\ \displaystyle \sum_{j} \alpha_{i,\ j}\ =\ 1
つまり、 a_{i}Value ベクトルの重み付き平均


重み  a_{i,\ j} は、 i 番目のトークンを処理する際に、  j 番目のトークンの重要度を表す.

重要度の計算

重み  \alpha_{i,\ j} の計算は、Key と Query から決まる.


一般的には、 i 番目のトークンに対応する Query  q_{i} j 番目のトークンに対応する Key  k_{j} との関連度をなんらかの方法で評価し、 それに応じて重みを決める


つまり、 i 番目のトークンを処理するときには、その Queryと関連度(重要度)が大きい Key を持つトークンほど、影響を大きく与える  \alpha_{i,\ j} が利用される.



Scaled Dot Product と呼ばれる方法(スコア関数)で、Query と Key のスコア (トークン同士の関連度 / 重要度)を計算する.

 \tilde{\alpha}_{i,\ j}\ =\ \displaystyle \frac{q_{i}\ \cdot\ k_{j}}{\sqrt{d}}


分子は、Query と Key の内積 (Dot Product)であり、ベクトル同士の類似度のような役割を果たす. (似通っているベクトル同士であれば、大きい値を出してくれる.)
分母は、ベクトルの次元数  d が大きくなってしますと、内積の値も大きくなってしますので、  \sqrt{d} で調整している.


Softmax 関数で正規化して、重み  \alpha_{i,\ j} が完成.

スコア関数


\begin{align}
\tilde{\alpha}_{i,\ j}&\ =\ \displaystyle \frac{q_{i}\ \cdot\ k_{j}}{\sqrt{d}} \\
\alpha_{i,\ j}&\ =\ Softmax(\tilde{\alpha}_{i,\ j})
\end{align}



計算時の効率化

異なるトークンに対する出力は別々にするのではなく、一つの行列演算で効率よく計算することができる.

Attentionへの入力、Query、Key、Value、Attentionからの出力を以下のようにする.


\begin{align}
X&\ =\ (x_{1},\ \cdots,\ x_{n}) \\
\\
Q&\ =\ (q_{1},\ \cdots,\ q_{n}) \\
K&\ =\ (k_{1},\ \cdots,\ k_{n}) \\
V&\ =\ (v_{1},\ \cdots,\ v_{n}) \\
\\
A&\ =\ (a_{1},\ \cdots,\ a_{n}) \\
\end{align}


出力 A は以下のように表現でき、計算を並列化できる.

 
\begin{align}
Q&\ =\ XW^{Q} \\
K&\ =\ XW^{K} \\
V&\ =\ XW^{V} \\
\end{align}
 A\ =\ Attention(Q,\ K,\ V)\ =\ Softmax \left( \displaystyle \frac{Q\ K^{T}}{\sqrt{d}} \right) V

参考