blog

誤差逆伝播法(1)

Published:

By nob

Category: Posts

Tags: 機械学習 ニューラルネットワーク ディープラーニングがわかる数学入門 誤差逆伝播法 Python

数学

分数関数の微分

$$ \biggl\{ \frac{1}{f(x)} \biggr\}' = - \frac{f'(x)}{\{f(x)\}^2} $$

指数関数の微分

$$ (e^{x})' = e^{x} $$

合成関数の微分

\(y = f(u)\) , \(u = g(x)\) のとき

$$ \frac{dy}{dx} = \frac{dy}{du}\frac{du}{dx} $$

ニューラルネットワークの関係式

活性化関数

$$ a(z) = \frac{1}{1+e^{-z}} $$
$$ \begin{align} a'(z) &= -\frac{(1+e^{-z})'}{(1+e^{-z})^2} \\ &= \frac{e^{-z}}{(1+e^{-z})^2} \\ &= \frac{1+e^{-z}-1}{(1+e^{-z})^2} \\ &= a(z)(1-a(z)) \end{align} $$

入力層

\(4 \times 3=12\) 画素、モノクロ2階調

$$ a_{j}^{1} = x_{i} \; \; \; (i, j = 1, \dots, 12) $$

中間層

$$ \begin{align} z_1^{2} &= w_{1,1}^{2} x_{1} + \cdots + w_{1,12}^{2} x_{12} + b_{1}^{2} \\ z_2^{2} &= w_{2,1}^{2} x_{1} + \cdots + w_{2,12}^{2} x_{12} + b_{2}^{2} \\ z_3^{2} &= w_{3,1}^{2} x_{1} + \cdots + w_{3,12}^{2} x_{12} + b_{3}^{2} \\ a_1^{2} &= a(z_{1}^{2}) \\ a_2^{2} &= a(z_{2}^{2}) \\ a_3^{2} &= a(z_{3}^{2}) \\ \end{align} $$

出力層

$$ \begin{align} z_1^{3} &= w_{1,1}^3 a_1^{2} + w_{1,2}^3 a_2^{2} + w_{1,3}^3 a_3^{2}+ b_1^3 \\ z_2^{3} &= w_{2,1}^3 a_1^{2} + w_{2,2}^3 a_2^{2} + w_{2,3}^3 a_3^{2} +b_2^3 \\ a_1^{3} &= a(z_1^{3}) \\ a_2^{3} &= a(z_2^{3}) \\ \end{align} $$

ネットワーク全体の2乗誤差

微分計算の簡素化のため \(1/2\) を掛ける。

$$ C = \frac{1}{2}\lbrace{{(t_1 - a_1^{3})^2 + (t_2 - a_2^{3})^2}\rbrace} $$

コスト関数

訓練用の画像が64個あるので、コスト関数 \(C_\Gamma\) は以下のように表現される。

$$ C_\Gamma = \sum_{i=1}^{64} C_i $$

パラメータを求める

中間層

$$ \begin{align} \frac{\partial C_\Gamma}{\partial w_{1,1}^{2}} &= 0, \dots , \frac{\partial C_\Gamma}{\partial w_{1,12}^{2}} = 0, \frac{\partial C_\Gamma}{\partial b_{1}^{2}} = 0 \\ \frac{\partial C_\Gamma}{\partial w_{2,1}^{2}} &= 0, \dots , \frac{\partial C_\Gamma}{\partial w_{2,12}^{2}} = 0, \frac{\partial C_\Gamma}{\partial b_{2}^{2}} = 0 \\ \frac{\partial C_\Gamma}{\partial w_{3,1}^{2}} &= 0, \dots , \frac{\partial C_\Gamma}{\partial w_{3,12}^{2}} = 0, \frac{\partial C_\Gamma}{\partial b_{3}^{2}} = 0 \\ \end{align} $$

出力層

$$ \begin{align} \frac{\partial C_\Gamma}{\partial w_{1,1}^{3}} &= 0, \frac{\partial C_\Gamma}{\partial w_{1,2}^{3}} = 0, \frac{\partial C_\Gamma}{\partial w_{1,3}^{3}} = 0, \frac{\partial C_\Gamma}{\partial b_{1}^{3}} = 0 \\ \frac{\partial C_\Gamma}{\partial w_{2,1}^{3}} &= 0, \frac{\partial C_\Gamma}{\partial w_{2,2}^{3}} = 0, \frac{\partial C_\Gamma}{\partial w_{2,3}^{3}} = 0, \frac{\partial C_\Gamma}{\partial b_{2}^{3}} = 0 \\ \end{align} $$

勾配降下法を適用する

$$ (\varDelta w_{1,1}^{2}, \dots, b_{1}^{2}, \dots \varDelta w_{1,1}^{3}, \dots, b_{1}^{3}, \dots) = -\eta \bigl( \frac{\partial C_\Gamma}{\partial w_{1,1}^{2}}, \dots, \frac{\partial C_\Gamma}{\partial b_{1}^{2}}, \dots, \frac{\partial C_\Gamma}{\partial w_{1,1}^{3}}, \dots, \frac{\partial C_\Gamma}{\partial b_{1}^{3}}, \dots \bigr) $$

ユニットの誤差を定義する

定義

$$ \delta_{j}^{l} = \frac{\partial C}{\partial z_{j}^{l}} \; \; \; (l=2, 3) $$

\(\frac{\partial C_\Gamma}{\partial w_{1,1}^{2}}\) を求めてみる

$$ \begin{align} \frac{\partial C_\Gamma}{\partial w_{1,1}^{2}} &= \frac{\partial C}{\partial z_{1}^{2}}\frac{\partial z_{1}^{2}}{\partial w_{1,1}^{2}} \\ &= \frac{\partial C}{\partial z_{1}^{2}} x_{1} \\ &= \delta_{1}^{2} x_{1} \\ &= \delta_{1}^2 a_{1}^{1} \end{align} $$

\(\frac{\partial C_\Gamma}{\partial w_{1,1}^{3}}\) を求めてみる

$$ \begin{align} \frac{\partial C_\Gamma}{\partial w_{1,1}^{3}} &= \frac{\partial C}{\partial z_{1}^{3}}\frac{\partial z_{1}^{3}}{\partial w_{1,1}^{3}} \\ &= \frac{\partial C}{\partial z_{1}^{3}} a_{1}^{2} \\ &= \delta_{1}^3 a_{1}^{2} \end{align} $$

ユニットの誤差を使って重みを表現する

$$ \frac{\partial C}{\partial w_{j,i}^{l}} = \delta_{j}^{l}a_{i}^{l-1} $$

ユニットの誤差を使ってバイアスを表現する

$$ \frac{\partial C}{\partial b_{j}^{l}} = \delta_{j}^{l} \;\;\;(l=2,3) $$

出力層の誤差

$$ \begin{align} \delta_{j}^{3} &= \frac{\partial C}{\partial z_{j}^{3}} \\ &= \frac{\partial C}{\partial a_{j}^{3}} \frac{\partial a_{j}^{3}}{\partial z_{j}^{3}} \\ &= \frac{\partial C}{\partial a_{j}^{3}} a'(z_{j}^{3}) \end{align} $$

\(\delta_{1}^{3}\) を求めてみる

$$ \begin{align} \delta_{1}^{3} &= \frac{\partial C}{\partial a_{1}^{3}} a'(z_{1}^{3}) \\ &= (t-a_{1}^{3})(-1) a'(z_{1}^{3}) \\ &= (a_{1}^{3} - t) a'(z_{1}^{3}) \\ \end{align} $$

中間層の誤差

