XLogo

Research

Articles

Multi Head Attention 内部処理

MultiHeadAttention_Top

記事のまとめ

  1. Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V)=softmax\left( \frac {QK^T}{\sqrt {d_k}} \right) V
  2. ☝QとKの相関情報をVに負荷する

概要

Transformer_Arch2
本記事ではTransformerの内部に使用されているMulti Head Attentionの内部の層を一つずつ解説し、 なぜMulti Head Attentionが単語間の関連の理解で有用なのか説明します。
※この記事の図は2017年にGoogleにより発表された論文「"Attention is All you need"」内の図を引用しています。
※本記事ではMulti Head Attentionの中でも1入力からQ(クエリ), K(キー), V(バリュー)から作成するものに説明を限定してします。 一部のMulti Head AttentionではEncoderの出力をK, Vとして受け取り、Decoder内部からQを受け取っていますが、考え方は本記事の解説と同様です。
※本記事ではバッチサイズ(並列化数)を1と仮定します。

Transformer全体の内部構造や入出力に関しては以下の記事をご参照ください。

Transformerの内部構造と文の生成top

Transformerの内部構造と文の生成

Multi Head Attentionの内部構造

Multi_Head_Attention
Multi Head Attentionは複数(図ではhh個) のHeadと1つのConcatから構成されます。Headとは図のような, 3個のLinearと1個のScaled Dot-Product Attentionからなります。
入力Q(クエリ),K(キー),V(バリュー)はそれぞれ以下のような役割を持ちます。Q,K,Vはモデルの構造上このような役割を持っているに過ぎず、意味などを定義しているわけではありません。 また、今回解説する1つの入力からQ,K,Vを作るような場合はQとKの役割の違いはあまり考慮しなくてもよいです。 QとKの役割の違いは特に翻訳タスクに使用されるようなEncoderを使用するモデルにおいて登場します。
Multi Head Attentionのそれぞれの層は以下のような機能を持ちます。
以下では各層の入出力を詳しく説明します。

LinearLinear

Linear層ではQ,K,Vを作成します。 Q,K,Vは以下のような重み行列の線形変換で作られます。
Q=xWQQ=xW_Q
K=xWKK=xW_K
V=xWVV=xW_V
入力xxはどれも同じ入力であり、xxの行列の大きさはn×dmodeln \times d_{model}となります。nnはトークンの要素数です。 Transformerが入力を単語ごとにトークン化するモデルなら、nnは入力文の単語数に開始記号などを加えた数になります。

WQW_QWKW_Kdmodel×dkd_{model} \times d_kの行列、WVW_Vdmodel×dvd_{model} \times d_vの行列です。論文「"Attention is All you need"」ではdk=dv=dmodelhd_k=d_v=\frac {d_{model}} h(hhはHeadの数)であり 、Q,K,Vはどれも同じ大きさの行列で、大きさはn×dmodelhn \times \frac{d_{model}} hです。 これらはConcatでつなぎ合わされ、xxの元の大きさに戻ります。 headごとに重み行列が決まっており、それぞれ異なる視点で情報を判断します。

ScaledScaledDotProductDot-ProductAttentionAttention

Scaled Dot Product Attentionでは入力をQ,K,VQ, K ,Vとして出力yyは以下の式で表されます。
y=Attention(Q,K,V)=softmax(QKTdk)Vy=Attention(Q,K,V)=softmax\left( \frac {QK^T}{\sqrt {d_k}} \right) V
このブロックの操作は後で詳しく解説します。

ConcatConcat

Concatでは全てのheadの入力を結合して出力します。そのため、出力yyは以下の式のようになります。
y=Concat(head1,head2,head3,....,headh)y=Concat(head_1, head_2, head_3,...., head_h)
Concatの出力は入力を横に結合して出力します。
[ head1head2...headh]\begin {bmatrix} \ head_1 & head_2 & ... & head_h \end{bmatrix}
headひとつあたりのサイズはn×dmodelhn \times \frac{d_{model}} hであるため、出力yyのサイズはn×dmodeln \times d_{model}となります。
※ただし、dkdvd_k \ne d_vの場合、出力yyのサイズは異なります。この場合は、Concatの次のLinearで出力yyのサイズをn×dmodeln \times d_{model}にします。

以上のように、Multi Head Attentionでは入出力で行列の大きさは変わりません。

Scaled Dot-Product Attentionの構造

SDPA
Scaled Dot-Product AttentionはMulti Head Attentionの中核を担うブロックで、 上図がScaled Dot-Product Attentionの図で、これを式に直すと下の式になります。
y=Attention(Q,K,V)=softmax(QKTdk)Vy=Attention(Q,K,V)=softmax\left( \frac {QK^T}{\sqrt {d_k}} \right) V
この式がどのように単語間に関連を持たせているか、一つずつ説明していきます。

MatMul(Q,KT)MatMul(Q, K^T)

