いぬおさんのおもしろ数学実験室

おいしい紅茶でも飲みながら数学、物理、工学、プログラミング、そして読書を楽しみましょう

バッチ正規化の逆伝播の式を導く

バッチ正規化の話です。前回書いたとおり、画像認識のネットワークでデータを処理するとき、「各層で、特定のいくつかのユニットだけが値が大きくなる」という現象が起こります(「アクティべーションの分布が偏る」)。出力層でも「2」ばかりが出力されたりするのです。各層にはたくさんのユニットがあるのですから、もっと均等になるのが望ましそうです。そこで、ある層で無理矢理ユニットの値を平均と分散を使って標準化(正規化)し、データの偏りを減らしてやるのです。これをバッチ正規化と言います。逆伝播の仕組みを入れるのが面倒になりますが、今回はその式を導きます。

 

ぼくはC#で実装し、ユニットの誤差を偏微分の定義に基づいた計算と、これから示す逆伝播の理屈から求め、ほぼ一致することを確認しました(「勾配確認」と言っていいか分かりませんが、ここではそう呼びましょう)。以下に示すとおり、式自体は手間をかければ求まるのですが、大変だったのは実装です。ぼくがハマったのはバッチ正規化を入れた場合の誤差の扱いでした。今考えれば当たり前なのですがバッチ内のデータすべての誤差の総和を小さくする、と考えなければならなかったのでした。式が間違っているのか、式が正しくてもコーディングがまずいのかと考えて、もう10回くらいずつ丁寧に点検しました。間違いなさそうなのに勾配確認では一致しない。数日間、悩みに悩んで寝ようとしているときにボンヤリ考えていて、「あっ!!」と気がつきました。これもよい経験です。あとは実際にMNISTやCIFAR-10で認識率がどこまでいくか試すことになります。

 

ミニバッチのサイズをNとし、0≦n≦Nとします。nはミニバッチ内のデータ(サンプル)の番号です。n番のデータをネットワークに流したとき、第L層のj番目のユニットの値を次のように表します。

ミニバッチ全体(N個の画像データ)をネットワークに流したときの誤差をJとおきます。Jは、2乗誤差を使うなら次の通り。

Hは出力層(第K層)のユニット数。また例えばミニバッチ内の5番目のデータのラベル(正解)は2だとすると、

(One-Hot表現)

標準化に使う平均値と分散は、次の式で求めます。ここではユニットの番号jを固定し、N個のデータを相手にしています。

『深層学習(改訂第2版)』によると、平均や分散を取る範囲を変えると別の正規化の方法が導出される、とあります。例えばn, jを動かして平均を取る方法もあるということです。

標準化は次の式で行います。

ここで

です。値が0に近くなったとき計算を安定させるために正の小さな数ε(=0.0000001とか)を使っています。

γ、βはN個のデータを流す間に学習によって値を更新しますが、ここではγ=1、β=0(定数)としておきます。今回の目的はバッチ正規化のロジックの検討で、これがうまく行ってから考えればよいと思います(γ、βの更新の理屈は易しいです)。

 

では逆伝播の式を導きます。微分法の連鎖律によると

Σの変数は、混乱しないよう適宜変更して書きます。

m=nのときのみ1、それ以外では0となります。これを用いて、例えば以下のようになるのです。

なお、次の式も使っています。

同じ変数なら微分すれば1、そうでなければ0だからこうなるのです。

これらを使うと

以上から、

というわけで、お疲れ様!! これで逆伝播の式が導けました!!