誤差逆伝播法(1)
Published:
By nobCategory: Posts
数学
分数関数の微分
$$
\biggl\{ \frac{1}{f(x)} \biggr\}' = - \frac{f'(x)}{\{f(x)\}^2}
$$
指数関数の微分
$$
(e^{x})' = e^{x}
$$
合成関数の微分
\(y = f(u)\) , \(u = g(x)\) のとき
$$
\frac{dy}{dx} = \frac{dy}{du}\frac{du}{dx}
$$
ニューラルネットワークの関係式
活性化関数
$$
a(z) = \frac{1}{1+e^{-z}}
$$
$$
\begin{align}
a'(z) &= -\frac{(1+e^{-z})'}{(1+e^{-z})^2} \\
&= \frac{e^{-z}}{(1+e^{-z})^2} \\
&= \frac{1+e^{-z}-1}{(1+e^{-z})^2} \\
&= a(z)(1-a(z))
\end{align}
$$
入力層
\(4 \times 3=12\) 画素、モノクロ2階調
$$
a_{j}^{1} = x_{i} \; \; \; (i, j = 1, \dots, 12)
$$
中間層
$$
\begin{align}
z_1^{2} &= w_{1,1}^{2} x_{1} + \cdots + w_{1,12}^{2} x_{12} + b_{1}^{2} \\
z_2^{2} &= w_{2,1}^{2} x_{1} + \cdots + w_{2,12}^{2} x_{12} + b_{2}^{2} \\
z_3^{2} &= w_{3,1}^{2} x_{1} + \cdots + w_{3,12}^{2} x_{12} + b_{3}^{2} \\
a_1^{2} &= a(z_{1}^{2}) \\
a_2^{2} &= a(z_{2}^{2}) \\
a_3^{2} &= a(z_{3}^{2}) \\
\end{align}
$$
出力層
$$
\begin{align}
z_1^{3} &= w_{1,1}^3 a_1^{2} + w_{1,2}^3 a_2^{2} + w_{1,3}^3 a_3^{2}+ b_1^3 \\
z_2^{3} &= w_{2,1}^3 a_1^{2} + w_{2,2}^3 a_2^{2} + w_{2,3}^3 a_3^{2} +b_2^3 \\
a_1^{3} &= a(z_1^{3}) \\
a_2^{3} &= a(z_2^{3}) \\
\end{align}
$$
ネットワーク全体の2乗誤差
微分計算の簡素化のため \(1/2\) を掛ける。
$$
C = \frac{1}{2}\lbrace{{(t_1 - a_1^{3})^2 + (t_2 - a_2^{3})^2}\rbrace}
$$
コスト関数
訓練用の画像が64個あるので、コスト関数 \(C_\Gamma\) は以下のように表現される。
$$
C_\Gamma = \sum_{i=1}^{64} C_i
$$
パラメータを求める
中間層
$$
\begin{align}
\frac{\partial C_\Gamma}{\partial w_{1,1}^{2}} &= 0, \dots ,
\frac{\partial C_\Gamma}{\partial w_{1,12}^{2}} = 0,
\frac{\partial C_\Gamma}{\partial b_{1}^{2}} = 0 \\
\frac{\partial C_\Gamma}{\partial w_{2,1}^{2}} &= 0, \dots ,
\frac{\partial C_\Gamma}{\partial w_{2,12}^{2}} = 0,
\frac{\partial C_\Gamma}{\partial b_{2}^{2}} = 0 \\
\frac{\partial C_\Gamma}{\partial w_{3,1}^{2}} &= 0, \dots ,
\frac{\partial C_\Gamma}{\partial w_{3,12}^{2}} = 0,
\frac{\partial C_\Gamma}{\partial b_{3}^{2}} = 0 \\
\end{align}
$$
出力層
$$
\begin{align}
\frac{\partial C_\Gamma}{\partial w_{1,1}^{3}} &= 0,
\frac{\partial C_\Gamma}{\partial w_{1,2}^{3}} = 0,
\frac{\partial C_\Gamma}{\partial w_{1,3}^{3}} = 0,
\frac{\partial C_\Gamma}{\partial b_{1}^{3}} = 0 \\
\frac{\partial C_\Gamma}{\partial w_{2,1}^{3}} &= 0,
\frac{\partial C_\Gamma}{\partial w_{2,2}^{3}} = 0,
\frac{\partial C_\Gamma}{\partial w_{2,3}^{3}} = 0,
\frac{\partial C_\Gamma}{\partial b_{2}^{3}} = 0 \\
\end{align}
$$
勾配降下法を適用する
$$
(\varDelta w_{1,1}^{2}, \dots, b_{1}^{2}, \dots \varDelta w_{1,1}^{3}, \dots, b_{1}^{3}, \dots) =
-\eta \bigl(
\frac{\partial C_\Gamma}{\partial w_{1,1}^{2}}, \dots, \frac{\partial C_\Gamma}{\partial b_{1}^{2}}, \dots,
\frac{\partial C_\Gamma}{\partial w_{1,1}^{3}}, \dots, \frac{\partial C_\Gamma}{\partial b_{1}^{3}}, \dots
\bigr)
$$
ユニットの誤差を定義する
定義
$$
\delta_{j}^{l} = \frac{\partial C}{\partial z_{j}^{l}} \; \; \; (l=2, 3)
$$
\(\frac{\partial C_\Gamma}{\partial w_{1,1}^{2}}\) を求めてみる
$$
\begin{align}
\frac{\partial C_\Gamma}{\partial w_{1,1}^{2}} &= \frac{\partial C}{\partial z_{1}^{2}}\frac{\partial z_{1}^{2}}{\partial w_{1,1}^{2}} \\
&= \frac{\partial C}{\partial z_{1}^{2}} x_{1} \\
&= \delta_{1}^{2} x_{1} \\
&= \delta_{1}^2 a_{1}^{1}
\end{align}
$$
\(\frac{\partial C_\Gamma}{\partial w_{1,1}^{3}}\) を求めてみる
$$
\begin{align}
\frac{\partial C_\Gamma}{\partial w_{1,1}^{3}} &= \frac{\partial C}{\partial z_{1}^{3}}\frac{\partial z_{1}^{3}}{\partial w_{1,1}^{3}} \\
&= \frac{\partial C}{\partial z_{1}^{3}} a_{1}^{2} \\
&= \delta_{1}^3 a_{1}^{2}
\end{align}
$$
ユニットの誤差を使って重みを表現する
$$
\frac{\partial C}{\partial w_{j,i}^{l}} = \delta_{j}^{l}a_{i}^{l-1}
$$
ユニットの誤差を使ってバイアスを表現する
$$
\frac{\partial C}{\partial b_{j}^{l}} = \delta_{j}^{l} \;\;\;(l=2,3)
$$
出力層の誤差
$$
\begin{align}
\delta_{j}^{3} &= \frac{\partial C}{\partial z_{j}^{3}} \\
&= \frac{\partial C}{\partial a_{j}^{3}} \frac{\partial a_{j}^{3}}{\partial z_{j}^{3}} \\
&= \frac{\partial C}{\partial a_{j}^{3}} a'(z_{j}^{3})
\end{align}
$$
\(\delta_{1}^{3}\) を求めてみる
$$
\begin{align}
\delta_{1}^{3} &= \frac{\partial C}{\partial a_{1}^{3}} a'(z_{1}^{3}) \\
&= (t-a_{1}^{3})(-1) a'(z_{1}^{3}) \\
&= (a_{1}^{3} - t) a'(z_{1}^{3}) \\
\end{align}
$$
中間層の誤差
$$
\begin{align}
\delta_{1}^{2} &= \frac{\partial C}{\partial z_{1}^{2}} \\
&= \frac{\partial C}{\partial z_{1}^{3}}\frac{\partial z_{1}^{3}}{\partial a_{1}^{2}}\frac{\partial a_{1}^{2}}{\partial z_{1}^{2}} +
\frac{\partial C}{\partial z_{2}^{3}}\frac{\partial z_{2}^{3}}{\partial a_{1}^{2}}\frac{\partial a_{1}^{2}}{\partial z_{1}^{2}} \\
&= (\delta_{1}^{3}w_{1,1}^{3} + \delta_{2}^{3}w_{2,1}^{3}) a'(z_{1}^{2}) \\
&= (\delta_{1}^{3}, \delta_{2}^{3}) \cdot (w_{1,1}^{3}, w_{2,1}^{3}) a'(z_{1}^{2}) \\
\end{align}
$$
4x3画素、モノクロ2階調の画像を識別してみる
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-1 * x))
def activate(z):
return sigmoid(z)
def deactivate(a):
return a * (1 - a)
def sigma(w, x, b):
return np.dot(w, x) + b
def cost(t, x):
return ((t - x) ** 2 / 2).sum()
X = np.array(
[
[1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
[0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
[1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0],
[1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1],
[1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0],
[0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0],
[1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1],
[1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
[0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1],
[0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1],
[0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1],
[0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1],
[1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1],
[1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1],
[1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1],
[0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
[1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1],
[1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
[1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1],
[1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1],
[0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0],
[1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0],
[1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0],
[0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1],
[1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1],
[1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
[0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1],
[1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0],
[1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0],
[1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0],
[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0],
[1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
[1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1],
[0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
[0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0],
[0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
]
)
T = np.array(
[
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
]
)
np.set_printoptions(precision=3)
rng = np.random.default_rng(123)
# 入力: 12, 隠れ層のユニット数: 3
w2 = np.array(
[
[0.49038557898997, 0.348475676796501, 0.0725879008695083, 0.837472826850604, -0.0706798311519743, -3.6169369170322, -0.53557819719488, -0.0228584789393108, -1.71745249082217, -1.45563751579807, -0.555799932254451, 0.852476539980059],
[0.442372911956926, -0.536877487857221, 1.00782536916829, 1.07196001297575, -0.732814485632708, 0.822959617857012, -0.453282364154155, -0.0138979392949318, -0.0274233258563056, -0.426670298661898, 1.87560275441379, -2.30528048189891],
[0.654393041569443, -1.38856820257739, 1.24648311661583, 0.0572877158406771, -0.183237472237546, -0.74305066513479, -0.460930664925325, 0.331118557255208, 0.449470835925128, -1.29645372413246, 1.56850561324256, -0.470667153317658],
]
) # fmt: skip
b2 = np.array([-0.185002, 0.525677, -1.168623])
# 入力: 3 (=隠れ層のユニット数), 出力層のユニット数: 2
w3 = np.array(
[
[0.3880031194962, 0.803384989025837, 0.0292864334994403],
[0.0254467679708455, -0.790397993881956, 1.55313793058729],
]
)
b3 = np.array([-1.438040, -1.379338])
dw2 = np.zeros_like(w2)
dw3 = np.zeros_like(w3)
db2 = np.zeros_like(b2)
db3 = np.zeros_like(b3)
costs = []
learning_rate = 0.2
epochs = 50
debug = True
for epoch in range(epochs):
c = 0
for i, (x, t) in enumerate(zip(X, T)):
# 出力(第2層)
z2 = sigma(w2, x, b2)
a2 = activate(z2)
# 出力(第3層)
z3 = sigma(w3, a2, b3)
a3 = activate(z3)
# コスト
c = c + cost(t, a3)
# 誤差逆伝播
d3 = (a3 - t) * deactivate(a3)
dw3 = dw3 + d3[:, np.newaxis] * a2
db3 = db3 + d3
d2 = np.dot(d3, w3) * deactivate(a2)
dw2 = dw2 + d2[:, np.newaxis] * x
db2 = db2 + d2
costs.append(c)
w2 = w2 - learning_rate * dw2
w3 = w3 - learning_rate * dw3
b2 = b2 - learning_rate * db2
b3 = b3 - learning_rate * db3
dw2.fill(0)
dw3.fill(0)
db2.fill(0)
db3.fill(0)
if debug:
print("=" * 100)
print("[", epoch, "]", "cost=", c)
print("=" * 100)
print("layer 2 (weight)")
print(w2.reshape(3, 4, 3))
print("layer 2 (bias)")
print(b2)
print("layer 3 (weight)")
print(w3.reshape(2, 3))
print("layer 3 (bias)")
print(b3)
====================================================================================================
[ 0 ] cost= 20.255347730696887
====================================================================================================
layer 2 (weight)
[[[ 0.482 0.335 0.077]
[ 0.84 -0.091 -3.614]
[-0.533 -0.041 -1.713]
[-1.456 -0.572 0.855]]
[[ 0.446 -0.575 1.067]
[ 1.168 -0.851 0.902]
[-0.346 -0.143 0.055]
[-0.369 1.838 -2.226]]
[[ 0.753 -1.23 1.239]
[ 0.054 0.009 -0.726]
[-0.464 0.516 0.475]
[-1.273 1.746 -0.438]]]
layer 2 (bias)
[-0.201 0.501 -0.982]
layer 3 (weight)
[[ 0.28 1.191 0.056]
[ 0.257 -0.369 1.759]]
layer 3 (bias)
[-0.94 -0.727]
====================================================================================================
[ 1 ] cost= 14.428302221124186
====================================================================================================
layer 2 (weight)
[[[ 0.48 0.336 0.078]
[ 0.841 -0.093 -3.614]
[-0.532 -0.042 -1.712]
[-1.453 -0.574 0.857]]
[[ 0.444 -0.64 1.13 ]
[ 1.279 -1.001 0.99 ]
[-0.217 -0.312 0.155]
[-0.308 1.77 -2.119]]
[[ 0.767 -1.105 1.127]
[-0.069 0.224 -0.806]
[-0.574 0.719 0.379]
[-1.296 1.87 -0.496]]]
layer 2 (bias)
[-0.202 0.457 -0.875]
layer 3 (weight)
[[ 0.094 1.29 -0.098]
[ 0.451 -0.302 1.935]]
layer 3 (bias)
[-0.913 -0.464]
====================================================================================================
[ 2 ] cost= 12.243389610133285
====================================================================================================
layer 2 (weight)
[[[ 0.485 0.352 0.074]
[ 0.84 -0.074 -3.615]
[-0.533 -0.025 -1.716]
[-1.446 -0.56 0.859]]
[[ 0.433 -0.713 1.179]
[ 1.371 -1.144 1.061]
[-0.105 -0.476 0.24 ]
[-0.26 1.689 -2.021]]
[[ 0.687 -1.08 0.96 ]
[-0.264 0.363 -0.957]
[-0.752 0.848 0.199]
[-1.362 1.877 -0.626]]]
layer 2 (bias)
[-0.184 0.401 -0.902]
layer 3 (weight)
[[-0.073 1.458 -0.336]
[ 0.555 -0.588 2.008]]
layer 3 (bias)
[-0.881 -0.628]
====================================================================================================
[ 3 ] cost= 9.923756247383421
====================================================================================================
layer 2 (weight)
[[[ 0.504 0.393 0.071]
[ 0.843 -0.029 -3.617]
[-0.528 0.017 -1.718]
[-1.43 -0.52 0.864]]
[[ 0.395 -0.839 1.215]
[ 1.447 -1.326 1.105]
[-0.009 -0.679 0.307]
[-0.243 1.553 -1.934]]
[[ 0.713 -0.902 0.893]
[-0.334 0.574 -0.978]
[-0.811 1.051 0.154]
[-1.304 2.047 -0.644]]]
layer 2 (bias)
[-0.138 0.287 -0.737]
layer 3 (weight)
[[-0.221 1.658 -0.616]
[ 0.686 -0.625 2.253]]
layer 3 (bias)
[-0.858 -0.427]
====================================================================================================
[ 4 ] cost= 7.581090185728515
====================================================================================================
layer 2 (weight)
[[[ 5.115e-01 4.203e-01 6.302e-02]
[ 8.411e-01 3.583e-03 -3.621e+00]
[-5.294e-01 4.708e-02 -1.726e+00]
[-1.422e+00 -4.935e-01 8.598e-01]]
[[ 3.930e-01 -8.938e-01 1.258e+00]
[ 1.532e+00 -1.431e+00 1.168e+00]
[ 1.061e-01 -8.130e-01 4.061e-01]
[-1.870e-01 1.496e+00 -1.818e+00]]
[[ 6.199e-01 -8.833e-01 7.567e-01]
[-4.956e-01 6.553e-01 -1.098e+00]
[-9.575e-01 1.127e+00 -4.412e-03]
[-1.347e+00 2.040e+00 -7.660e-01]]]
layer 2 (bias)
[-0.108 0.259 -0.775]
layer 3 (weight)
[[-0.33 1.942 -0.831]
[ 0.742 -0.949 2.333]]
layer 3 (bias)
[-0.722 -0.657]
====================================================================================================
[ 5 ] cost= 5.679356822237967
====================================================================================================
layer 2 (weight)
[[[ 0.53 0.467 0.058]
[ 0.848 0.051 -3.622]
[-0.522 0.092 -1.729]
[-1.402 -0.448 0.861]]
[[ 0.348 -1.023 1.274]
[ 1.567 -1.576 1.174]
[ 0.168 -0.982 0.452]
[-0.191 1.366 -1.753]]
[[ 0.636 -0.713 0.721]
[-0.522 0.809 -1.084]
[-0.978 1.277 -0.021]
[-1.267 2.192 -0.775]]]
layer 2 (bias)
[-0.057 0.145 -0.62 ]
layer 3 (weight)
[[-0.442 2.082 -1.081]
[ 0.831 -1.008 2.533]]
layer 3 (bias)
[-0.777 -0.519]
====================================================================================================
[ 6 ] cost= 4.332227065429195
====================================================================================================
layer 2 (weight)
[[[ 0.533 0.489 0.05 ]
[ 0.846 0.075 -3.626]
[-0.523 0.115 -1.739]
[-1.397 -0.427 0.853]]
[[ 0.371 -1.035 1.306]
[ 1.636 -1.602 1.233]
[ 0.268 -1.043 0.547]
[-0.122 1.369 -1.648]]
[[ 0.569 -0.688 0.637]
[-0.62 0.838 -1.154]
[-1.066 1.308 -0.12 ]
[-1.299 2.195 -0.86 ]]]
layer 2 (bias)
[-0.034 0.173 -0.64 ]
layer 3 (weight)
[[-0.503 2.313 -1.201]
[ 0.87 -1.212 2.607]]
layer 3 (bias)
[-0.629 -0.656]
====================================================================================================
[ 7 ] cost= 3.4506357911592938
====================================================================================================
layer 2 (weight)
[[[ 0.543 0.524 0.045]
[ 0.853 0.108 -3.628]
[-0.515 0.145 -1.744]
[-1.382 -0.393 0.85 ]]
[[ 0.347 -1.117 1.313]
[ 1.652 -1.674 1.235]
[ 0.306 -1.138 0.579]
[-0.119 1.296 -1.602]]
[[ 0.562 -0.581 0.606]
[-0.639 0.911 -1.148]
[-1.079 1.378 -0.136]
[-1.263 2.285 -0.883]]]
layer 2 (bias)
[ 0.004 0.111 -0.552]
layer 3 (weight)
[[-0.578 2.401 -1.37 ]
[ 0.924 -1.283 2.73 ]]
layer 3 (bias)
[-0.676 -0.615]
====================================================================================================
[ 8 ] cost= 2.8681333226508494
====================================================================================================
layer 2 (weight)
[[[ 0.545 0.546 0.038]
[ 0.855 0.128 -3.631]
[-0.512 0.164 -1.751]
[-1.377 -0.373 0.843]]
[[ 0.354 -1.147 1.327]
[ 1.685 -1.69 1.261]
[ 0.362 -1.182 0.631]
[-0.082 1.283 -1.54 ]]
[[ 0.526 -0.529 0.56 ]
[-0.681 0.932 -1.172]
[-1.115 1.401 -0.179]
[-1.275 2.322 -0.931]]]
layer 2 (bias)
[ 0.026 0.111 -0.529]
layer 3 (weight)
[[-0.627 2.524 -1.471]
[ 0.961 -1.388 2.806]]
layer 3 (bias)
[-0.626 -0.658]
====================================================================================================
[ 9 ] cost= 2.4881334467695897
====================================================================================================
layer 2 (weight)
[[[ 0.547 0.568 0.033]
[ 0.859 0.147 -3.633]
[-0.508 0.183 -1.758]
[-1.37 -0.35 0.837]]
[[ 0.352 -1.189 1.334]
[ 1.702 -1.713 1.271]
[ 0.397 -1.227 0.663]
[-0.06 1.255 -1.496]]
[[ 0.502 -0.464 0.526]
[-0.704 0.957 -1.178]
[-1.132 1.425 -0.201]
[-1.277 2.374 -0.964]]]
layer 2 (bias)
[ 0.049 0.093 -0.486]
layer 3 (weight)
[[-0.676 2.605 -1.573]
[ 0.997 -1.458 2.884]]
layer 3 (bias)
[-0.619 -0.665]
====================================================================================================
[ 10 ] cost= 2.2138572637339453
====================================================================================================
layer 2 (weight)
[[[ 0.548 0.588 0.027]
[ 0.862 0.163 -3.635]
[-0.504 0.198 -1.764]
[-1.366 -0.331 0.831]]
[[ 0.354 -1.224 1.341]
[ 1.718 -1.727 1.281]
[ 0.43 -1.26 0.692]
[-0.037 1.234 -1.457]]
[[ 0.475 -0.407 0.492]
[-0.725 0.97 -1.185]
[-1.148 1.439 -0.223]
[-1.288 2.42 -0.998]]]
layer 2 (bias)
[ 0.069 0.081 -0.451]
layer 3 (weight)
[[-0.718 2.678 -1.658]
[ 1.029 -1.523 2.951]]
layer 3 (bias)
[-0.607 -0.68 ]
====================================================================================================
[ 11 ] cost= 2.000462713541164
====================================================================================================
layer 2 (weight)
[[[ 0.548 0.606 0.021]
[ 0.865 0.177 -3.638]
[-0.5 0.211 -1.771]
[-1.363 -0.314 0.824]]
[[ 0.357 -1.256 1.347]
[ 1.731 -1.737 1.288]
[ 0.457 -1.288 0.716]
[-0.016 1.214 -1.422]]
[[ 0.449 -0.353 0.461]
[-0.743 0.979 -1.19 ]
[-1.16 1.447 -0.24 ]
[-1.3 2.466 -1.03 ]]]
layer 2 (bias)
[ 0.087 0.069 -0.416]
layer 3 (weight)
[[-0.756 2.739 -1.735]
[ 1.059 -1.58 3.013]]
layer 3 (bias)
[-0.597 -0.692]
====================================================================================================
[ 12 ] cost= 1.8265387394086678
====================================================================================================
layer 2 (weight)
[[[ 0.547 0.622 0.015]
[ 0.867 0.189 -3.64 ]
[-0.496 0.222 -1.778]
[-1.362 -0.299 0.817]]
[[ 0.361 -1.287 1.353]
[ 1.741 -1.744 1.294]
[ 0.481 -1.311 0.736]
[ 0.004 1.195 -1.39 ]]
[[ 0.423 -0.301 0.431]
[-0.758 0.983 -1.194]
[-1.169 1.452 -0.255]
[-1.315 2.511 -1.061]]]
layer 2 (bias)
[ 0.102 0.058 -0.384]
layer 3 (weight)
[[-0.791 2.792 -1.806]
[ 1.086 -1.63 3.07 ]]
layer 3 (bias)
[-0.589 -0.703]
====================================================================================================
[ 13 ] cost= 1.6798725134647547
====================================================================================================
layer 2 (weight)
[[[ 0.545 0.636 0.01 ]
[ 0.87 0.199 -3.643]
[-0.493 0.231 -1.784]
[-1.361 -0.284 0.811]]
[[ 0.365 -1.316 1.358]
[ 1.751 -1.748 1.298]
[ 0.502 -1.329 0.752]
[ 0.023 1.177 -1.362]]
[[ 0.397 -0.251 0.402]
[-0.771 0.985 -1.197]
[-1.177 1.453 -0.268]
[-1.332 2.555 -1.091]]]
layer 2 (bias)
[ 0.116 0.046 -0.353]
layer 3 (weight)
[[-0.824 2.838 -1.871]
[ 1.111 -1.675 3.123]]
layer 3 (bias)
[-0.581 -0.714]
====================================================================================================
[ 14 ] cost= 1.5530392965388815
====================================================================================================
layer 2 (weight)
[[[ 0.542 0.649 0.004]
[ 0.872 0.208 -3.645]
[-0.49 0.239 -1.791]
[-1.361 -0.271 0.804]]
[[ 0.37 -1.344 1.363]
[ 1.758 -1.751 1.3 ]
[ 0.52 -1.345 0.766]
[ 0.042 1.158 -1.335]]
[[ 0.371 -0.204 0.374]
[-0.784 0.984 -1.199]
[-1.184 1.452 -0.28 ]
[-1.349 2.597 -1.12 ]]]
layer 2 (bias)
[ 0.129 0.035 -0.324]
layer 3 (weight)
[[-0.854 2.88 -1.932]
[ 1.134 -1.716 3.173]]
layer 3 (bias)
[-0.575 -0.725]
====================================================================================================
[ 15 ] cost= 1.4413721502103198
====================================================================================================
layer 2 (weight)
[[[ 5.395e-01 6.606e-01 -1.520e-03]
[ 8.735e-01 2.156e-01 -3.647e+00]
[-4.875e-01 2.467e-01 -1.798e+00]
[-1.362e+00 -2.597e-01 7.971e-01]]
[[ 3.761e-01 -1.370e+00 1.368e+00]
[ 1.765e+00 -1.752e+00 1.301e+00]
[ 5.368e-01 -1.358e+00 7.784e-01]
[ 5.973e-02 1.140e+00 -1.311e+00]]
[[ 3.459e-01 -1.598e-01 3.469e-01]
[-7.956e-01 9.821e-01 -1.202e+00]
[-1.190e+00 1.449e+00 -2.917e-01]
[-1.368e+00 2.637e+00 -1.147e+00]]]
layer 2 (bias)
[ 0.14 0.024 -0.296]
layer 3 (weight)
[[-0.882 2.917 -1.989]
[ 1.155 -1.755 3.221]]
layer 3 (bias)
[-0.568 -0.735]
====================================================================================================
[ 16 ] cost= 1.3418163884655132
====================================================================================================
layer 2 (weight)
[[[ 0.536 0.671 -0.007]
[ 0.875 0.222 -3.65 ]
[-0.485 0.253 -1.804]
[-1.363 -0.249 0.79 ]]
[[ 0.382 -1.397 1.373]
[ 1.772 -1.752 1.302]
[ 0.552 -1.369 0.788]
[ 0.077 1.122 -1.288]]
[[ 0.321 -0.118 0.321]
[-0.807 0.979 -1.205]
[-1.196 1.445 -0.303]
[-1.387 2.676 -1.174]]]
layer 2 (bias)
[ 0.149 0.012 -0.27 ]
layer 3 (weight)
[[-0.909 2.95 -2.044]
[ 1.174 -1.791 3.265]]
layer 3 (bias)
[-0.562 -0.745]
====================================================================================================
[ 17 ] cost= 1.2522668540965225
====================================================================================================
layer 2 (weight)
[[[ 0.533 0.681 -0.012]
[ 0.876 0.229 -3.652]
[-0.483 0.259 -1.811]
[-1.365 -0.239 0.784]]
[[ 0.389 -1.422 1.378]
[ 1.777 -1.752 1.301]
[ 0.565 -1.378 0.797]
[ 0.094 1.104 -1.266]]
[[ 0.297 -0.078 0.295]
[-0.818 0.974 -1.208]
[-1.202 1.44 -0.313]
[-1.405 2.712 -1.2 ]]]
layer 2 (bias)
[ 0.158 0. -0.247]
layer 3 (weight)
[[-0.934 2.981 -2.097]
[ 1.192 -1.824 3.308]]
layer 3 (bias)
[-0.556 -0.755]
====================================================================================================
[ 18 ] cost= 1.1711885012985146
====================================================================================================
layer 2 (weight)
[[[ 0.529 0.69 -0.018]
[ 0.877 0.234 -3.654]
[-0.482 0.263 -1.817]
[-1.367 -0.23 0.777]]
[[ 0.396 -1.447 1.382]
[ 1.782 -1.75 1.3 ]
[ 0.578 -1.385 0.803]
[ 0.11 1.086 -1.245]]
[[ 0.273 -0.041 0.271]
[-0.829 0.969 -1.21 ]
[-1.208 1.434 -0.323]
[-1.424 2.747 -1.225]]]
layer 2 (bias)
[ 0.166 -0.012 -0.225]
layer 3 (weight)
[[-0.958 3.01 -2.148]
[ 1.209 -1.855 3.349]]
layer 3 (bias)
[-0.55 -0.765]
====================================================================================================
[ 19 ] cost= 1.0974019780514392
====================================================================================================
layer 2 (weight)
[[[ 0.525 0.698 -0.023]
[ 0.877 0.239 -3.657]
[-0.48 0.268 -1.823]
[-1.369 -0.222 0.771]]
[[ 0.403 -1.472 1.387]
[ 1.786 -1.748 1.297]
[ 0.59 -1.392 0.809]
[ 0.126 1.068 -1.225]]
[[ 0.25 -0.007 0.248]
[-0.84 0.964 -1.213]
[-1.213 1.428 -0.333]
[-1.442 2.78 -1.249]]]
layer 2 (bias)
[ 0.173 -0.024 -0.204]
layer 3 (weight)
[[-0.981 3.037 -2.196]
[ 1.224 -1.885 3.389]]
layer 3 (bias)
[-0.544 -0.774]
====================================================================================================
[ 20 ] cost= 1.0299628391863533
====================================================================================================
layer 2 (weight)
[[[ 0.522 0.705 -0.028]
[ 0.877 0.244 -3.659]
[-0.479 0.272 -1.829]
[-1.372 -0.214 0.764]]
[[ 0.41 -1.497 1.392]
[ 1.79 -1.746 1.295]
[ 0.6 -1.397 0.813]
[ 0.142 1.049 -1.206]]
[[ 0.228 0.026 0.225]
[-0.85 0.959 -1.216]
[-1.219 1.422 -0.342]
[-1.46 2.811 -1.272]]]
layer 2 (bias)
[ 0.18 -0.037 -0.185]
layer 3 (weight)
[[-1.002 3.062 -2.243]
[ 1.239 -1.913 3.426]]
layer 3 (bias)
[-0.538 -0.782]
====================================================================================================
[ 21 ] cost= 0.9680924619511708
====================================================================================================
layer 2 (weight)
[[[ 0.518 0.712 -0.033]
[ 0.878 0.248 -3.661]
[-0.479 0.275 -1.835]
[-1.374 -0.207 0.758]]
[[ 0.418 -1.521 1.396]
[ 1.794 -1.743 1.291]
[ 0.61 -1.401 0.816]
[ 0.157 1.03 -1.188]]
[[ 0.207 0.057 0.204]
[-0.861 0.953 -1.219]
[-1.225 1.415 -0.352]
[-1.478 2.84 -1.294]]]
layer 2 (bias)
[ 0.186 -0.05 -0.168]
layer 3 (weight)
[[-1.022 3.086 -2.289]
[ 1.252 -1.94 3.462]]
layer 3 (bias)
[-0.532 -0.791]
====================================================================================================
[ 22 ] cost= 0.9111370257052971
====================================================================================================
layer 2 (weight)
[[[ 0.514 0.718 -0.037]
[ 0.878 0.251 -3.663]
[-0.478 0.278 -1.841]
[-1.377 -0.201 0.752]]
[[ 0.425 -1.545 1.401]
[ 1.797 -1.739 1.287]
[ 0.62 -1.404 0.818]
[ 0.172 1.011 -1.171]]
[[ 0.187 0.085 0.183]
[-0.871 0.948 -1.223]
[-1.23 1.409 -0.36 ]
[-1.494 2.867 -1.315]]]
layer 2 (bias)
[ 0.191 -0.063 -0.152]
layer 3 (weight)
[[-1.041 3.109 -2.333]
[ 1.265 -1.965 3.497]]
layer 3 (bias)
[-0.526 -0.799]
====================================================================================================
[ 23 ] cost= 0.8585416998820148
====================================================================================================
layer 2 (weight)
[[[ 0.51 0.723 -0.042]
[ 0.877 0.255 -3.665]
[-0.477 0.281 -1.847]
[-1.38 -0.195 0.747]]
[[ 0.433 -1.569 1.406]
[ 1.8 -1.736 1.282]
[ 0.629 -1.407 0.819]
[ 0.187 0.992 -1.154]]
[[ 0.167 0.112 0.164]
[-0.881 0.942 -1.226]
[-1.236 1.402 -0.369]
[-1.511 2.893 -1.335]]]
layer 2 (bias)
[ 0.196 -0.076 -0.137]
layer 3 (weight)
[[-1.059 3.131 -2.375]
[ 1.277 -1.989 3.53 ]]
layer 3 (bias)
[-0.52 -0.806]
====================================================================================================
[ 24 ] cost= 0.8098332037946312
====================================================================================================
layer 2 (weight)
[[[ 0.506 0.729 -0.046]
[ 0.877 0.258 -3.667]
[-0.477 0.284 -1.852]
[-1.383 -0.189 0.741]]
[[ 0.441 -1.593 1.41 ]
[ 1.803 -1.732 1.277]
[ 0.637 -1.409 0.819]
[ 0.201 0.973 -1.137]]
[[ 0.149 0.137 0.145]
[-0.891 0.937 -1.229]
[-1.242 1.396 -0.377]
[-1.526 2.917 -1.354]]]
layer 2 (bias)
[ 0.2 -0.09 -0.123]
layer 3 (weight)
[[-1.076 3.152 -2.416]
[ 1.288 -2.012 3.562]]
layer 3 (bias)
[-0.513 -0.814]
====================================================================================================
[ 25 ] cost= 0.7646071204200937
====================================================================================================
layer 2 (weight)
[[[ 0.502 0.733 -0.05 ]
[ 0.877 0.261 -3.669]
[-0.477 0.286 -1.857]
[-1.385 -0.184 0.735]]
[[ 0.449 -1.616 1.415]
[ 1.806 -1.728 1.271]
[ 0.645 -1.41 0.819]
[ 0.215 0.953 -1.121]]
[[ 0.131 0.161 0.127]
[-0.9 0.932 -1.232]
[-1.247 1.39 -0.385]
[-1.541 2.94 -1.372]]]
layer 2 (bias)
[ 0.204 -0.104 -0.11 ]
layer 3 (weight)
[[-1.092 3.173 -2.456]
[ 1.299 -2.034 3.593]]
layer 3 (bias)
[-0.507 -0.821]
====================================================================================================
[ 26 ] cost= 0.7225180178393089
====================================================================================================
layer 2 (weight)
[[[ 0.499 0.738 -0.055]
[ 0.876 0.264 -3.671]
[-0.477 0.288 -1.862]
[-1.388 -0.179 0.73 ]]
[[ 0.457 -1.64 1.42 ]
[ 1.808 -1.724 1.266]
[ 0.653 -1.411 0.818]
[ 0.229 0.934 -1.105]]
[[ 0.114 0.183 0.11 ]
[-0.91 0.927 -1.235]
[-1.253 1.384 -0.393]
[-1.555 2.961 -1.389]]]
layer 2 (bias)
[ 0.208 -0.118 -0.098]
layer 3 (weight)
[[-1.107 3.194 -2.494]
[ 1.308 -2.055 3.623]]
layer 3 (bias)
[-0.5 -0.827]
====================================================================================================
[ 27 ] cost= 0.6832712903817526
====================================================================================================
layer 2 (weight)
[[[ 0.495 0.742 -0.058]
[ 0.875 0.266 -3.673]
[-0.477 0.29 -1.867]
[-1.391 -0.175 0.725]]
[[ 0.465 -1.663 1.425]
[ 1.811 -1.719 1.259]
[ 0.661 -1.411 0.817]
[ 0.243 0.914 -1.089]]
[[ 0.098 0.204 0.094]
[-0.919 0.922 -1.238]
[-1.258 1.378 -0.4 ]
[-1.569 2.981 -1.406]]]
layer 2 (bias)
[ 0.211 -0.132 -0.087]
layer 3 (weight)
[[-1.122 3.214 -2.531]
[ 1.318 -2.076 3.651]]
layer 3 (bias)
[-0.493 -0.834]
====================================================================================================
[ 28 ] cost= 0.646616081328722
====================================================================================================
layer 2 (weight)
[[[ 0.492 0.746 -0.062]
[ 0.875 0.268 -3.675]
[-0.477 0.292 -1.872]
[-1.394 -0.171 0.72 ]]
[[ 0.474 -1.686 1.429]
[ 1.813 -1.715 1.253]
[ 0.668 -1.411 0.815]
[ 0.256 0.895 -1.074]]
[[ 0.082 0.224 0.079]
[-0.928 0.917 -1.242]
[-1.264 1.373 -0.407]
[-1.582 3. -1.421]]]
layer 2 (bias)
[ 0.214 -0.147 -0.077]
layer 3 (weight)
[[-1.135 3.234 -2.566]
[ 1.327 -2.096 3.678]]
layer 3 (bias)
[-0.486 -0.84 ]
====================================================================================================
[ 29 ] cost= 0.6123389037345327
====================================================================================================
layer 2 (weight)
[[[ 0.488 0.749 -0.066]
[ 0.874 0.271 -3.676]
[-0.477 0.294 -1.877]
[-1.396 -0.167 0.715]]
[[ 0.482 -1.709 1.434]
[ 1.815 -1.71 1.247]
[ 0.675 -1.411 0.813]
[ 0.269 0.875 -1.059]]
[[ 0.068 0.242 0.065]
[-0.936 0.913 -1.245]
[-1.269 1.368 -0.414]
[-1.595 3.018 -1.436]]]
layer 2 (bias)
[ 0.217 -0.161 -0.067]
layer 3 (weight)
[[-1.148 3.253 -2.6 ]
[ 1.335 -2.115 3.704]]
layer 3 (bias)
[-0.478 -0.847]
====================================================================================================
[ 30 ] cost= 0.5802577423362514
====================================================================================================
layer 2 (weight)
[[[ 0.485 0.753 -0.069]
[ 0.873 0.273 -3.678]
[-0.477 0.295 -1.881]
[-1.399 -0.163 0.711]]
[[ 0.49 -1.731 1.439]
[ 1.818 -1.706 1.24 ]
[ 0.682 -1.41 0.81 ]
[ 0.283 0.856 -1.045]]
[[ 0.054 0.259 0.051]
[-0.945 0.909 -1.248]
[-1.274 1.363 -0.421]
[-1.606 3.035 -1.45 ]]]
layer 2 (bias)
[ 0.219 -0.175 -0.058]
layer 3 (weight)
[[-1.161 3.272 -2.633]
[ 1.343 -2.134 3.73 ]]
layer 3 (bias)
[-0.471 -0.853]
====================================================================================================
[ 31 ] cost= 0.55021654513121
====================================================================================================
layer 2 (weight)
[[[ 0.482 0.756 -0.073]
[ 0.873 0.275 -3.68 ]
[-0.477 0.297 -1.885]
[-1.401 -0.16 0.707]]
[[ 0.499 -1.753 1.444]
[ 1.82 -1.701 1.233]
[ 0.689 -1.409 0.807]
[ 0.295 0.837 -1.03 ]]
[[ 0.04 0.276 0.038]
[-0.953 0.905 -1.251]
[-1.279 1.358 -0.428]
[-1.618 3.051 -1.464]]]
layer 2 (bias)
[ 0.222 -0.189 -0.05 ]
layer 3 (weight)
[[-1.172 3.291 -2.664]
[ 1.351 -2.152 3.754]]
layer 3 (bias)
[-0.463 -0.858]
====================================================================================================
[ 32 ] cost= 0.5220801124013141
====================================================================================================
layer 2 (weight)
[[[ 0.478 0.759 -0.076]
[ 0.872 0.277 -3.681]
[-0.478 0.298 -1.89 ]
[-1.404 -0.156 0.702]]
[[ 0.507 -1.775 1.449]
[ 1.822 -1.696 1.227]
[ 0.696 -1.408 0.805]
[ 0.308 0.818 -1.016]]
[[ 0.028 0.291 0.025]
[-0.96 0.901 -1.254]
[-1.284 1.353 -0.434]
[-1.629 3.066 -1.476]]]
layer 2 (bias)
[ 0.224 -0.203 -0.042]
layer 3 (weight)
[[-1.183 3.31 -2.694]
[ 1.358 -2.17 3.777]]
layer 3 (bias)
[-0.455 -0.864]
====================================================================================================
[ 33 ] cost= 0.4957294612032177
====================================================================================================
layer 2 (weight)
[[[ 0.475 0.762 -0.079]
[ 0.871 0.279 -3.683]
[-0.478 0.3 -1.894]
[-1.406 -0.153 0.698]]
[[ 0.515 -1.796 1.453]
[ 1.824 -1.692 1.221]
[ 0.702 -1.407 0.802]
[ 0.321 0.799 -1.002]]
[[ 0.016 0.305 0.013]
[-0.968 0.897 -1.258]
[-1.289 1.349 -0.44 ]
[-1.639 3.08 -1.489]]]
layer 2 (bias)
[ 0.226 -0.216 -0.035]
layer 3 (weight)
[[-1.194 3.329 -2.723]
[ 1.365 -2.188 3.8 ]]
layer 3 (bias)
[-0.447 -0.87 ]
====================================================================================================
[ 34 ] cost= 0.471057776972603
====================================================================================================
layer 2 (weight)
[[[ 4.725e-01 7.641e-01 -8.215e-02]
[ 8.700e-01 2.804e-01 -3.684e+00]
[-4.785e-01 3.008e-01 -1.897e+00]
[-1.408e+00 -1.503e-01 6.943e-01]]
[[ 5.238e-01 -1.817e+00 1.458e+00]
[ 1.826e+00 -1.687e+00 1.214e+00]
[ 7.084e-01 -1.405e+00 7.989e-01]
[ 3.329e-01 7.809e-01 -9.887e-01]]
[[ 4.610e-03 3.191e-01 2.012e-03]
[-9.747e-01 8.941e-01 -1.261e+00]
[-1.294e+00 1.345e+00 -4.465e-01]
[-1.649e+00 3.093e+00 -1.500e+00]]]
layer 2 (bias)
[ 0.228 -0.229 -0.028]
layer 3 (weight)
[[-1.204 3.347 -2.751]
[ 1.372 -2.205 3.821]]
layer 3 (bias)
[-0.439 -0.875]
====================================================================================================
[ 35 ] cost= 0.44796705714045626
====================================================================================================
layer 2 (weight)
[[[ 0.47 0.767 -0.085]
[ 0.869 0.282 -3.685]
[-0.479 0.302 -1.901]
[-1.411 -0.148 0.691]]
[[ 0.532 -1.837 1.463]
[ 1.828 -1.682 1.208]
[ 0.715 -1.403 0.796]
[ 0.345 0.763 -0.975]]
[[-0.006 0.332 -0.009]
[-0.981 0.891 -1.264]
[-1.298 1.341 -0.452]
[-1.658 3.105 -1.511]]]
layer 2 (bias)
[ 0.23 -0.242 -0.022]
layer 3 (weight)
[[-1.213 3.365 -2.778]
[ 1.378 -2.221 3.842]]
layer 3 (bias)
[-0.431 -0.881]
====================================================================================================
[ 36 ] cost= 0.42636550971397674
====================================================================================================
layer 2 (weight)
[[[ 0.467 0.769 -0.088]
[ 0.868 0.284 -3.687]
[-0.479 0.303 -1.905]
[-1.413 -0.145 0.687]]
[[ 0.54 -1.856 1.467]
[ 1.829 -1.678 1.203]
[ 0.721 -1.402 0.794]
[ 0.357 0.746 -0.962]]
[[-0.016 0.344 -0.019]
[-0.988 0.888 -1.267]
[-1.303 1.338 -0.458]
[-1.667 3.117 -1.521]]]
layer 2 (bias)
[ 0.232 -0.255 -0.016]
layer 3 (weight)
[[-1.223 3.383 -2.803]
[ 1.384 -2.238 3.862]]
layer 3 (bias)
[-0.423 -0.886]
====================================================================================================
[ 37 ] cost= 0.4061657067505498
====================================================================================================
layer 2 (weight)
[[[ 0.464 0.771 -0.09 ]
[ 0.868 0.285 -3.688]
[-0.48 0.304 -1.908]
[-1.415 -0.142 0.683]]
[[ 0.549 -1.875 1.472]
[ 1.831 -1.673 1.197]
[ 0.726 -1.4 0.791]
[ 0.368 0.729 -0.95 ]]
[[-0.026 0.356 -0.029]
[-0.994 0.886 -1.27 ]
[-1.307 1.335 -0.464]
[-1.675 3.128 -1.531]]]
layer 2 (bias)
[ 0.234 -0.267 -0.01 ]
layer 3 (weight)
[[-1.231 3.401 -2.828]
[ 1.39 -2.253 3.881]]
layer 3 (bias)
[-0.415 -0.891]
====================================================================================================
[ 38 ] cost= 0.38728342707866054
====================================================================================================
layer 2 (weight)
[[[ 0.462 0.773 -0.093]
[ 0.867 0.287 -3.689]
[-0.48 0.306 -1.911]
[-1.417 -0.14 0.68 ]]
[[ 0.557 -1.894 1.476]
[ 1.833 -1.669 1.192]
[ 0.732 -1.398 0.789]
[ 0.38 0.712 -0.937]]
[[-0.035 0.367 -0.038]
[-1. 0.883 -1.273]
[-1.311 1.332 -0.469]
[-1.683 3.139 -1.54 ]]]
layer 2 (bias)
[ 0.236 -0.279 -0.004]
layer 3 (weight)
[[-1.24 3.418 -2.851]
[ 1.396 -2.269 3.9 ]]
layer 3 (bias)
[-0.407 -0.897]
====================================================================================================
[ 39 ] cost= 0.36963707107333454
====================================================================================================
layer 2 (weight)
[[[ 0.459 0.775 -0.095]
[ 0.866 0.289 -3.69 ]
[-0.481 0.307 -1.915]
[-1.419 -0.138 0.677]]
[[ 0.565 -1.911 1.481]
[ 1.835 -1.665 1.187]
[ 0.738 -1.396 0.787]
[ 0.391 0.696 -0.925]]
[[-0.044 0.377 -0.047]
[-1.006 0.881 -1.276]
[-1.314 1.329 -0.474]
[-1.691 3.149 -1.549]]]
layer 2 (bias)
[ 0.237 -0.29 0.001]
layer 3 (weight)
[[-1.248 3.436 -2.873]
[ 1.402 -2.284 3.918]]
layer 3 (bias)
[-0.399 -0.901]
====================================================================================================
[ 40 ] cost= 0.3531475024724594
====================================================================================================
layer 2 (weight)
[[[ 0.457 0.777 -0.098]
[ 0.865 0.29 -3.691]
[-0.481 0.308 -1.918]
[-1.42 -0.135 0.674]]
[[ 0.573 -1.928 1.485]
[ 1.836 -1.661 1.183]
[ 0.743 -1.394 0.785]
[ 0.402 0.681 -0.913]]
[[-0.052 0.387 -0.055]
[-1.011 0.879 -1.279]
[-1.318 1.326 -0.48 ]
[-1.698 3.159 -1.557]]]
layer 2 (bias)
[ 0.239 -0.301 0.006]
layer 3 (weight)
[[-1.256 3.452 -2.895]
[ 1.407 -2.299 3.935]]
layer 3 (bias)
[-0.391 -0.906]
====================================================================================================
[ 41 ] cost= 0.33773816944179186
====================================================================================================
layer 2 (weight)
[[[ 0.455 0.779 -0.1 ]
[ 0.864 0.292 -3.693]
[-0.481 0.309 -1.921]
[-1.422 -0.133 0.671]]
[[ 0.58 -1.945 1.489]
[ 1.838 -1.657 1.178]
[ 0.749 -1.392 0.783]
[ 0.412 0.665 -0.901]]
[[-0.06 0.396 -0.063]
[-1.017 0.877 -1.282]
[-1.322 1.324 -0.485]
[-1.705 3.168 -1.565]]]
layer 2 (bias)
[ 0.24 -0.311 0.01 ]
layer 3 (weight)
[[-1.263 3.469 -2.915]
[ 1.412 -2.314 3.952]]
layer 3 (bias)
[-0.383 -0.911]
====================================================================================================
[ 42 ] cost= 0.32333537411032315
====================================================================================================
layer 2 (weight)
[[[ 0.453 0.781 -0.102]
[ 0.864 0.293 -3.694]
[-0.482 0.31 -1.923]
[-1.424 -0.131 0.668]]
[[ 0.588 -1.961 1.494]
[ 1.839 -1.653 1.175]
[ 0.754 -1.39 0.782]
[ 0.423 0.651 -0.89 ]]
[[-0.068 0.405 -0.071]
[-1.021 0.876 -1.285]
[-1.325 1.321 -0.489]
[-1.711 3.177 -1.573]]]
layer 2 (bias)
[ 0.242 -0.322 0.015]
layer 3 (weight)
[[-1.27 3.485 -2.935]
[ 1.417 -2.328 3.968]]
layer 3 (bias)
[-0.375 -0.916]
====================================================================================================
[ 43 ] cost= 0.30986858826012853
====================================================================================================
layer 2 (weight)
[[[ 0.451 0.783 -0.104]
[ 0.863 0.294 -3.695]
[-0.482 0.311 -1.926]
[-1.425 -0.129 0.665]]
[[ 0.595 -1.976 1.498]
[ 1.841 -1.649 1.171]
[ 0.759 -1.388 0.781]
[ 0.433 0.637 -0.879]]
[[-0.075 0.414 -0.078]
[-1.026 0.874 -1.288]
[-1.328 1.319 -0.494]
[-1.718 3.185 -1.58 ]]]
layer 2 (bias)
[ 0.243 -0.331 0.019]
layer 3 (weight)
[[-1.277 3.501 -2.954]
[ 1.422 -2.342 3.984]]
layer 3 (bias)
[-0.368 -0.92 ]
====================================================================================================
[ 44 ] cost= 0.2972707446089732
====================================================================================================
layer 2 (weight)
[[[ 0.449 0.785 -0.106]
[ 0.862 0.296 -3.696]
[-0.483 0.312 -1.929]
[-1.427 -0.127 0.663]]
[[ 0.603 -1.991 1.502]
[ 1.842 -1.645 1.168]
[ 0.764 -1.385 0.779]
[ 0.442 0.623 -0.868]]
[[-0.082 0.422 -0.085]
[-1.031 0.873 -1.291]
[-1.331 1.317 -0.499]
[-1.724 3.193 -1.587]]]
layer 2 (bias)
[ 0.244 -0.341 0.023]
layer 3 (weight)
[[-1.284 3.517 -2.973]
[ 1.427 -2.356 3.999]]
layer 3 (bias)
[-0.36 -0.925]
====================================================================================================
[ 45 ] cost= 0.28547846219172934
====================================================================================================
layer 2 (weight)
[[[ 0.447 0.786 -0.108]
[ 0.861 0.297 -3.697]
[-0.483 0.313 -1.931]
[-1.428 -0.125 0.66 ]]
[[ 0.61 -2.005 1.506]
[ 1.843 -1.642 1.165]
[ 0.768 -1.383 0.779]
[ 0.452 0.61 -0.858]]
[[-0.088 0.43 -0.092]
[-1.035 0.872 -1.294]
[-1.334 1.316 -0.503]
[-1.729 3.2 -1.593]]]
layer 2 (bias)
[ 0.246 -0.349 0.027]
layer 3 (weight)
[[-1.29 3.532 -2.99 ]
[ 1.432 -2.369 4.014]]
layer 3 (bias)
[-0.353 -0.929]
====================================================================================================
[ 46 ] cost= 0.27443218749102744
====================================================================================================
layer 2 (weight)
[[[ 0.445 0.788 -0.11 ]
[ 0.861 0.299 -3.698]
[-0.484 0.314 -1.934]
[-1.43 -0.124 0.658]]
[[ 0.617 -2.019 1.51 ]
[ 1.844 -1.638 1.162]
[ 0.773 -1.381 0.778]
[ 0.461 0.597 -0.847]]
[[-0.095 0.438 -0.098]
[-1.039 0.871 -1.297]
[-1.337 1.314 -0.508]
[-1.735 3.208 -1.6 ]]]
layer 2 (bias)
[ 0.247 -0.358 0.031]
layer 3 (weight)
[[-1.296 3.547 -3.007]
[ 1.436 -2.383 4.028]]
layer 3 (bias)
[-0.346 -0.933]
====================================================================================================
[ 47 ] cost= 0.264076249275664
====================================================================================================
layer 2 (weight)
[[[ 0.443 0.789 -0.112]
[ 0.86 0.3 -3.698]
[-0.484 0.315 -1.936]
[-1.431 -0.122 0.655]]
[[ 0.624 -2.032 1.513]
[ 1.846 -1.635 1.159]
[ 0.777 -1.379 0.777]
[ 0.47 0.585 -0.837]]
[[-0.1 0.445 -0.104]
[-1.043 0.87 -1.299]
[-1.34 1.313 -0.512]
[-1.74 3.215 -1.605]]]
layer 2 (bias)
[ 0.248 -0.366 0.035]
layer 3 (weight)
[[-1.302 3.562 -3.024]
[ 1.44 -2.395 4.042]]
layer 3 (bias)
[-0.339 -0.937]
====================================================================================================
[ 48 ] cost= 0.25435883517324565
====================================================================================================
layer 2 (weight)
[[[ 0.441 0.791 -0.114]
[ 0.859 0.301 -3.699]
[-0.484 0.316 -1.939]
[-1.432 -0.12 0.653]]
[[ 0.631 -2.044 1.517]
[ 1.847 -1.631 1.157]
[ 0.781 -1.377 0.777]
[ 0.479 0.573 -0.828]]
[[-0.106 0.452 -0.11 ]
[-1.047 0.869 -1.302]
[-1.343 1.311 -0.516]
[-1.745 3.221 -1.611]]]
layer 2 (bias)
[ 0.25 -0.374 0.039]
layer 3 (weight)
[[-1.308 3.576 -3.04 ]
[ 1.445 -2.408 4.055]]
layer 3 (bias)
[-0.332 -0.941]
====================================================================================================
[ 49 ] cost= 0.24523190314535037
====================================================================================================
layer 2 (weight)
[[[ 0.44 0.792 -0.116]
[ 0.859 0.303 -3.7 ]
[-0.485 0.317 -1.941]
[-1.433 -0.118 0.651]]
[[ 0.638 -2.057 1.521]
[ 1.848 -1.628 1.155]
[ 0.786 -1.375 0.777]
[ 0.488 0.562 -0.818]]
[[-0.112 0.458 -0.116]
[-1.051 0.869 -1.304]
[-1.345 1.31 -0.52 ]
[-1.75 3.228 -1.616]]]
layer 2 (bias)
[ 0.251 -0.382 0.042]
layer 3 (weight)
[[-1.314 3.59 -3.055]
[ 1.449 -2.42 4.068]]
layer 3 (bias)
[-0.326 -0.945]
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot()
ax.plot(costs, marker="o", markersize=5)
ax.grid()
plt.minorticks_on()
plt.show()
XORを識別してみる
rng = np.random.default_rng(123)
X = np.array(
[
[0, 0],
[0, 1],
[1, 0],
[1, 1],
]
)
T = np.array(
[
[0],
[1],
[1],
[0],
]
)
w2 = rng.random(size=(2, 2)) - 0.5
b2 = rng.random(2) - 0.5
w3 = rng.random(size=(1, 2)) - 0.5
b3 = rng.random(1) - 0.5
dw2 = np.zeros_like(w2)
dw3 = np.zeros_like(w3)
db2 = np.zeros_like(b2)
db3 = np.zeros_like(b3)
costs = []
learning_rate = 0.5
epochs = 3000
debug = False
for epoch in range(epochs):
c = 0
for i, (x, t) in enumerate(zip(X, T)):
# 出力(第2層)
z2 = sigma(w2, x, b2)
a2 = activate(z2)
# 出力(第3層)
z3 = sigma(w3, a2, b3)
a3 = activate(z3)
# コスト
c = c + cost(t, a3)
# 誤差逆伝播
d3 = (a3 - t) * deactivate(a3)
dw3 = dw3 + d3[:, np.newaxis] * a2
db3 = db3 + d3
d2 = np.dot(d3, w3) * deactivate(a2)
dw2 = dw2 + d2[:, np.newaxis] * x
db2 = db2 + d2
costs.append(c)
w2 = w2 - learning_rate * dw2
w3 = w3 - learning_rate * dw3
b2 = b2 - learning_rate * db2
b3 = b3 - learning_rate * db3
dw2.fill(0)
dw3.fill(0)
db2.fill(0)
db3.fill(0)
if debug:
print("=" * 100)
print("[", epoch, "]", "cost=", c)
print("=" * 100)
print("layer 2 (weight)")
print(w2)
print("layer 2 (bias)")
print(b2)
print("layer 3 (weight)")
print(w3)
print("layer 3 (bias)")
print(b3)
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot()
ax.plot(costs, marker="o", markersize=2)
ax.grid()
plt.minorticks_on()
plt.show()
訓練データの出典
- 涌井良幸, 涌井貞美 著. ディープラーニングがわかる数学入門, 技術評論社, 2017.4. 978-4-7741-8814-0.