$$ \begin{align} \delta_{1}^{2} &= \frac{\partial C}{\partial z_{1}^{2}} \\ &= \frac{\partial C}{\partial z_{1}^{3}}\frac{\partial z_{1}^{3}}{\partial a_{1}^{2}}\frac{\partial a_{1}^{2}}{\partial z_{1}^{2}} + \frac{\partial C}{\partial z_{2}^{3}}\frac{\partial z_{2}^{3}}{\partial a_{1}^{2}}\frac{\partial a_{1}^{2}}{\partial z_{1}^{2}} \\ &= (\delta_{1}^{3}w_{1,1}^{3} + \delta_{2}^{3}w_{2,1}^{3}) a'(z_{1}^{2}) \\ &= (\delta_{1}^{3}, \delta_{2}^{3}) \cdot (w_{1,1}^{3}, w_{2,1}^{3}) a'(z_{1}^{2}) \\ \end{align} $$

4x3画素、モノクロ2階調の画像を識別してみる

import numpy as np


def sigmoid(x):
    return 1 / (1 + np.exp(-1 * x))


def activate(z):
    return sigmoid(z)


def deactivate(a):
    return a * (1 - a)


def sigma(w, x, b):
    return np.dot(w, x) + b


def cost(t, x):
    return ((t - x) ** 2 / 2).sum()


