blog

PyTorchで文字認識

Published:

By nob

Category: Posts

Tags: 文字認識 OpenCV Pillow PyTorch Intel-Extension-for-PyTorch Python

Intel Extension for PyTorchを有効にする

import torch

import intel_extension_for_pytorch as ipex  # isort: skip

Datasetを作成する

カスタムDatasetクラスを定義する

画像のファイル名とラベルを記載した注釈ファイルを作成し、注釈ファイルを読み込んでDataSetを作成する。

注釈ファイルの内容

table-1-a-0001.png 1
table-1-a-0002.png n
table-1-a-0003.png /
[ファイル名] [ラベル]

学習には使わないが、2桁の数字の画像も注釈ファイルに含めておく。

参考:

import os

from torch.utils.data import Dataset
from torchvision.io import read_image


class ImageDataset(Dataset):

    def __init__(
        self,
        annotation_file,
        label_names,
        image_dir,
        transform=None,
        target_transform=None,
    ):
        self.label_names = label_names
        self.annotations = self.load_annotation_file(annotation_file)
        self.image_dir = image_dir
        self.transform = transform
        self.target_transform = target_transform

    def load_annotation_file(self, annotation_file):
        annotations = []
        with open(annotation_file, "r") as f:
            while line := f.readline():
                file_name, label_name = line.split()
                if label_name in self.label_names:
                    annotations.append([file_name, label_name])
        return annotations

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, i):

        image_path = os.path.join(self.image_dir, self.annotations[i][0])
        image = read_image(image_path)
        if self.transform:
            image = self.transform(image)

        label_name = self.annotations[i][1]
        # ラベルをOne-Hotエンコーディングする
        label = torch.zeros(len(self.label_names))
        label_index = self.label_names.index(label_name)
        label[label_index] = 1
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

Datasetの前処理をする

画像によって大きさが異なるので、大きさを等しくするための変換処理をする。

変換処理は torchvision.transforms モジュールや torchvision.transforms.v2 モジュールを使って行う。

v1よりv2の方が実行速度が速い ということなので、他のtransformsと組み合わせて使えるようにv2のインターフェースで変換処理を作成する。

参考:

変換処理

  1. 画像ごとに余白の大きさが異なるので、画像から余白を取り除く
  2. 画像の大きさを揃える

    PIL.ImageOps.pad を使うと画像のサイズ変更・余白の追加・センタリングをまとめて実行できる。

変換処理をまとめる。

from nobituk import PillowPad, Trim
from torchvision.transforms import v2

image_width = 90
image_height = 120

transforms = v2.Compose(
    [
        Trim(),
        PillowPad(size=(image_width, image_height)),
        v2.Grayscale(),
        v2.ToDtype(torch.float32, scale=True),
    ]
)

Datasetを作成する。注釈ファイルには2桁の数字も含まれるのでラベルを指定する。

label_names = [str(i) for i in range(10)]
label_names.extend(["+", "-", "/", "n", "N"])

train_dataset = ImageDataset(
    "data/pism/image/table-1-a-label.txt",
    label_names,
    "data/pism/image/",
    transform=transforms,
)
print(len(train_dataset))

test_dataset = ImageDataset(
    "data/pism/image/table-1-b-label.txt",
    label_names,
    "data/pism/image/",
    transform=transforms,
)
print(len(test_dataset))

validation_dataset = ImageDataset(
    "data/pism/image/table-1-c-label.txt",
    label_names,
    "data/pism/image/",
    transform=transforms,
)
print(len(validation_dataset))
336
289
293

注釈ファイルの行数を数えてDatasetの件数を確認する。

$ for f in data/pism/image/table-1-*-label.txt; do echo $f $(cat $f|cut -d " " -f 2|grep -o "^.$"|wc -l); done
data/pism/image/table-1-a-label.txt 336
data/pism/image/table-1-b-label.txt 289
data/pism/image/table-1-c-label.txt 293

Datasetを確認する

訓練Datasetの正解ラベルごとの画像を数えてみると偏りがあった。

from IPython.display import display

labels = [label_name for _, label_name in train_dataset.annotations]
label_counters = {
    label_name: labels.count(label_name) for label_name in set(labels)
}
label_counts = sorted(label_counters.items(), key=lambda x: x[0])

display(label_counts)
[('+', 5),
 ('-', 110),
 ('/', 95),
 ('0', 6),
 ('1', 4),
 ('2', 2),
 ('3', 1),
 ('4', 6),
 ('5', 6),
 ('6', 1),
 ('7', 1),
 ('8', 2),
 ('9', 7),
 ('N', 45),
 ('n', 45)]

画像を拡張する(Image Augmentation)

正解ラベルごとの画像の偏りを無くすために訓練Datasetを拡張する。

画像の数が最も多いラベルに合わせて他のラベルの画像を増やす。

正解ラベルごとの不足数を数える。

max_label_name, max_count = max(label_counts, key=lambda x: x[1])

augment_tasks = [
    [label_name, max_count - count]
    for label_name, count in label_counts
    if (max_count - count) > 0
]

