前回は「バグに詰まっております。」という感じで、中途半端なところで記事が終わってしまいました。
今回は完成編です。
今回のこのソースコードはgithubに上げておりますので全体像を見たい方、使ってみたい方はこちらで確認してください。
学習の流れ
- MNIST から画像・ラベル(正解)を取得
- 学習
- 学習で得たパラメーターたちを適当なフォルダに保存
- 学習したデータでテスト
実装
前回で実装はかなり終わっていました。(ほんとはアドベントカレンダーに向けての記事 1 つで終わらせる予定だったので。。)
以下は、唯一紹介していなかった推論に用いる関数です。
# maxのindexを返す
def inference6(a1, a2, a3, b1, b2, b3, x) do
y1 = fc(50, 784, x, a1, b1) |> relu(50)
y2 = fc(100, 50, y1, a2, b2) |> relu(100)
y = fc(10, 100, y2, a3, b3) |> softmax(10)
y
|> List.flatten()
|> Enum.with_index()
|> Enum.sort_by(&(elem(&1, 0)), &>=/2)
|> Enum.at(0)
|> elem(1)
end
この関数では最終的に y にそれぞれの数字である確率が入ります。 それを後半の処理で予想結果の数字として返しています。
性能(正答率)
今回僕が行った学習は
#今回の学習
epoch = 1
train_count = 2000
#大学の授業の最終課題との比較
epoch = 10
train_count = 60000
です。 なのでクソ少ないです。 理由は単純に時間がかかりすぎるためです。
大学の授業の最終課題(実装: C)では学習の終了は 20 分もかからなかった記憶があります。 今回、上記のようにかなり少ない学習だったにも関わらず、終了まで3 時間近くかかります。
正答率はどうなったかというと..。
testing ...
....................................................................................................
currect percentage (%) : 39.0
39.0%です! 学習が足りないためかなり低く感じるかもしれないですが、当てずっぽうよりはかなり当たっています。
また、大学の授業の最終課題(実装: C)epoch = 1
の段階で正答率が 30%ほどなのでほぼ同じ性能を出していると言えます。
大学の授業の最終課題では epoch = 10
では 80%近くまで正答率が上がるため、同じだけ学習すれば、同様程度まで性能を出せるものと予想できます。
(番外編)バグの原因はなんだったの?
バグの原因となっていた部分を見ていきます。
まず、誤差逆伝播(Softmax 層)の関数 softmaxwithlossbwd です。
Enum.withindex()はそれぞれの要素を {elm, index}
という風に index をつけて返してくれる関数です。
しかし、なぜか {elm, index}
で受け取った後、処理を行い、{index, elm}
という風に逆にして返していました。凡ミス〜
def softmaxwithloss_bwd(m, y, t) do
y = List.flatten(y)
[0,0,0,0,0,0,0,0,0,0]
|> Enum.with_index(0)
#- |> Enum.map(fn {ans, index} -> if index == t, do: {index-1, 1}, else: {index-1, 0} end)
+ |> Enum.map(fn {ans, index} -> if index == t, do: {1, index}, else: {0, index} end)
|> Enum.map(fn {ans, index} ->
Enum.at(y, index) - ans
end)
|> Enum.chunk_every(1)
end
次に、誤差逆伝播(fc 層)の関数 fcbwd です。
はじめに、もともと takewhile という珍しいものを使っていた部分です。
ここで行いたかった処理は rem(indexx, n) == index
の要素だけに絞るという処理です。
Enum.take_while がどのような関数なのかということですが、関数の通り条件に合致する要素を取得するというものです。 これだけ聞くと Enum.filter とほとんど役割は変わりませんが、比べてみると
iex(1)> Enum.take_while([1, 2, 3, 7, 9, 1, 1], fn x -> x < 3 end)
[1, 2]
iex(2)> Enum.filter([1, 2, 3, 7, 9, 1, 1], fn x -> x < 3 end)
[1, 2, 1, 1]
Enum.filter は条件に合致する物を全て、 Enum.take_while は条件に合致するもののうち、一番初めに合致して、合致しなくなるまでの部分のみを所得します。
なので上記のような違いが出ます。
僕は Enum.take_while なんてものはこれまで使ったことはなかったのですが、なぜかここで使っていました。 今回の目的では Enum.filter が適しているのでそちらに置き換えました。
その後の部分で、シンプルに計算する部分を間違えていたためここも修正しています。
def fc_bwd(m, n, x, dEdy, a, dEda, dEdx) do
#(省略)
dEdx =
Task.async(fn ->
dEdx
|> Enum.with_index()
|> Enum.map(fn {dedx, index} ->
+ dedx = 0
a
|> Enum.with_index()
#- |> Enum.take_while(fn {aa, indexx} -> rem(indexx, n) == index end)
+ |> Enum.filter(fn {aa, indexx} -> rem(indexx, n) == index end)
#- |> Enum.map(fn {aa, indexx} -> aa * Enum.at(dEdy, index) end)
+ |> Enum.map(fn {aa, indexx} ->
+ index_for_dedy = div(indexx, n)
+ aa * Enum.at(dEdy, index_for_dedy)
+ end)
|> Enum.sum()
end)
|> Enum.chunk_every(1)
end)
dEda = Task.await(dEda, 1000000)
dEdx = Task.await(dEdx, 1000000)
{dEda, dEdb, dEdx}
end
終わりに
今回学んだこととしては、睡眠不足の状態でプログラミングしない、ということにかぎりますね。。 実装も全体的に読みにくく、また、変なところで変なバグを生み出してしまいました。
来年の Elixir アドベントカレンダーも機会があれば参加してみたいです。 ここまで読んでくださった方ありがとうございました。