MatMulは行列の積を計算する演算です。QQKTK^Tの積に操作について詳しく見てみると、 以下の図のようになります。 図の表記ではQ[a,b]Q_{[a, b]}QQaabb列の要素を指し、*はすべての要素を指します。
MatMul(Q,KT)MatMul(Q, K^T)K[,1]TK^T_{[*, 1]}K[,2]TK^T_{[*, 2]}...K[,n]TK^T_{[*, n]}
Q[1,]Q_{[1, *]}Q[1,]K[,1]TQ_{[1, *]} K^T_{[*, 1]}Q[1,]K[,2]TQ_{[1, *]} K^T_{[*, 2]}...Q[1,]K[,n]TQ_{[1, *]} K^T_{[*, n]}
Q[2,]Q_{[2, *]}Q[2,]K[,1]TQ_{[2, *]} K^T_{[*, 1]}Q[2,]K[,2]TQ_{[2, *]} K^T_{[*, 2]}...Q[2,]K[,n]TQ_{[2, *]} K^T_{[*, n]}
...............
Q[n,]Q_{[n, *]}Q[n,]K[,1]TQ_{[n, *]} K^T_{[*, 1]}Q[n,]K[,2]TQ_{[n, *]} K^T_{[*, 2]}...Q[n,]K[,n]TQ_{[n, *]} K^T_{[*, n]}
QQn×dkn \times d_kの行列であり、KTK^Tdk×nd_k \times nなので出力サイズはn×nn \times nになります。
また、Q=xWQQ=x W_QK=xWKK=x W_Kから、
MatMul(Q,KT)=xWQWKTxTMatMul(Q, K^T)=x W_Q W_K^T x^T
となり、本質的にはWQW_Qの行とWKW_Kの行を掛け合わせる操作になります。
行というのは各トークン(≒文中の単語)を表していますので、この演算により、重み行列WQ,WKW_Q, W_Kの行の相関関係の大きさを表現することで、単語間の関係を大まかに表現出来ます。

Scale(QKT)Scale ( {QK^T} )

MatMulMatMuln×dkn \times d_kdk×nd_k \times nの演算を行っていますが、これはd_Kに比例して計算量が増加し、同時に 絶対値が非常に大きな値が発生を許してしまうことになります。これでは後のsoftmax関数で正規化する際、不安定な確率分布になってしまいます。そのため、dk\sqrt d_kを使ってスケーリングします。
y=QKTdky= \frac {QK^T} {\sqrt d_k}
スケーリングに使う値としてはdk\sqrt d_kが最も適切であることが分かっています。こちらは機会があれば記事にします。

MaskMask

DecoderのMulti Head Attentionではトークンを入力して次のトークンを予測します。この時、次のトークン以降の情報を 得られないようにするため、上三角行列を- \inftyに設定します。- \inftyに設定することで後のsoftmax関数で正規化した際に確率として0に表現できるようになります。 以下に4×4の例を示します。
MaskK[,1]TK^T_{[*, 1]}K[,2]TK^T_{[*, 2]}K[,3]TK^T_{[*, 3]}K[,4]TK^T_{[*, 4]}
Q[1,]Q_{[1, *]}1.11.1-\infty-\infty-\infty
Q[2,]Q_{[2, *]}1.41.40.7-0.7-\infty-\infty
Q[3,]Q_{[3, *]}2.1-2.11.01.00.80.8-\infty
Q[4,]Q_{[4, *]}0.90.92.92.93.33.31.41.4
※表中の数字はダミーです このように上三角行列をすべて-\inftyにすることで、クエリは過去のキーは参照することが出来ますが、未来のキーは参照できなくなります。 この操作がどのような意味を持つのかはバリューをかける操作で分かります。

softmax(QKTdk)softmax \left( \frac {QK^T} {\sqrt d_k} \right)

softmaxは行列の各行の要素の値を0から1に補正し、全ての要素の和が1になるように計算します。
softmaxK[,1]TK^T_{[*, 1]}K[,2]TK^T_{[*, 2]}K[,3]TK^T_{[*, 3]}K[,4]TK^T_{[*, 4]}
Q[1,]Q_{[1, *]}1.001.000.000.000.000.000.000.00
Q[2,]Q_{[2, *]}0.670.670.330.330.000.000.000.00
Q[3,]Q_{[3, *]}0.020.020.540.540.440.440.000.00
Q[4,]Q_{[4, *]}0.050.050.350.350.520.520.080.08
以上のように、確率的にクエリとキーの関係を表現することが出来ます。

MatMul(softmax(QKTdk),V)MatMul \left( softmax \left( \frac {QK^T} {\sqrt d_k} \right) , V \right)

最後にsoftmaxした注意量とVV(バリュー)の行列積をとります。積を行うことでVVに情報量を付加することが出来ます。以下にVVの各トークンのベクトルをV1,V2,V3,V4V_1, V_2, V_3, V_4とした場合の例を示します。
[1.000.000.000.000.670.330.000.000.020.540.440.000.050.350.520.08][V1V2V3V4]=[1.00V10.67V1+0.33V20.02V1+0.54V2+0.44V30.05V1+0.35V2+0.52V3+0.08V4]\begin{bmatrix} 1.00 & 0.00 & 0.00 & 0.00 \\ 0.67 & 0.33 & 0.00 & 0.00 \\ 0.02 & 0.54 & 0.44 & 0.00 \\ 0.05 & 0.35 & 0.52 & 0.08 \end {bmatrix} \begin{bmatrix} V_1\\V_2\\V_3\\V_4 \end{bmatrix} = \begin{bmatrix}\begin{array}{l} 1.00V_1 \\ 0.67V_1 + 0.33V_2 \\ 0.02V_1+0.54V_2+0.44V_3\\0.05V_1+0.35V_2+0.52V_3+0.08V_4 \end{array} \end {bmatrix}
例えば、2行目の0.67V1+0.33V20.67V_1+ 0.33V_2は、(トークン1とトークン2の相関×トークン1の情報)と(トークン2とトークン2の相関×トークン2の情報) の和となり、トークン2がどのようなトークンなのかを推測することが出来ます。 この辺りの「情報」に関する概念はTransformerの内部で処理されており、解明されていない点も多いです。

Multi Head Attentionの特徴

Multi Head Attentionの構造から発生する特徴は以下の通りです。

まとめ

Conclusion
Multi Head Attentionの内部処理を順に以下にまとめます。
ご愛読ありがとうございます。