display(augment_tasks)
[['+', 105],
 ['/', 15],
 ['0', 104],
 ['1', 106],
 ['2', 108],
 ['3', 109],
 ['4', 104],
 ['5', 104],
 ['6', 109],
 ['7', 109],
 ['8', 108],
 ['9', 103],
 ['N', 65],
 ['n', 65]]

画像を拡張するための変換処理を作成する

torchvision.transforms.v2 モジュールには様々な変換処理が含まれている。変換処理の自作も可能である。

  • 画像の中心を移動する
  • ガウスぼかし

の2つを作成する。

参考:

画像の中心を移動させるような変換処理を定義する。

from torch import nn
from torchvision.transforms.v2 import functional


class Move(nn.Module):

    def __init__(self, dx=0, dy=0):
        super().__init__()
        self.dx = dx
        self.dy = dy

    def forward(self, image):

        org = functional.to_pil_image(image)

        tmp = org.copy()
        tmp.paste(0, (0, 0, *org.size))

        x, y = 0, 0
        left, upper, right, lower = 0, 0, org.size[0], org.size[1]

        if self.dx >= 0:
            x = self.dx
            right = org.size[0] - self.dx
        else:
            left = -self.dx

        if self.dy >= 0:
            y = self.dy
            lower = org.size[1] - self.dy
        else:
            upper = -self.dy

        tmp.paste(org.crop((left, upper, right, lower)), (x, y))

        return functional.pil_to_tensor(tmp)

    def __str__(self):
        return f"Move(dx={self.dx}, dy={self.dy})"

    def __repr__(self):
        return f"Move(dx={self.dx}, dy={self.dy})"

ガウスぼかしの変換処理はPyTorchに用意されている v2.GaussianBlur を利用する。

movers = [
    Move(dx, dy)
    for dx, dy in [(5, 0), (5, 5), (0, 5), (-5, 0), (-5, -5), (0, -5)]
]

blurrers = [
    v2.GaussianBlur(kernel_size=(5, 5), sigma=sigma)
    for sigma in [0.5, 0.6, 0.7, 0.8]
]

transforms = []
transforms.extend(movers)
transforms.extend(blurrers)

画像を拡張するクラスを作成する

ここまでの処理をまとめてクラス化する。

import random


class AugmentedImageDataset(Dataset):

    def __init__(self, original_dataset, transforms):
        self.original_dataset = original_dataset
        self.original_label_names = [
            label_name for _, label_name in self.original_dataset.annotations
        ]
        self.transforms = transforms
        augment_tasks = self.get_augment_tasks()
        self.datasets = self.augment(augment_tasks)

    def get_augment_tasks(self):
        label_name_counters = {
            label_name: self.original_label_names.count(label_name)
            for label_name in set(self.original_label_names)
        }
        max_label_name, max_count = max(
            label_name_counters.items(), key=lambda x: x[1]
        )
        augment_tasks = [
            [label_name, max_count - count]
            for label_name, count in label_name_counters.items()
            if (max_count - count) > 0
        ]
        return augment_tasks

    def augment(self, augment_tasks):
        datasets = []
        for label_name, num in augment_tasks:
            original_indices = [
                i
                for i, x in enumerate(self.original_label_names)
                if x == label_name
            ]
            for _ in range(num):
                original_index = random.choice(original_indices)
                original_data, label = train_dataset[original_index]
                transform = random.choice(self.transforms)
                composite_transform = v2.Compose(
                    [
                        transform,
                        v2.Grayscale(),
                        v2.ToDtype(torch.float32, scale=True),
                    ]
                )
                transformed_data = composite_transform(original_data)
                datasets.append(
                    (original_index, str(transform), transformed_data, label)
                )
        return datasets

    def __len__(self):
        return len(self.datasets)

    def __getitem__(self, i):
        return self.datasets[i][2], self.datasets[i][3]
augmented_dataset = AugmentedImageDataset(train_dataset, transforms)
len(augmented_dataset)
1314

訓練Datasetと拡張Datasetを結合する

参考

from torch.utils.data import ConcatDataset

concat_dataset = ConcatDataset([augmented_dataset, train_dataset])

print(len(augmented_dataset), len(train_dataset), len(concat_dataset))
1314 336 1650

DataLoaderを作成する

from torch.utils.data import DataLoader

batch_size = 30

train_loader = DataLoader(
    concat_dataset,
    batch_size=batch_size,
    shuffle=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
)

validation_loader = DataLoader(
    validation_dataset,
    batch_size=len(validation_dataset),
    shuffle=False,
)

ニューラルネットワークで文字を認識する

モデルを定義する

参考:

device = "xpu"


class NeuralNetwork(nn.Module):

    def __init__(self, in_features, out_features):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, out_features),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = NeuralNetwork(image_width * image_height, len(label_names)).to(device)
print(model)
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=10800, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=15, bias=True)
  )
)

損失関数を定義する

loss_func = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)

訓練する

from tqdm import tqdm


