オムライスの備忘録

数学・統計学・機械学習・プログラミングに関することを記す

【深層学習】リカレントニューラルネットワーク #実装編 #01

こんな方におすすめ

深層学習・ディープラーニングでも必要な「リカレントニューラルネットワーク (Recurrent Neural Network : RNN)」の実装について知りたい



この記事では、リカレントニューラルネットワーク (Recurrent Neural Network : RNN)の実装内容を扱います

アルゴリズムについては、こちらの記事にまとめてます yhayato1320.hatenablog.com



キーワード・知ってると理解がしやすい


目次

環境

  • Google Colaboratory を利用
  • 今回は、numpy のみを利用

RNN レイヤの実装

まずは、RNN の1ステップの処理を行う RNN class の実装です.

順伝播 / Forward

RNN の順伝播は以下の式です


h_t = tanh(h_{t-1}W_{h}\ +\ x_{t}W_{x}\ +\ b)


これを元に、forward function を実装します

計算グラフ

逆伝播の計算のために、上のRNNの順伝播の計算グラフを考えます

逆伝播 / Backward

計算グラフから、逆伝播の計算グラフを考えます

これを元に backward function を実装します

スクリプト

パラメータの更新が確認できます

Time RNN レイヤの実装

先ほどは、1 ステップ分の処理の実装でしたが、Time RNN レイヤは、T個のRNN レイヤから構成されます

また、隠れ状態を保持するかも決定できるような実装にします

順伝播 / Forward

RNN レイヤを複数作成して、パラメータを設置. 各RNN レイヤの forward function を実行します. 最後に隠れ状態を出力します

逆伝播 / Backward

保持している各RNN レイヤの backward function を実行します.

各RNN レイヤ間の逆伝播はこのようになっております

スクリプト

まとめ

  • RNN レイヤと、それらを時系列分のレイヤを持つ、Time RNN を実装
  • RNN class は、パラメータと順伝播、逆伝播の処理を持ち、RNN Time class はそれらの処理を取りまとめるように実装

参考