こんな方におすすめ
深層学習・ディープラーニングでも必要な「リカレントニューラルネットワーク (Recurrent Neural Network : RNN)」の実装について知りたい
この記事では、リカレントニューラルネットワーク (Recurrent Neural Network : RNN)の実装内容を扱います
アルゴリズムについては、こちらの記事にまとめてます yhayato1320.hatenablog.com
キーワード・知ってると理解がしやすい
- RNN アルゴリズム
目次
環境
- Google Colaboratory を利用
- 今回は、numpy のみを利用
RNN レイヤの実装
まずは、RNN の1ステップの処理を行う RNN class の実装です.
順伝播 / Forward
RNN の順伝播は以下の式です
これを元に、forward
function を実装します
計算グラフ
逆伝播の計算のために、上のRNNの順伝播の計算グラフを考えます
逆伝播 / Backward
計算グラフから、逆伝播の計算グラフを考えます
これを元に backward
function を実装します
スクリプト
パラメータの更新が確認できます
Time RNN レイヤの実装
先ほどは、1 ステップ分の処理の実装でしたが、Time RNN レイヤは、T個のRNN レイヤから構成されます
また、隠れ状態を保持するかも決定できるような実装にします
順伝播 / Forward
RNN レイヤを複数作成して、パラメータを設置. 各RNN レイヤの forward
function を実行します.
最後に隠れ状態を出力します
逆伝播 / Backward
保持している各RNN レイヤの backward
function を実行します.
各RNN レイヤ間の逆伝播はこのようになっております
スクリプト
まとめ
- RNN レイヤと、それらを時系列分のレイヤを持つ、Time RNN を実装
- RNN class は、パラメータと順伝播、逆伝播の処理を持ち、RNN Time class はそれらの処理を取りまとめるように実装
参考
- ゼロから作るDeep Learning 2
- 5 リカレントニューラルネットワーク