X = np.array(
    [
        [1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
        [0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
        [1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0],
        [1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0],
        [0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1],
        [1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0],
        [0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0],
        [1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1],
        [1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1],
        [1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1],
        [0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
        [0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1],
        [0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1],
        [0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1],
        [0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1],
        [0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1],
        [1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1],
        [1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1],
        [1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1],
        [1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1],
        [1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
        [1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1],
        [1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
        [1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1],
        [1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1],
        [0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0],
        [1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0],
        [1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0],
        [0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1],
        [1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1],
        [1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
        [0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1],
        [1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0],
        [1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0],
        [1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0],
        [1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0],
        [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
        [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1],
        [0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0],
        [0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
    ]
)
T = np.array(
    [
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
    ]
)

np.set_printoptions(precision=3)
rng = np.random.default_rng(123)

# 入力: 12, 隠れ層のユニット数: 3
w2 = np.array(
    [
        [0.49038557898997, 0.348475676796501, 0.0725879008695083, 0.837472826850604, -0.0706798311519743, -3.6169369170322, -0.53557819719488, -0.0228584789393108, -1.71745249082217, -1.45563751579807, -0.555799932254451, 0.852476539980059],
        [0.442372911956926, -0.536877487857221, 1.00782536916829, 1.07196001297575, -0.732814485632708, 0.822959617857012, -0.453282364154155, -0.0138979392949318, -0.0274233258563056, -0.426670298661898, 1.87560275441379, -2.30528048189891],
        [0.654393041569443, -1.38856820257739, 1.24648311661583, 0.0572877158406771, -0.183237472237546, -0.74305066513479, -0.460930664925325, 0.331118557255208, 0.449470835925128, -1.29645372413246, 1.56850561324256, -0.470667153317658],
    ]
)  # fmt: skip

b2 = np.array([-0.185002, 0.525677, -1.168623])

# 入力: 3 (=隠れ層のユニット数), 出力層のユニット数: 2
w3 = np.array(
    [
        [0.3880031194962, 0.803384989025837, 0.0292864334994403],
        [0.0254467679708455, -0.790397993881956, 1.55313793058729],
    ]
)

b3 = np.array([-1.438040, -1.379338])

dw2 = np.zeros_like(w2)
dw3 = np.zeros_like(w3)
db2 = np.zeros_like(b2)
db3 = np.zeros_like(b3)

costs = []
learning_rate = 0.2
epochs = 50
debug = True

for epoch in range(epochs):
    c = 0
    for i, (x, t) in enumerate(zip(X, T)):

        # 出力(第2層)
        z2 = sigma(w2, x, b2)
        a2 = activate(z2)

        # 出力(第3層)
        z3 = sigma(w3, a2, b3)
        a3 = activate(z3)

        # コスト
        c = c + cost(t, a3)

        # 誤差逆伝播
        d3 = (a3 - t) * deactivate(a3)
        dw3 = dw3 + d3[:, np.newaxis] * a2
        db3 = db3 + d3

        d2 = np.dot(d3, w3) * deactivate(a2)
        dw2 = dw2 + d2[:, np.newaxis] * x
        db2 = db2 + d2

    costs.append(c)

    w2 = w2 - learning_rate * dw2
    w3 = w3 - learning_rate * dw3
    b2 = b2 - learning_rate * db2
    b3 = b3 - learning_rate * db3

    dw2.fill(0)
    dw3.fill(0)
    db2.fill(0)
    db3.fill(0)

    if debug:
        print("=" * 100)
        print("[", epoch, "]", "cost=", c)
        print("=" * 100)
        print("layer 2 (weight)")
        print(w2.reshape(3, 4, 3))
        print("layer 2 (bias)")
        print(b2)
        print("layer 3 (weight)")
        print(w3.reshape(2, 3))
        print("layer 3 (bias)")
        print(b3)
====================================================================================================
[ 0 ] cost= 20.255347730696887
====================================================================================================
layer 2 (weight)
[[[ 0.482  0.335  0.077]
  [ 0.84  -0.091 -3.614]
  [-0.533 -0.041 -1.713]
  [-1.456 -0.572  0.855]]

 [[ 0.446 -0.575  1.067]
  [ 1.168 -0.851  0.902]
  [-0.346 -0.143  0.055]
  [-0.369  1.838 -2.226]]

 [[ 0.753 -1.23   1.239]
  [ 0.054  0.009 -0.726]
  [-0.464  0.516  0.475]
  [-1.273  1.746 -0.438]]]
layer 2 (bias)
[-0.201  0.501 -0.982]
layer 3 (weight)
[[ 0.28   1.191  0.056]
 [ 0.257 -0.369  1.759]]
layer 3 (bias)
[-0.94  -0.727]
====================================================================================================
[ 1 ] cost= 14.428302221124186
====================================================================================================
layer 2 (weight)
[[[ 0.48   0.336  0.078]
  [ 0.841 -0.093 -3.614]
  [-0.532 -0.042 -1.712]
  [-1.453 -0.574  0.857]]

 [[ 0.444 -0.64   1.13 ]
  [ 1.279 -1.001  0.99 ]
  [-0.217 -0.312  0.155]
  [-0.308  1.77  -2.119]]

 [[ 0.767 -1.105  1.127]
  [-0.069  0.224 -0.806]
  [-0.574  0.719  0.379]
  [-1.296  1.87  -0.496]]]
layer 2 (bias)
[-0.202  0.457 -0.875]
layer 3 (weight)
[[ 0.094  1.29  -0.098]
 [ 0.451 -0.302  1.935]]
layer 3 (bias)
[-0.913 -0.464]
====================================================================================================
[ 2 ] cost= 12.243389610133285
====================================================================================================
layer 2 (weight)
[[[ 0.485  0.352  0.074]
  [ 0.84  -0.074 -3.615]
  [-0.533 -0.025 -1.716]
  [-1.446 -0.56   0.859]]

 [[ 0.433 -0.713  1.179]
  [ 1.371 -1.144  1.061]
  [-0.105 -0.476  0.24 ]
  [-0.26   1.689 -2.021]]

 [[ 0.687 -1.08   0.96 ]
  [-0.264  0.363 -0.957]
  [-0.752  0.848  0.199]
  [-1.362  1.877 -0.626]]]
layer 2 (bias)
[-0.184  0.401 -0.902]
layer 3 (weight)
[[-0.073  1.458 -0.336]
 [ 0.555 -0.588  2.008]]
layer 3 (bias)
[-0.881 -0.628]
====================================================================================================
[ 3 ] cost= 9.923756247383421
====================================================================================================
layer 2 (weight)
[[[ 0.504  0.393  0.071]
  [ 0.843 -0.029 -3.617]
  [-0.528  0.017 -1.718]
  [-1.43  -0.52   0.864]]

 [[ 0.395 -0.839  1.215]
  [ 1.447 -1.326  1.105]
  [-0.009 -0.679  0.307]
  [-0.243  1.553 -1.934]]

 [[ 0.713 -0.902  0.893]
  [-0.334  0.574 -0.978]
  [-0.811  1.051  0.154]
  [-1.304  2.047 -0.644]]]
layer 2 (bias)
[-0.138  0.287 -0.737]
layer 3 (weight)
[[-0.221  1.658 -0.616]
 [ 0.686 -0.625  2.253]]
layer 3 (bias)
[-0.858 -0.427]
====================================================================================================
[ 4 ] cost= 7.581090185728515
====================================================================================================
layer 2 (weight)
[[[ 5.115e-01  4.203e-01  6.302e-02]
  [ 8.411e-01  3.583e-03 -3.621e+00]
  [-5.294e-01  4.708e-02 -1.726e+00]
  [-1.422e+00 -4.935e-01  8.598e-01]]

 [[ 3.930e-01 -8.938e-01  1.258e+00]
  [ 1.532e+00 -1.431e+00  1.168e+00]
  [ 1.061e-01 -8.130e-01  4.061e-01]
  [-1.870e-01  1.496e+00 -1.818e+00]]

 [[ 6.199e-01 -8.833e-01  7.567e-01]
  [-4.956e-01  6.553e-01 -1.098e+00]
  [-9.575e-01  1.127e+00 -4.412e-03]
  [-1.347e+00  2.040e+00 -7.660e-01]]]
layer 2 (bias)
[-0.108  0.259 -0.775]
layer 3 (weight)
[[-0.33   1.942 -0.831]
 [ 0.742 -0.949  2.333]]
layer 3 (bias)
[-0.722 -0.657]
====================================================================================================
[ 5 ] cost= 5.679356822237967
====================================================================================================
layer 2 (weight)
[[[ 0.53   0.467  0.058]
  [ 0.848  0.051 -3.622]
  [-0.522  0.092 -1.729]
  [-1.402 -0.448  0.861]]

 [[ 0.348 -1.023  1.274]
  [ 1.567 -1.576  1.174]
  [ 0.168 -0.982  0.452]
  [-0.191  1.366 -1.753]]

 [[ 0.636 -0.713  0.721]
  [-0.522  0.809 -1.084]
  [-0.978  1.277 -0.021]
  [-1.267  2.192 -0.775]]]
layer 2 (bias)
[-0.057  0.145 -0.62 ]
layer 3 (weight)
[[-0.442  2.082 -1.081]
 [ 0.831 -1.008  2.533]]
layer 3 (bias)
[-0.777 -0.519]
====================================================================================================
[ 6 ] cost= 4.332227065429195
====================================================================================================
layer 2 (weight)
[[[ 0.533  0.489  0.05 ]
  [ 0.846  0.075 -3.626]
  [-0.523  0.115 -1.739]
  [-1.397 -0.427  0.853]]

 [[ 0.371 -1.035  1.306]
  [ 1.636 -1.602  1.233]
  [ 0.268 -1.043  0.547]
  [-0.122  1.369 -1.648]]

 [[ 0.569 -0.688  0.637]
  [-0.62   0.838 -1.154]
  [-1.066  1.308 -0.12 ]
  [-1.299  2.195 -0.86 ]]]
layer 2 (bias)
[-0.034  0.173 -0.64 ]
layer 3 (weight)
[[-0.503  2.313 -1.201]
 [ 0.87  -1.212  2.607]]
layer 3 (bias)
[-0.629 -0.656]
====================================================================================================
[ 7 ] cost= 3.4506357911592938
====================================================================================================
layer 2 (weight)
[[[ 0.543  0.524  0.045]
  [ 0.853  0.108 -3.628]
  [-0.515  0.145 -1.744]
  [-1.382 -0.393  0.85 ]]

 [[ 0.347 -1.117  1.313]
  [ 1.652 -1.674  1.235]
  [ 0.306 -1.138  0.579]
  [-0.119  1.296 -1.602]]

 [[ 0.562 -0.581  0.606]
  [-0.639  0.911 -1.148]
  [-1.079  1.378 -0.136]
  [-1.263  2.285 -0.883]]]
layer 2 (bias)
[ 0.004  0.111 -0.552]
layer 3 (weight)
[[-0.578  2.401 -1.37 ]
 [ 0.924 -1.283  2.73 ]]
layer 3 (bias)
[-0.676 -0.615]
====================================================================================================
[ 8 ] cost= 2.8681333226508494
====================================================================================================
layer 2 (weight)
[[[ 0.545  0.546  0.038]
  [ 0.855  0.128 -3.631]
  [-0.512  0.164 -1.751]
  [-1.377 -0.373  0.843]]

 [[ 0.354 -1.147  1.327]
  [ 1.685 -1.69   1.261]
  [ 0.362 -1.182  0.631]
  [-0.082  1.283 -1.54 ]]

 [[ 0.526 -0.529  0.56 ]
  [-0.681  0.932 -1.172]
  [-1.115  1.401 -0.179]
  [-1.275  2.322 -0.931]]]
layer 2 (bias)
[ 0.026  0.111 -0.529]
layer 3 (weight)
[[-0.627  2.524 -1.471]
 [ 0.961 -1.388  2.806]]
layer 3 (bias)
[-0.626 -0.658]
====================================================================================================
[ 9 ] cost= 2.4881334467695897
====================================================================================================
layer 2 (weight)
[[[ 0.547  0.568  0.033]
  [ 0.859  0.147 -3.633]
  [-0.508  0.183 -1.758]
  [-1.37  -0.35   0.837]]

 [[ 0.352 -1.189  1.334]
  [ 1.702 -1.713  1.271]
  [ 0.397 -1.227  0.663]
  [-0.06   1.255 -1.496]]

 [[ 0.502 -0.464  0.526]
  [-0.704  0.957 -1.178]
  [-1.132  1.425 -0.201]
  [-1.277  2.374 -0.964]]]
layer 2 (bias)
[ 0.049  0.093 -0.486]
layer 3 (weight)
[[-0.676  2.605 -1.573]
 [ 0.997 -1.458  2.884]]
layer 3 (bias)
[-0.619 -0.665]
====================================================================================================
[ 10 ] cost= 2.2138572637339453
====================================================================================================
layer 2 (weight)
[[[ 0.548  0.588  0.027]
  [ 0.862  0.163 -3.635]
  [-0.504  0.198 -1.764]
  [-1.366 -0.331  0.831]]

 [[ 0.354 -1.224  1.341]
  [ 1.718 -1.727  1.281]
  [ 0.43  -1.26   0.692]
  [-0.037  1.234 -1.457]]

 [[ 0.475 -0.407  0.492]
  [-0.725  0.97  -1.185]
  [-1.148  1.439 -0.223]
  [-1.288  2.42  -0.998]]]
layer 2 (bias)
[ 0.069  0.081 -0.451]
layer 3 (weight)
[[-0.718  2.678 -1.658]
 [ 1.029 -1.523  2.951]]
layer 3 (bias)
[-0.607 -0.68 ]
====================================================================================================
[ 11 ] cost= 2.000462713541164
====================================================================================================
layer 2 (weight)
[[[ 0.548  0.606  0.021]
  [ 0.865  0.177 -3.638]
  [-0.5    0.211 -1.771]
  [-1.363 -0.314  0.824]]

 [[ 0.357 -1.256  1.347]
  [ 1.731 -1.737  1.288]
  [ 0.457 -1.288  0.716]
  [-0.016  1.214 -1.422]]

 [[ 0.449 -0.353  0.461]
  [-0.743  0.979 -1.19 ]
  [-1.16   1.447 -0.24 ]
  [-1.3    2.466 -1.03 ]]]
layer 2 (bias)
[ 0.087  0.069 -0.416]
layer 3 (weight)
[[-0.756  2.739 -1.735]
 [ 1.059 -1.58   3.013]]
layer 3 (bias)
[-0.597 -0.692]
====================================================================================================
[ 12 ] cost= 1.8265387394086678
====================================================================================================
layer 2 (weight)
[[[ 0.547  0.622  0.015]
  [ 0.867  0.189 -3.64 ]
  [-0.496  0.222 -1.778]
  [-1.362 -0.299  0.817]]

 [[ 0.361 -1.287  1.353]
  [ 1.741 -1.744  1.294]
  [ 0.481 -1.311  0.736]
  [ 0.004  1.195 -1.39 ]]

 [[ 0.423 -0.301  0.431]
  [-0.758  0.983 -1.194]
  [-1.169  1.452 -0.255]
  [-1.315  2.511 -1.061]]]
layer 2 (bias)
[ 0.102  0.058 -0.384]
layer 3 (weight)
[[-0.791  2.792 -1.806]
 [ 1.086 -1.63   3.07 ]]
layer 3 (bias)
[-0.589 -0.703]
====================================================================================================
[ 13 ] cost= 1.6798725134647547
====================================================================================================
layer 2 (weight)
[[[ 0.545  0.636  0.01 ]
  [ 0.87   0.199 -3.643]
  [-0.493  0.231 -1.784]
  [-1.361 -0.284  0.811]]

 [[ 0.365 -1.316  1.358]
  [ 1.751 -1.748  1.298]
  [ 0.502 -1.329  0.752]
  [ 0.023  1.177 -1.362]]

 [[ 0.397 -0.251  0.402]
  [-0.771  0.985 -1.197]
  [-1.177  1.453 -0.268]
  [-1.332  2.555 -1.091]]]
layer 2 (bias)
[ 0.116  0.046 -0.353]
layer 3 (weight)
[[-0.824  2.838 -1.871]
 [ 1.111 -1.675  3.123]]
layer 3 (bias)
[-0.581 -0.714]
====================================================================================================
[ 14 ] cost= 1.5530392965388815
====================================================================================================
layer 2 (weight)
[[[ 0.542  0.649  0.004]
  [ 0.872  0.208 -3.645]
  [-0.49   0.239 -1.791]
  [-1.361 -0.271  0.804]]

 [[ 0.37  -1.344  1.363]
  [ 1.758 -1.751  1.3  ]
  [ 0.52  -1.345  0.766]
  [ 0.042  1.158 -1.335]]

 [[ 0.371 -0.204  0.374]
  [-0.784  0.984 -1.199]
  [-1.184  1.452 -0.28 ]
  [-1.349  2.597 -1.12 ]]]
layer 2 (bias)
[ 0.129  0.035 -0.324]
layer 3 (weight)
[[-0.854  2.88  -1.932]
 [ 1.134 -1.716  3.173]]
layer 3 (bias)
[-0.575 -0.725]
====================================================================================================
[ 15 ] cost= 1.4413721502103198
====================================================================================================
layer 2 (weight)
[[[ 5.395e-01  6.606e-01 -1.520e-03]
  [ 8.735e-01  2.156e-01 -3.647e+00]
  [-4.875e-01  2.467e-01 -1.798e+00]
  [-1.362e+00 -2.597e-01  7.971e-01]]

 [[ 3.761e-01 -1.370e+00  1.368e+00]
  [ 1.765e+00 -1.752e+00  1.301e+00]
  [ 5.368e-01 -1.358e+00  7.784e-01]
  [ 5.973e-02  1.140e+00 -1.311e+00]]

 [[ 3.459e-01 -1.598e-01  3.469e-01]
  [-7.956e-01  9.821e-01 -1.202e+00]
  [-1.190e+00  1.449e+00 -2.917e-01]
  [-1.368e+00  2.637e+00 -1.147e+00]]]
layer 2 (bias)
[ 0.14   0.024 -0.296]
layer 3 (weight)
[[-0.882  2.917 -1.989]
 [ 1.155 -1.755  3.221]]
layer 3 (bias)
[-0.568 -0.735]
====================================================================================================
[ 16 ] cost= 1.3418163884655132
====================================================================================================
layer 2 (weight)
[[[ 0.536  0.671 -0.007]
  [ 0.875  0.222 -3.65 ]
  [-0.485  0.253 -1.804]
  [-1.363 -0.249  0.79 ]]

 [[ 0.382 -1.397  1.373]
  [ 1.772 -1.752  1.302]
  [ 0.552 -1.369  0.788]
  [ 0.077  1.122 -1.288]]

 [[ 0.321 -0.118  0.321]
  [-0.807  0.979 -1.205]
  [-1.196  1.445 -0.303]
  [-1.387  2.676 -1.174]]]
layer 2 (bias)
[ 0.149  0.012 -0.27 ]
layer 3 (weight)
[[-0.909  2.95  -2.044]
 [ 1.174 -1.791  3.265]]
layer 3 (bias)
[-0.562 -0.745]
====================================================================================================
[ 17 ] cost= 1.2522668540965225
====================================================================================================
layer 2 (weight)
[[[ 0.533  0.681 -0.012]
  [ 0.876  0.229 -3.652]
  [-0.483  0.259 -1.811]
  [-1.365 -0.239  0.784]]

 [[ 0.389 -1.422  1.378]
  [ 1.777 -1.752  1.301]
  [ 0.565 -1.378  0.797]
  [ 0.094  1.104 -1.266]]

 [[ 0.297 -0.078  0.295]
  [-0.818  0.974 -1.208]
  [-1.202  1.44  -0.313]
  [-1.405  2.712 -1.2  ]]]
layer 2 (bias)
[ 0.158  0.    -0.247]
layer 3 (weight)
[[-0.934  2.981 -2.097]
 [ 1.192 -1.824  3.308]]
layer 3 (bias)
[-0.556 -0.755]
====================================================================================================
[ 18 ] cost= 1.1711885012985146
====================================================================================================
layer 2 (weight)
[[[ 0.529  0.69  -0.018]
  [ 0.877  0.234 -3.654]
  [-0.482  0.263 -1.817]
  [-1.367 -0.23   0.777]]

 [[ 0.396 -1.447  1.382]
  [ 1.782 -1.75   1.3  ]
  [ 0.578 -1.385  0.803]
  [ 0.11   1.086 -1.245]]

 [[ 0.273 -0.041  0.271]
  [-0.829  0.969 -1.21 ]
  [-1.208  1.434 -0.323]
  [-1.424  2.747 -1.225]]]
layer 2 (bias)
[ 0.166 -0.012 -0.225]
layer 3 (weight)
[[-0.958  3.01  -2.148]
 [ 1.209 -1.855  3.349]]
layer 3 (bias)
[-0.55  -0.765]
====================================================================================================
[ 19 ] cost= 1.0974019780514392
====================================================================================================
layer 2 (weight)
[[[ 0.525  0.698 -0.023]
  [ 0.877  0.239 -3.657]
  [-0.48   0.268 -1.823]
  [-1.369 -0.222  0.771]]

 [[ 0.403 -1.472  1.387]
  [ 1.786 -1.748  1.297]
  [ 0.59  -1.392  0.809]
  [ 0.126  1.068 -1.225]]

 [[ 0.25  -0.007  0.248]
  [-0.84   0.964 -1.213]
  [-1.213  1.428 -0.333]
  [-1.442  2.78  -1.249]]]
layer 2 (bias)
[ 0.173 -0.024 -0.204]
layer 3 (weight)
[[-0.981  3.037 -2.196]
 [ 1.224 -1.885  3.389]]
layer 3 (bias)
[-0.544 -0.774]
====================================================================================================
[ 20 ] cost= 1.0299628391863533
====================================================================================================
layer 2 (weight)
[[[ 0.522  0.705 -0.028]
  [ 0.877  0.244 -3.659]
  [-0.479  0.272 -1.829]
  [-1.372 -0.214  0.764]]

 [[ 0.41  -1.497  1.392]
  [ 1.79  -1.746  1.295]
  [ 0.6   -1.397  0.813]
  [ 0.142  1.049 -1.206]]

 [[ 0.228  0.026  0.225]
  [-0.85   0.959 -1.216]
  [-1.219  1.422 -0.342]
  [-1.46   2.811 -1.272]]]
layer 2 (bias)
[ 0.18  -0.037 -0.185]
layer 3 (weight)
[[-1.002  3.062 -2.243]
 [ 1.239 -1.913  3.426]]
layer 3 (bias)
[-0.538 -0.782]
====================================================================================================
[ 21 ] cost= 0.9680924619511708
====================================================================================================
layer 2 (weight)
[[[ 0.518  0.712 -0.033]
  [ 0.878  0.248 -3.661]
  [-0.479  0.275 -1.835]
  [-1.374 -0.207  0.758]]

 [[ 0.418 -1.521  1.396]
  [ 1.794 -1.743  1.291]
  [ 0.61  -1.401  0.816]
  [ 0.157  1.03  -1.188]]

 [[ 0.207  0.057  0.204]
  [-0.861  0.953 -1.219]
  [-1.225  1.415 -0.352]
  [-1.478  2.84  -1.294]]]
layer 2 (bias)
[ 0.186 -0.05  -0.168]
layer 3 (weight)
[[-1.022  3.086 -2.289]
 [ 1.252 -1.94   3.462]]
layer 3 (bias)
[-0.532 -0.791]
====================================================================================================
[ 22 ] cost= 0.9111370257052971
====================================================================================================
layer 2 (weight)
[[[ 0.514  0.718 -0.037]
  [ 0.878  0.251 -3.663]
  [-0.478  0.278 -1.841]
  [-1.377 -0.201  0.752]]

 [[ 0.425 -1.545  1.401]
  [ 1.797 -1.739  1.287]
  [ 0.62  -1.404  0.818]
  [ 0.172  1.011 -1.171]]

 [[ 0.187  0.085  0.183]
  [-0.871  0.948 -1.223]
  [-1.23   1.409 -0.36 ]
  [-1.494  2.867 -1.315]]]
layer 2 (bias)
[ 0.191 -0.063 -0.152]
layer 3 (weight)
[[-1.041  3.109 -2.333]
 [ 1.265 -1.965  3.497]]
layer 3 (bias)
[-0.526 -0.799]
====================================================================================================
[ 23 ] cost= 0.8585416998820148
====================================================================================================
layer 2 (weight)
[[[ 0.51   0.723 -0.042]
  [ 0.877  0.255 -3.665]
  [-0.477  0.281 -1.847]
  [-1.38  -0.195  0.747]]

 [[ 0.433 -1.569  1.406]
  [ 1.8   -1.736  1.282]
  [ 0.629 -1.407  0.819]
  [ 0.187  0.992 -1.154]]

 [[ 0.167  0.112  0.164]
  [-0.881  0.942 -1.226]
  [-1.236  1.402 -0.369]
  [-1.511  2.893 -1.335]]]
layer 2 (bias)
[ 0.196 -0.076 -0.137]
layer 3 (weight)
[[-1.059  3.131 -2.375]
 [ 1.277 -1.989  3.53 ]]
layer 3 (bias)
[-0.52  -0.806]
====================================================================================================
[ 24 ] cost= 0.8098332037946312
====================================================================================================
layer 2 (weight)
[[[ 0.506  0.729 -0.046]
  [ 0.877  0.258 -3.667]
  [-0.477  0.284 -1.852]
  [-1.383 -0.189  0.741]]

 [[ 0.441 -1.593  1.41 ]
  [ 1.803 -1.732  1.277]
  [ 0.637 -1.409  0.819]
  [ 0.201  0.973 -1.137]]

 [[ 0.149  0.137  0.145]
  [-0.891  0.937 -1.229]
  [-1.242  1.396 -0.377]
  [-1.526  2.917 -1.354]]]
layer 2 (bias)
[ 0.2   -0.09  -0.123]
layer 3 (weight)
[[-1.076  3.152 -2.416]
 [ 1.288 -2.012  3.562]]
layer 3 (bias)
[-0.513 -0.814]
====================================================================================================
[ 25 ] cost= 0.7646071204200937
====================================================================================================
layer 2 (weight)
[[[ 0.502  0.733 -0.05 ]
  [ 0.877  0.261 -3.669]
  [-0.477  0.286 -1.857]
  [-1.385 -0.184  0.735]]

 [[ 0.449 -1.616  1.415]
  [ 1.806 -1.728  1.271]
  [ 0.645 -1.41   0.819]
  [ 0.215  0.953 -1.121]]

 [[ 0.131  0.161  0.127]
  [-0.9    0.932 -1.232]
  [-1.247  1.39  -0.385]
  [-1.541  2.94  -1.372]]]
layer 2 (bias)
[ 0.204 -0.104 -0.11 ]
layer 3 (weight)
[[-1.092  3.173 -2.456]
 [ 1.299 -2.034  3.593]]
layer 3 (bias)
[-0.507 -0.821]
====================================================================================================
[ 26 ] cost= 0.7225180178393089
====================================================================================================
layer 2 (weight)
[[[ 0.499  0.738 -0.055]
  [ 0.876  0.264 -3.671]
  [-0.477  0.288 -1.862]
  [-1.388 -0.179  0.73 ]]

 [[ 0.457 -1.64   1.42 ]
  [ 1.808 -1.724  1.266]
  [ 0.653 -1.411  0.818]
  [ 0.229  0.934 -1.105]]

 [[ 0.114  0.183  0.11 ]
  [-0.91   0.927 -1.235]
  [-1.253  1.384 -0.393]
  [-1.555  2.961 -1.389]]]
layer 2 (bias)
[ 0.208 -0.118 -0.098]
layer 3 (weight)
[[-1.107  3.194 -2.494]
 [ 1.308 -2.055  3.623]]
layer 3 (bias)
[-0.5   -0.827]
====================================================================================================
[ 27 ] cost= 0.6832712903817526
====================================================================================================
layer 2 (weight)
[[[ 0.495  0.742 -0.058]
  [ 0.875  0.266 -3.673]
  [-0.477  0.29  -1.867]
  [-1.391 -0.175  0.725]]

 [[ 0.465 -1.663  1.425]
  [ 1.811 -1.719  1.259]
  [ 0.661 -1.411  0.817]
  [ 0.243  0.914 -1.089]]

 [[ 0.098  0.204  0.094]
  [-0.919  0.922 -1.238]
  [-1.258  1.378 -0.4  ]
  [-1.569  2.981 -1.406]]]
layer 2 (bias)
[ 0.211 -0.132 -0.087]
layer 3 (weight)
[[-1.122  3.214 -2.531]
 [ 1.318 -2.076  3.651]]
layer 3 (bias)
[-0.493 -0.834]
====================================================================================================
[ 28 ] cost= 0.646616081328722
====================================================================================================
layer 2 (weight)
[[[ 0.492  0.746 -0.062]
  [ 0.875  0.268 -3.675]
  [-0.477  0.292 -1.872]
  [-1.394 -0.171  0.72 ]]

 [[ 0.474 -1.686  1.429]
  [ 1.813 -1.715  1.253]
  [ 0.668 -1.411  0.815]
  [ 0.256  0.895 -1.074]]

 [[ 0.082  0.224  0.079]
  [-0.928  0.917 -1.242]
  [-1.264  1.373 -0.407]
  [-1.582  3.    -1.421]]]
layer 2 (bias)
[ 0.214 -0.147 -0.077]
layer 3 (weight)
[[-1.135  3.234 -2.566]
 [ 1.327 -2.096  3.678]]
layer 3 (bias)
[-0.486 -0.84 ]
====================================================================================================
[ 29 ] cost= 0.6123389037345327
====================================================================================================
layer 2 (weight)
[[[ 0.488  0.749 -0.066]
  [ 0.874  0.271 -3.676]
  [-0.477  0.294 -1.877]
  [-1.396 -0.167  0.715]]

 [[ 0.482 -1.709  1.434]
  [ 1.815 -1.71   1.247]
  [ 0.675 -1.411  0.813]
  [ 0.269  0.875 -1.059]]

 [[ 0.068  0.242  0.065]
  [-0.936  0.913 -1.245]
  [-1.269  1.368 -0.414]
  [-1.595  3.018 -1.436]]]
layer 2 (bias)
[ 0.217 -0.161 -0.067]
layer 3 (weight)
[[-1.148  3.253 -2.6  ]
 [ 1.335 -2.115  3.704]]
layer 3 (bias)
[-0.478 -0.847]
====================================================================================================
[ 30 ] cost= 0.5802577423362514
====================================================================================================
layer 2 (weight)
[[[ 0.485  0.753 -0.069]
  [ 0.873  0.273 -3.678]
  [-0.477  0.295 -1.881]
  [-1.399 -0.163  0.711]]

 [[ 0.49  -1.731  1.439]
  [ 1.818 -1.706  1.24 ]
  [ 0.682 -1.41   0.81 ]
  [ 0.283  0.856 -1.045]]

 [[ 0.054  0.259  0.051]
  [-0.945  0.909 -1.248]
  [-1.274  1.363 -0.421]
  [-1.606  3.035 -1.45 ]]]
layer 2 (bias)
[ 0.219 -0.175 -0.058]
layer 3 (weight)
[[-1.161  3.272 -2.633]
 [ 1.343 -2.134  3.73 ]]
layer 3 (bias)
[-0.471 -0.853]
====================================================================================================
[ 31 ] cost= 0.55021654513121
====================================================================================================
layer 2 (weight)
[[[ 0.482  0.756 -0.073]
  [ 0.873  0.275 -3.68 ]
  [-0.477  0.297 -1.885]
  [-1.401 -0.16   0.707]]

 [[ 0.499 -1.753  1.444]
  [ 1.82  -1.701  1.233]
  [ 0.689 -1.409  0.807]
  [ 0.295  0.837 -1.03 ]]

 [[ 0.04   0.276  0.038]
  [-0.953  0.905 -1.251]
  [-1.279  1.358 -0.428]
  [-1.618  3.051 -1.464]]]
layer 2 (bias)
[ 0.222 -0.189 -0.05 ]
layer 3 (weight)
[[-1.172  3.291 -2.664]
 [ 1.351 -2.152  3.754]]
layer 3 (bias)
[-0.463 -0.858]
====================================================================================================
[ 32 ] cost= 0.5220801124013141
====================================================================================================
layer 2 (weight)
[[[ 0.478  0.759 -0.076]
  [ 0.872  0.277 -3.681]
  [-0.478  0.298 -1.89 ]
  [-1.404 -0.156  0.702]]

 [[ 0.507 -1.775  1.449]
  [ 1.822 -1.696  1.227]
  [ 0.696 -1.408  0.805]
  [ 0.308  0.818 -1.016]]

 [[ 0.028  0.291  0.025]
  [-0.96   0.901 -1.254]
  [-1.284  1.353 -0.434]
  [-1.629  3.066 -1.476]]]
layer 2 (bias)
[ 0.224 -0.203 -0.042]
layer 3 (weight)
[[-1.183  3.31  -2.694]
 [ 1.358 -2.17   3.777]]
layer 3 (bias)
[-0.455 -0.864]
====================================================================================================
[ 33 ] cost= 0.4957294612032177
====================================================================================================
layer 2 (weight)
[[[ 0.475  0.762 -0.079]
  [ 0.871  0.279 -3.683]
  [-0.478  0.3   -1.894]
  [-1.406 -0.153  0.698]]

 [[ 0.515 -1.796  1.453]
  [ 1.824 -1.692  1.221]
  [ 0.702 -1.407  0.802]
  [ 0.321  0.799 -1.002]]

 [[ 0.016  0.305  0.013]
  [-0.968  0.897 -1.258]
  [-1.289  1.349 -0.44 ]
  [-1.639  3.08  -1.489]]]
layer 2 (bias)
[ 0.226 -0.216 -0.035]
layer 3 (weight)
[[-1.194  3.329 -2.723]
 [ 1.365 -2.188  3.8  ]]
layer 3 (bias)
[-0.447 -0.87 ]
====================================================================================================
[ 34 ] cost= 0.471057776972603
====================================================================================================
layer 2 (weight)
[[[ 4.725e-01  7.641e-01 -8.215e-02]
  [ 8.700e-01  2.804e-01 -3.684e+00]
  [-4.785e-01  3.008e-01 -1.897e+00]
  [-1.408e+00 -1.503e-01  6.943e-01]]

 [[ 5.238e-01 -1.817e+00  1.458e+00]
  [ 1.826e+00 -1.687e+00  1.214e+00]
  [ 7.084e-01 -1.405e+00  7.989e-01]
  [ 3.329e-01  7.809e-01 -9.887e-01]]

 [[ 4.610e-03  3.191e-01  2.012e-03]
  [-9.747e-01  8.941e-01 -1.261e+00]
  [-1.294e+00  1.345e+00 -4.465e-01]
  [-1.649e+00  3.093e+00 -1.500e+00]]]
layer 2 (bias)
[ 0.228 -0.229 -0.028]
layer 3 (weight)
[[-1.204  3.347 -2.751]
 [ 1.372 -2.205  3.821]]
layer 3 (bias)
[-0.439 -0.875]
====================================================================================================
[ 35 ] cost= 0.44796705714045626
====================================================================================================
layer 2 (weight)
[[[ 0.47   0.767 -0.085]
  [ 0.869  0.282 -3.685]
  [-0.479  0.302 -1.901]
  [-1.411 -0.148  0.691]]

 [[ 0.532 -1.837  1.463]
  [ 1.828 -1.682  1.208]
  [ 0.715 -1.403  0.796]
  [ 0.345  0.763 -0.975]]

 [[-0.006  0.332 -0.009]
  [-0.981  0.891 -1.264]
  [-1.298  1.341 -0.452]
  [-1.658  3.105 -1.511]]]
layer 2 (bias)
[ 0.23  -0.242 -0.022]
layer 3 (weight)
[[-1.213  3.365 -2.778]
 [ 1.378 -2.221  3.842]]
layer 3 (bias)
[-0.431 -0.881]
====================================================================================================
[ 36 ] cost= 0.42636550971397674
====================================================================================================
layer 2 (weight)
[[[ 0.467  0.769 -0.088]
  [ 0.868  0.284 -3.687]
  [-0.479  0.303 -1.905]
  [-1.413 -0.145  0.687]]

 [[ 0.54  -1.856  1.467]
  [ 1.829 -1.678  1.203]
  [ 0.721 -1.402  0.794]
  [ 0.357  0.746 -0.962]]

 [[-0.016  0.344 -0.019]
  [-0.988  0.888 -1.267]
  [-1.303  1.338 -0.458]
  [-1.667  3.117 -1.521]]]
layer 2 (bias)
[ 0.232 -0.255 -0.016]
layer 3 (weight)
[[-1.223  3.383 -2.803]
 [ 1.384 -2.238  3.862]]
layer 3 (bias)
[-0.423 -0.886]
====================================================================================================
[ 37 ] cost= 0.4061657067505498
====================================================================================================
layer 2 (weight)
[[[ 0.464  0.771 -0.09 ]
  [ 0.868  0.285 -3.688]
  [-0.48   0.304 -1.908]
  [-1.415 -0.142  0.683]]

 [[ 0.549 -1.875  1.472]
  [ 1.831 -1.673  1.197]
  [ 0.726 -1.4    0.791]
  [ 0.368  0.729 -0.95 ]]

 [[-0.026  0.356 -0.029]
  [-0.994  0.886 -1.27 ]
  [-1.307  1.335 -0.464]
  [-1.675  3.128 -1.531]]]
layer 2 (bias)
[ 0.234 -0.267 -0.01 ]
layer 3 (weight)
[[-1.231  3.401 -2.828]
 [ 1.39  -2.253  3.881]]
layer 3 (bias)
[-0.415 -0.891]
====================================================================================================
[ 38 ] cost= 0.38728342707866054
====================================================================================================
layer 2 (weight)
[[[ 0.462  0.773 -0.093]
  [ 0.867  0.287 -3.689]
  [-0.48   0.306 -1.911]
  [-1.417 -0.14   0.68 ]]

 [[ 0.557 -1.894  1.476]
  [ 1.833 -1.669  1.192]
  [ 0.732 -1.398  0.789]
  [ 0.38   0.712 -0.937]]

 [[-0.035  0.367 -0.038]
  [-1.     0.883 -1.273]
  [-1.311  1.332 -0.469]
  [-1.683  3.139 -1.54 ]]]
layer 2 (bias)
[ 0.236 -0.279 -0.004]
layer 3 (weight)
[[-1.24   3.418 -2.851]
 [ 1.396 -2.269  3.9  ]]
layer 3 (bias)
[-0.407 -0.897]
====================================================================================================
[ 39 ] cost= 0.36963707107333454
====================================================================================================
layer 2 (weight)
[[[ 0.459  0.775 -0.095]
  [ 0.866  0.289 -3.69 ]
  [-0.481  0.307 -1.915]
  [-1.419 -0.138  0.677]]

 [[ 0.565 -1.911  1.481]
  [ 1.835 -1.665  1.187]
  [ 0.738 -1.396  0.787]
  [ 0.391  0.696 -0.925]]

 [[-0.044  0.377 -0.047]
  [-1.006  0.881 -1.276]
  [-1.314  1.329 -0.474]
  [-1.691  3.149 -1.549]]]
layer 2 (bias)
[ 0.237 -0.29   0.001]
layer 3 (weight)
[[-1.248  3.436 -2.873]
 [ 1.402 -2.284  3.918]]
layer 3 (bias)
[-0.399 -0.901]
====================================================================================================
[ 40 ] cost= 0.3531475024724594
====================================================================================================
layer 2 (weight)
[[[ 0.457  0.777 -0.098]
  [ 0.865  0.29  -3.691]
  [-0.481  0.308 -1.918]
  [-1.42  -0.135  0.674]]

 [[ 0.573 -1.928  1.485]
  [ 1.836 -1.661  1.183]
  [ 0.743 -1.394  0.785]
  [ 0.402  0.681 -0.913]]

 [[-0.052  0.387 -0.055]
  [-1.011  0.879 -1.279]
  [-1.318  1.326 -0.48 ]
  [-1.698  3.159 -1.557]]]
layer 2 (bias)
[ 0.239 -0.301  0.006]
layer 3 (weight)
[[-1.256  3.452 -2.895]
 [ 1.407 -2.299  3.935]]
layer 3 (bias)
[-0.391 -0.906]
====================================================================================================
[ 41 ] cost= 0.33773816944179186
====================================================================================================
layer 2 (weight)
[[[ 0.455  0.779 -0.1  ]
  [ 0.864  0.292 -3.693]
  [-0.481  0.309 -1.921]
  [-1.422 -0.133  0.671]]

 [[ 0.58  -1.945  1.489]
  [ 1.838 -1.657  1.178]
  [ 0.749 -1.392  0.783]
  [ 0.412  0.665 -0.901]]

 [[-0.06   0.396 -0.063]
  [-1.017  0.877 -1.282]
  [-1.322  1.324 -0.485]
  [-1.705  3.168 -1.565]]]
layer 2 (bias)
[ 0.24  -0.311  0.01 ]
layer 3 (weight)
[[-1.263  3.469 -2.915]
 [ 1.412 -2.314  3.952]]
layer 3 (bias)
[-0.383 -0.911]
====================================================================================================
[ 42 ] cost= 0.32333537411032315
====================================================================================================
layer 2 (weight)
[[[ 0.453  0.781 -0.102]
  [ 0.864  0.293 -3.694]
  [-0.482  0.31  -1.923]
  [-1.424 -0.131  0.668]]

 [[ 0.588 -1.961  1.494]
  [ 1.839 -1.653  1.175]
  [ 0.754 -1.39   0.782]
  [ 0.423  0.651 -0.89 ]]

 [[-0.068  0.405 -0.071]
  [-1.021  0.876 -1.285]
  [-1.325  1.321 -0.489]
  [-1.711  3.177 -1.573]]]
layer 2 (bias)
[ 0.242 -0.322  0.015]
layer 3 (weight)
[[-1.27   3.485 -2.935]
 [ 1.417 -2.328  3.968]]
layer 3 (bias)
[-0.375 -0.916]
====================================================================================================
[ 43 ] cost= 0.30986858826012853
====================================================================================================
layer 2 (weight)
[[[ 0.451  0.783 -0.104]
  [ 0.863  0.294 -3.695]
  [-0.482  0.311 -1.926]
  [-1.425 -0.129  0.665]]

 [[ 0.595 -1.976  1.498]
  [ 1.841 -1.649  1.171]
  [ 0.759 -1.388  0.781]
  [ 0.433  0.637 -0.879]]

 [[-0.075  0.414 -0.078]
  [-1.026  0.874 -1.288]
  [-1.328  1.319 -0.494]
  [-1.718  3.185 -1.58 ]]]
layer 2 (bias)
[ 0.243 -0.331  0.019]
layer 3 (weight)
[[-1.277  3.501 -2.954]
 [ 1.422 -2.342  3.984]]
layer 3 (bias)
[-0.368 -0.92 ]
====================================================================================================
[ 44 ] cost= 0.2972707446089732
====================================================================================================
layer 2 (weight)
[[[ 0.449  0.785 -0.106]
  [ 0.862  0.296 -3.696]
  [-0.483  0.312 -1.929]
  [-1.427 -0.127  0.663]]

 [[ 0.603 -1.991  1.502]
  [ 1.842 -1.645  1.168]
  [ 0.764 -1.385  0.779]
  [ 0.442  0.623 -0.868]]

 [[-0.082  0.422 -0.085]
  [-1.031  0.873 -1.291]
  [-1.331  1.317 -0.499]
  [-1.724  3.193 -1.587]]]
layer 2 (bias)
[ 0.244 -0.341  0.023]
layer 3 (weight)
[[-1.284  3.517 -2.973]
 [ 1.427 -2.356  3.999]]
layer 3 (bias)
[-0.36  -0.925]
====================================================================================================
[ 45 ] cost= 0.28547846219172934
====================================================================================================
layer 2 (weight)
[[[ 0.447  0.786 -0.108]
  [ 0.861  0.297 -3.697]
  [-0.483  0.313 -1.931]
  [-1.428 -0.125  0.66 ]]

 [[ 0.61  -2.005  1.506]
  [ 1.843 -1.642  1.165]
  [ 0.768 -1.383  0.779]
  [ 0.452  0.61  -0.858]]

 [[-0.088  0.43  -0.092]
  [-1.035  0.872 -1.294]
  [-1.334  1.316 -0.503]
  [-1.729  3.2   -1.593]]]
layer 2 (bias)
[ 0.246 -0.349  0.027]
layer 3 (weight)
[[-1.29   3.532 -2.99 ]
 [ 1.432 -2.369  4.014]]
layer 3 (bias)
[-0.353 -0.929]
====================================================================================================
[ 46 ] cost= 0.27443218749102744
====================================================================================================
layer 2 (weight)
[[[ 0.445  0.788 -0.11 ]
  [ 0.861  0.299 -3.698]
  [-0.484  0.314 -1.934]
  [-1.43  -0.124  0.658]]

 [[ 0.617 -2.019  1.51 ]
  [ 1.844 -1.638  1.162]
  [ 0.773 -1.381  0.778]
  [ 0.461  0.597 -0.847]]

 [[-0.095  0.438 -0.098]
  [-1.039  0.871 -1.297]
  [-1.337  1.314 -0.508]
  [-1.735  3.208 -1.6  ]]]
layer 2 (bias)
[ 0.247 -0.358  0.031]
layer 3 (weight)
[[-1.296  3.547 -3.007]
 [ 1.436 -2.383  4.028]]
layer 3 (bias)
[-0.346 -0.933]
====================================================================================================
[ 47 ] cost= 0.264076249275664
====================================================================================================
layer 2 (weight)
[[[ 0.443  0.789 -0.112]
  [ 0.86   0.3   -3.698]
  [-0.484  0.315 -1.936]
  [-1.431 -0.122  0.655]]

 [[ 0.624 -2.032  1.513]
  [ 1.846 -1.635  1.159]
  [ 0.777 -1.379  0.777]
  [ 0.47   0.585 -0.837]]

 [[-0.1    0.445 -0.104]
  [-1.043  0.87  -1.299]
  [-1.34   1.313 -0.512]
  [-1.74   3.215 -1.605]]]
layer 2 (bias)
[ 0.248 -0.366  0.035]
layer 3 (weight)
[[-1.302  3.562 -3.024]
 [ 1.44  -2.395  4.042]]
layer 3 (bias)
[-0.339 -0.937]
====================================================================================================
[ 48 ] cost= 0.25435883517324565
====================================================================================================
layer 2 (weight)
[[[ 0.441  0.791 -0.114]
  [ 0.859  0.301 -3.699]
  [-0.484  0.316 -1.939]
  [-1.432 -0.12   0.653]]

 [[ 0.631 -2.044  1.517]
  [ 1.847 -1.631  1.157]
  [ 0.781 -1.377  0.777]
  [ 0.479  0.573 -0.828]]

 [[-0.106  0.452 -0.11 ]
  [-1.047  0.869 -1.302]
  [-1.343  1.311 -0.516]
  [-1.745  3.221 -1.611]]]
layer 2 (bias)
[ 0.25  -0.374  0.039]
layer 3 (weight)
[[-1.308  3.576 -3.04 ]
 [ 1.445 -2.408  4.055]]
layer 3 (bias)
[-0.332 -0.941]
====================================================================================================
[ 49 ] cost= 0.24523190314535037
====================================================================================================
layer 2 (weight)
[[[ 0.44   0.792 -0.116]
  [ 0.859  0.303 -3.7  ]
  [-0.485  0.317 -1.941]
  [-1.433 -0.118  0.651]]

 [[ 0.638 -2.057  1.521]
  [ 1.848 -1.628  1.155]
  [ 0.786 -1.375  0.777]
  [ 0.488  0.562 -0.818]]

 [[-0.112  0.458 -0.116]
  [-1.051  0.869 -1.304]
  [-1.345  1.31  -0.52 ]
  [-1.75   3.228 -1.616]]]
layer 2 (bias)
[ 0.251 -0.382  0.042]
layer 3 (weight)
[[-1.314  3.59  -3.055]
 [ 1.449 -2.42   4.068]]
layer 3 (bias)
[-0.326 -0.945]
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot()
ax.plot(costs, marker="o", markersize=5)
ax.grid()
plt.minorticks_on()
plt.show()

fig-1

XORを識別してみる

rng = np.random.default_rng(123)

X = np.array(
    [
        [0, 0],
        [0, 1],
        [1, 0],
        [1, 1],
    ]
)
T = np.array(
    [
        [0],
        [1],
        [1],
        [0],
    ]
)

w2 = rng.random(size=(2, 2)) - 0.5
b2 = rng.random(2) - 0.5

w3 = rng.random(size=(1, 2)) - 0.5
b3 = rng.random(1) - 0.5

dw2 = np.zeros_like(w2)
dw3 = np.zeros_like(w3)
db2 = np.zeros_like(b2)
db3 = np.zeros_like(b3)

costs = []
learning_rate = 0.5
epochs = 3000
debug = False

for epoch in range(epochs):
    c = 0
    for i, (x, t) in enumerate(zip(X, T)):

        # 出力(第2層)
        z2 = sigma(w2, x, b2)
        a2 = activate(z2)

        # 出力(第3層)
        z3 = sigma(w3, a2, b3)
        a3 = activate(z3)

        # コスト
        c = c + cost(t, a3)

        # 誤差逆伝播
        d3 = (a3 - t) * deactivate(a3)
        dw3 = dw3 + d3[:, np.newaxis] * a2
        db3 = db3 + d3

        d2 = np.dot(d3, w3) * deactivate(a2)
        dw2 = dw2 + d2[:, np.newaxis] * x
        db2 = db2 + d2

    costs.append(c)

    w2 = w2 - learning_rate * dw2
    w3 = w3 - learning_rate * dw3
    b2 = b2 - learning_rate * db2
    b3 = b3 - learning_rate * db3

    dw2.fill(0)
    dw3.fill(0)
    db2.fill(0)
    db3.fill(0)

    if debug:
        print("=" * 100)
        print("[", epoch, "]", "cost=", c)
        print("=" * 100)
        print("layer 2 (weight)")
        print(w2)
        print("layer 2 (bias)")
        print(b2)
        print("layer 3 (weight)")
        print(w3)
        print("layer 3 (bias)")
        print(b3)
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot()
ax.plot(costs, marker="o", markersize=2)
ax.grid()
plt.minorticks_on()
plt.show()

fig-2

訓練データの出典