blog

手書き数字の認識

Published:

By nob

Category: Posts

Tags: 機械学習 ニューラルネットワーク digits

8x8画素、モノクロ17階調の画像を識別してみる

データの読み込み

import csv

import matplotlib.pyplot as plt
import numpy as np


def load_digits(file):
    x = []
    y = []
    with open(file, "r") as f:
        reader = csv.reader(f)
        for line in reader:
            x.append(np.array(line[0:-1], dtype=np.float64) / 16)
            yv = int(line[-1])
            ya = np.zeros(10)
            ya[yv] = 1
            y.append(ya)
    return np.array(x, dtype=np.float64), np.array(y, dtype=np.float64)


np.set_printoptions(precision=3)

x_train, y_train = load_digits("./data/digits/optdigits.tra")
x_test, y_test = load_digits("./data/digits/optdigits.tes")
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

print(x_train[0].reshape(8, 8))
print(y_train[0])

fig = plt.figure(figsize=(10, 2))
for i, t in enumerate(x_train[0:10], start=1):
    ax = fig.add_subplot(1, 10, i)
    ax.imshow(t.reshape(8, 8), cmap=plt.cm.gray)
    ax.axis("off")
plt.tight_layout()
plt.show()
(3823, 64) (3823, 10)
(1797, 64) (1797, 10)
[[0.    0.062 0.375 0.938 0.75  0.062 0.    0.   ]
 [0.    0.438 1.    0.375 0.375 0.625 0.    0.   ]
 [0.    0.5   1.    0.125 0.    0.688 0.125 0.   ]
 [0.    0.312 1.    0.188 0.    0.312 0.438 0.   ]
 [0.    0.438 0.812 0.188 0.    0.5   0.438 0.   ]
 [0.    0.25  0.75  0.    0.062 0.812 0.312 0.   ]
 [0.    0.    0.875 0.562 0.938 0.562 0.    0.   ]
 [0.    0.    0.375 0.875 0.438 0.062 0.    0.   ]]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

fig-1

識別

import matplotlib.pyplot as plt
import numpy as np
from mlp import (
    ActivationLayer,
    FullyConnectedLayer,
    InputLayer,
    MeanSquaredError,
    MultiLayerPerceptron,
    Sigmoid,
)

rng = np.random.default_rng(123)

mlp = MultiLayerPerceptron(
    [
        InputLayer(64),
        FullyConnectedLayer(64, 64, rng),
        ActivationLayer(Sigmoid()),
        FullyConnectedLayer(64, 10, rng),
        ActivationLayer(Sigmoid()),
    ],
    MeanSquaredError(),
    epochs=100,
    num_batch=100,
    learning_rate=0.01,
)
mlp.fit(x_train, y_train)
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot()
ax.plot(mlp.history, marker="o", markersize=2)
ax.grid()
plt.minorticks_on()
plt.show()

fig-2

def accuracy(pred, truth):
    return np.sum(np.all((truth - pred) == 0, axis=1)) / pred.shape[0]


y_train_pred = mlp.predict(x_train)
y_test_pred = mlp.predict(x_test)
print("train score", accuracy(y_train_pred, y_train))
print("test score", accuracy(y_test_pred, y_test))
train score 0.9497776615223646
test score 0.9081803005008348

データの出典

Optical Recognition of Handwritten Digits