def train(epoch, dataloader, model, loss_func, optimizer):

    model.train()
    model, optimizer = ipex.optimize(model, optimizer=optimizer)

    train_loss_history = []
    train_accuracy_history = []

    with tqdm(dataloader, unit="batch") as tbat:

        tbat.set_description(f"Epoch {epoch} Train")

        for X, y in tbat:

            X, y = X.to(device), y.to(device)
            pred = model(X)

            correct = (pred.argmax(dim=1) == y.argmax(dim=1)).sum().item()
            accuracy = correct / len(X)
            train_accuracy_history.append(accuracy)

            loss = loss_func(pred, y)
            train_loss_history.append(loss.item())

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            tbat.set_postfix(loss=loss.item(), accuracy=100 * accuracy)

    return train_loss_history, train_accuracy_history
def test(epoch, dataloader, model, loss_func):

    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    model.eval()
    model = ipex.optimize(model)

    test_loss, test_correct = 0, 0
    test_loss_history = []
    test_accuracy_history = []

    with torch.no_grad():

        for X, y in dataloader:

            X, y = X.to(device), y.to(device)
            pred = model(X)

            loss = loss_func(pred, y).item()
            test_loss_history.append(loss)
            test_loss += loss

            correct = (pred.argmax(dim=1) == y.argmax(dim=1)).sum().item()
            test_accuracy_history.append(correct / len(X))
            test_correct += correct

    test_loss /= num_batches
    test_correct /= size

    print(
        f"Epoch {epoch} Test Error: Accuracy: {(100*test_correct)}%, Avg loss: {test_loss}"
    )

    return test_loss_history, test_accuracy_history
epochs = 3
total_train_loss_history = []
total_train_accuracy_history = []
total_test_loss_history = []
total_test_accuracy_history = []
for epoch in range(epochs):
    # train
    epoch_train_loss_history, epoch_train_accuracy_history = train(
        epoch, train_loader, model, loss_func, optimizer
    )
    total_train_loss_history.extend(epoch_train_loss_history)
    total_train_accuracy_history.extend(epoch_train_accuracy_history)
    # test
    epoch_test_loss_history, epoch_test_accuracy_history = test(
        epoch, test_loader, model, loss_func
    )
    total_test_loss_history.extend(epoch_test_loss_history)
    total_test_accuracy_history.extend(epoch_test_accuracy_history)
Epoch 0 Train: 100%|██████████| 55/55 [00:00<00:00, 86.61batch/s, accuracy=100, loss=0.00905]


Epoch 0 Test Error: Accuracy: 100.0%, Avg loss: 0.007285525789484382


Epoch 1 Train: 100%|██████████| 55/55 [00:00<00:00, 166.94batch/s, accuracy=100, loss=0.00344]


Epoch 1 Test Error: Accuracy: 100.0%, Avg loss: 0.0023368961177766324


Epoch 2 Train: 100%|██████████| 55/55 [00:00<00:00, 168.76batch/s, accuracy=100, loss=0.00142]


Epoch 2 Test Error: Accuracy: 100.0%, Avg loss: 0.001543202588800341

学習曲線を表示する

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2, figsize=(10, 4))

axs[0].set_title("loss")
axs[0].plot(total_train_loss_history, label="train")
axs[0].plot(total_test_loss_history, label="test")
axs[0].set_xlabel("#batch")
axs[0].legend()

axs[1].set_title("accuracy")
axs[1].plot(total_train_accuracy_history, label="train")
axs[1].plot(total_test_accuracy_history, label="test")
axs[1].set_xlabel("#batch")
axs[1].legend()

plt.tight_layout()
plt.show()

fig-1

モデルを保存する

torch.save(model.state_dict(), "data/pism/53-2-231-table-1.pth")

保存したモデルを読み込む

model = NeuralNetwork(image_width * image_height, len(label_names)).to(device)
model.load_state_dict(torch.load("data/pism/53-2-231-table-1.pth"))
<All keys matched successfully>

読み込んだモデルで検証Datasetのラベルを予測する

def validate(dataloader, model):
    model.eval()
    model = ipex.optimize(model)
    size = len(dataloader.dataset)
    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            # yと同じフォーマットに変換する
            y_pred = torch.zeros_like(pred)
            y_pred[torch.arange(len(y_pred)), pred.argmax(dim=1)] = 1
            # 正解率
            correct += (y_pred + y == 2).sum().item()
            wrong_indices = torch.argwhere(y_pred - y == 1).sum(dim=1)
            if len(wrong_indices) > 0:
                print("wrong indices:", wrong_indices)
                print("predicted:", y_pred[wrong_indices])
                print(
                    "but the truth is:",
                    y[torch.argwhere(y - y_pred == 1).sum(dim=1)],
                )

    print(f"accuracy: {100*correct/size}% ({correct}/{size})")
torch.set_printoptions(threshold=100)

validate(validation_loader, model)
validate(train_loader, model)
validate(test_loader, model)
accuracy: 100.0% (293/293)
accuracy: 100.0% (1650/1650)
accuracy: 100.0% (289/289)