本記事ではAIを使用する時に登場するTransformerの基本的な仕組みを、筆者が学んだこと含めて纏めていきます。
今回はTransformerにおいてたびたび実行される、【Add & Norm】という処理について詳細にご紹介していきます。
今回解説する部分
今回解説する部分は、上図の赤枠内です。
上図の赤枠以外にも【Add & Norm】は存在しますが、本記事では分かりやすくするため、Multi-Head Attention後の処理のみ解説します。
Add(残差接続)
まず【Add & Norm】部分のAddとは何か、それは残差接続となります。
残差接続とは、残差ニューラルネットワークや残差ネットワーク、ResNetと呼ばれる、深層学習モデルです。
これだけでは全く意味が分からないため、より分かりやすく解説していきます。
まず残差接続は上図のように動いており、主な役割は事前に学習しておいたデータを用いるニューラルネットワーク(以下NN)から進化した仕組みです。
残差の意味について下記をご覧ください。
残差接続を理解するためにはNNの知識も必要なのですが、あまりにも本題から離れてしまうためNNの詳細は下記の記事をご覧ください。
上記の記事では、Transformer関係の記事を見ていたら頻出するDNNやCNN、RNNなどの解説もされているので、そちらも気になった方はご覧ください。
この残差接続はそんなNNから進化した1つの仕組みで、数式で表すと下記のようになります。
入力値が出力に直接影響するため、層が増えることによっておこる勾配消失の問題や元データが薄くなっていく問題を解消しています。
より概念的に理解する
残差接続とNNの違いとして勾配消失問題やデータが薄くなる問題を上げましたが、どうしてそのようなことが起こりえるのか、分かりやすく概念的にまとめていきます。
まずNNを概念として理解する場合、下図のようになります。
入力データをレイヤーが受け取り、その結果を出力していますが、入力と出力では同じように見えて違う結果が出力されています。
レイヤーの数が少なければ1番最初に入力されたデータの原型は保てますが、レイヤー数が多くなれば原型を保てなくなるのは自明です。
NNを例えるならば伝言ゲームであり、元の言葉(入力)は人(レイヤー)を介するにつれて答え(出力)がおかしくなっていく。
そんなNNと比較して、残差接続の概念として理解する場合、下図になります。
入力データをレイヤーが受け取るまでは同じで、出力結果として元データも併せて出力するのです。
残差接続を例えるならば、回答付伝言ゲームであり、自分よりも前に人達の回答を見られる伝言ゲームであるため、最終的に得られる結果(出力)は入力と大きく異なることはありません。
簡単に3回レイヤーを通す具体例を示すと、下記のようになります。
レイヤー1:
レイヤー2:
レイヤー3:
上記の数式を見れば、どれだけレイヤー数を経ても、1番最初に入力されたデータ()が出力結果に影響を及ぼし続けていることが分かります。
Norm(LayerNorm)
次に残差接続によって出力されたデータを正規化して整えていきます。
簡単に示すと正規化ですが、正しく表すとレイヤー正規化と呼ばれる処理を行っています。
レイヤー正規化の詳細については下記の記事をご覧ください。
レイヤー正規化を実行する意味を簡単に理解したいなら上記の記事がおすすめ。
しっかりとレイヤー正規化を知りたい人は、上記の記事がおすすめ。
他の正規化とどのように違うのか、なぜレイヤー正規化を使うのかを比較したい方は、上記の記事がおすすめ。
本記事ではまずはTransformerの構造や仕組みを理解することを優先するため、レイヤー正規化を一旦データの整列を都合よく正規化してくれる処理として処理します。
Add & Normを実行する意味
ここまで【Add & Norm】が、残差接続とレイヤー正規化を行っていると紹介してきましたが、ふと疑問に思いました
本記事で焦点を当てている部分は【Multi-Head Attention】の次に行われる【Add & Norm】であり、【Multi-Head Attention】内既に正規化されたデータが出力されているのに、どうしてまた正規化するのかと。
もっと言えば、どうして【Multi-Head Attention】で正規化までしたデータを再び変換する必要があるのかと。
この疑問を解決するため、下図をご覧ください。
上図の赤枠部分をよくよく見ると、【Add & Norm】に入力されるデータは、【Multi-Head Attention】から出力されるデータと、【Multi-Head Attention】に入力されるデータの2つなのです。
ここでAddつまり残差接続の特徴を思い出してみると、層が増えることによっておこる勾配消失の問題や元データが薄くなっていく問題を解消できるNNの1つです。
そして残差接続は値を変換していくとき元データを参照するため、勾配消失の問題や元データが薄くなっていく問題を解消できます。
つまり、この【Add & Norm】は1つ前の処理で変換されたデータが、元のデータとの関係を失わないよう補完していると捉えていいでしょう。
それを証明するように、上図の赤枠部分以外でも、【Add & Norm】がある場所では1つ前の処理が出力したデータと、1つ前の処理が入力されたデータの2つを受け取っています。
参考記事
Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385
上記が残差ネットワークのオリジナル論文です。
本記事の内容は以上となります。
今回の内容は非常に短いのですが、各処理後に行われる処理であるためちゃんとした理解が必要不可欠です。
ぜひ本記事でTransformerについて、私と共に学んでいきましょう。
本記事は以上です。お疲れさまでした。