PyTorchで文字認識
Published:
By nobCategory: Posts
Tags: 文字認識 OpenCV Pillow PyTorch Intel-Extension-for-PyTorch Python
Intel Extension for PyTorchを有効にする
import torch
import intel_extension_for_pytorch as ipex # isort: skip
Datasetを作成する
カスタムDatasetクラスを定義する
画像のファイル名とラベルを記載した注釈ファイルを作成し、注釈ファイルを読み込んでDataSetを作成する。
注釈ファイルの内容
table-1-a-0001.png 1
table-1-a-0002.png n
table-1-a-0003.png /
[ファイル名] [ラベル]
学習には使わないが、2桁の数字の画像も注釈ファイルに含めておく。
参考:
import os
from torch.utils.data import Dataset
from torchvision.io import read_image
class ImageDataset(Dataset):
def __init__(
self,
annotation_file,
label_names,
image_dir,
transform=None,
target_transform=None,
):
self.label_names = label_names
self.annotations = self.load_annotation_file(annotation_file)
self.image_dir = image_dir
self.transform = transform
self.target_transform = target_transform
def load_annotation_file(self, annotation_file):
annotations = []
with open(annotation_file, "r") as f:
while line := f.readline():
file_name, label_name = line.split()
if label_name in self.label_names:
annotations.append([file_name, label_name])
return annotations
def __len__(self):
return len(self.annotations)
def __getitem__(self, i):
image_path = os.path.join(self.image_dir, self.annotations[i][0])
image = read_image(image_path)
if self.transform:
image = self.transform(image)
label_name = self.annotations[i][1]
# ラベルをOne-Hotエンコーディングする
label = torch.zeros(len(self.label_names))
label_index = self.label_names.index(label_name)
label[label_index] = 1
if self.target_transform:
label = self.target_transform(label)
return image, label
Datasetの前処理をする
画像によって大きさが異なるので、大きさを等しくするための変換処理をする。
変換処理は
torchvision.transforms
モジュールや
torchvision.transforms.v2
モジュールを使って行う。
v1よりv2の方が実行速度が速い ということなので、他のtransformsと組み合わせて使えるようにv2のインターフェースで変換処理を作成する。
参考:
- V1 or V2? Which one should I use?
- How to write your own v2 transforms
- PyTorchのtransforms.v2を文字の画像に適用する
変換処理
- 画像ごとに余白の大きさが異なるので、画像から余白を取り除く
-
画像の大きさを揃える
PIL.ImageOps.pad を使うと画像のサイズ変更・余白の追加・センタリングをまとめて実行できる。
変換処理をまとめる。
from nobituk import PillowPad, Trim
from torchvision.transforms import v2
image_width = 90
image_height = 120
transforms = v2.Compose(
[
Trim(),
PillowPad(size=(image_width, image_height)),
v2.Grayscale(),
v2.ToDtype(torch.float32, scale=True),
]
)
Datasetを作成する。注釈ファイルには2桁の数字も含まれるのでラベルを指定する。
label_names = [str(i) for i in range(10)]
label_names.extend(["+", "-", "/", "n", "N"])
train_dataset = ImageDataset(
"data/pism/image/table-1-a-label.txt",
label_names,
"data/pism/image/",
transform=transforms,
)
print(len(train_dataset))
test_dataset = ImageDataset(
"data/pism/image/table-1-b-label.txt",
label_names,
"data/pism/image/",
transform=transforms,
)
print(len(test_dataset))
validation_dataset = ImageDataset(
"data/pism/image/table-1-c-label.txt",
label_names,
"data/pism/image/",
transform=transforms,
)
print(len(validation_dataset))
336
289
293
注釈ファイルの行数を数えてDatasetの件数を確認する。
$ for f in data/pism/image/table-1-*-label.txt; do echo $f $(cat $f|cut -d " " -f 2|grep -o "^.$"|wc -l); done
data/pism/image/table-1-a-label.txt 336
data/pism/image/table-1-b-label.txt 289
data/pism/image/table-1-c-label.txt 293
Datasetを確認する
訓練Datasetの正解ラベルごとの画像を数えてみると偏りがあった。
from IPython.display import display
labels = [label_name for _, label_name in train_dataset.annotations]
label_counters = {
label_name: labels.count(label_name) for label_name in set(labels)
}
label_counts = sorted(label_counters.items(), key=lambda x: x[0])
display(label_counts)
[('+', 5),
('-', 110),
('/', 95),
('0', 6),
('1', 4),
('2', 2),
('3', 1),
('4', 6),
('5', 6),
('6', 1),
('7', 1),
('8', 2),
('9', 7),
('N', 45),
('n', 45)]
画像を拡張する(Image Augmentation)
正解ラベルごとの画像の偏りを無くすために訓練Datasetを拡張する。
画像の数が最も多いラベルに合わせて他のラベルの画像を増やす。
正解ラベルごとの不足数を数える。
max_label_name, max_count = max(label_counts, key=lambda x: x[1])
augment_tasks = [
[label_name, max_count - count]
for label_name, count in label_counts
if (max_count - count) > 0
]
display(augment_tasks)
[['+', 105],
['/', 15],
['0', 104],
['1', 106],
['2', 108],
['3', 109],
['4', 104],
['5', 104],
['6', 109],
['7', 109],
['8', 108],
['9', 103],
['N', 65],
['n', 65]]
画像を拡張するための変換処理を作成する
torchvision.transforms.v2
モジュールには様々な変換処理が含まれている。変換処理の自作も可能である。
- 画像の中心を移動する
- ガウスぼかし
の2つを作成する。
参考:
画像の中心を移動させるような変換処理を定義する。
from torch import nn
from torchvision.transforms.v2 import functional
class Move(nn.Module):
def __init__(self, dx=0, dy=0):
super().__init__()
self.dx = dx
self.dy = dy
def forward(self, image):
org = functional.to_pil_image(image)
tmp = org.copy()
tmp.paste(0, (0, 0, *org.size))
x, y = 0, 0
left, upper, right, lower = 0, 0, org.size[0], org.size[1]
if self.dx >= 0:
x = self.dx
right = org.size[0] - self.dx
else:
left = -self.dx
if self.dy >= 0:
y = self.dy
lower = org.size[1] - self.dy
else:
upper = -self.dy
tmp.paste(org.crop((left, upper, right, lower)), (x, y))
return functional.pil_to_tensor(tmp)
def __str__(self):
return f"Move(dx={self.dx}, dy={self.dy})"
def __repr__(self):
return f"Move(dx={self.dx}, dy={self.dy})"
ガウスぼかしの変換処理はPyTorchに用意されている v2.GaussianBlur を利用する。
movers = [
Move(dx, dy)
for dx, dy in [(5, 0), (5, 5), (0, 5), (-5, 0), (-5, -5), (0, -5)]
]
blurrers = [
v2.GaussianBlur(kernel_size=(5, 5), sigma=sigma)
for sigma in [0.5, 0.6, 0.7, 0.8]
]
transforms = []
transforms.extend(movers)
transforms.extend(blurrers)
画像を拡張するクラスを作成する
ここまでの処理をまとめてクラス化する。
import random
class AugmentedImageDataset(Dataset):
def __init__(self, original_dataset, transforms):
self.original_dataset = original_dataset
self.original_label_names = [
label_name for _, label_name in self.original_dataset.annotations
]
self.transforms = transforms
augment_tasks = self.get_augment_tasks()
self.datasets = self.augment(augment_tasks)
def get_augment_tasks(self):
label_name_counters = {
label_name: self.original_label_names.count(label_name)
for label_name in set(self.original_label_names)
}
max_label_name, max_count = max(
label_name_counters.items(), key=lambda x: x[1]
)
augment_tasks = [
[label_name, max_count - count]
for label_name, count in label_name_counters.items()
if (max_count - count) > 0
]
return augment_tasks
def augment(self, augment_tasks):
datasets = []
for label_name, num in augment_tasks:
original_indices = [
i
for i, x in enumerate(self.original_label_names)
if x == label_name
]
for _ in range(num):
original_index = random.choice(original_indices)
original_data, label = train_dataset[original_index]
transform = random.choice(self.transforms)
composite_transform = v2.Compose(
[
transform,
v2.Grayscale(),
v2.ToDtype(torch.float32, scale=True),
]
)
transformed_data = composite_transform(original_data)
datasets.append(
(original_index, str(transform), transformed_data, label)
)
return datasets
def __len__(self):
return len(self.datasets)
def __getitem__(self, i):
return self.datasets[i][2], self.datasets[i][3]
augmented_dataset = AugmentedImageDataset(train_dataset, transforms)
len(augmented_dataset)
1314
訓練Datasetと拡張Datasetを結合する
参考
from torch.utils.data import ConcatDataset
concat_dataset = ConcatDataset([augmented_dataset, train_dataset])
print(len(augmented_dataset), len(train_dataset), len(concat_dataset))
1314 336 1650
DataLoaderを作成する
from torch.utils.data import DataLoader
batch_size = 30
train_loader = DataLoader(
concat_dataset,
batch_size=batch_size,
shuffle=True,
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=True,
)
validation_loader = DataLoader(
validation_dataset,
batch_size=len(validation_dataset),
shuffle=False,
)
ニューラルネットワークで文字を認識する
モデルを定義する
参考:
- Learn the Basics
- Intel Extension for PyTorch - Python examples demonstrate usage of Python APIs
- ipex.optimize
device = "xpu"
class NeuralNetwork(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, out_features),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork(image_width * image_height, len(label_names)).to(device)
print(model)
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=10800, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=15, bias=True)
)
)
損失関数を定義する
loss_func = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
訓練する
from tqdm import tqdm
def train(epoch, dataloader, model, loss_func, optimizer):
model.train()
model, optimizer = ipex.optimize(model, optimizer=optimizer)
train_loss_history = []
train_accuracy_history = []
with tqdm(dataloader, unit="batch") as tbat:
tbat.set_description(f"Epoch {epoch} Train")
for X, y in tbat:
X, y = X.to(device), y.to(device)
pred = model(X)
correct = (pred.argmax(dim=1) == y.argmax(dim=1)).sum().item()
accuracy = correct / len(X)
train_accuracy_history.append(accuracy)
loss = loss_func(pred, y)
train_loss_history.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
tbat.set_postfix(loss=loss.item(), accuracy=100 * accuracy)
return train_loss_history, train_accuracy_history
def test(epoch, dataloader, model, loss_func):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
model = ipex.optimize(model)
test_loss, test_correct = 0, 0
test_loss_history = []
test_accuracy_history = []
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_func(pred, y).item()
test_loss_history.append(loss)
test_loss += loss
correct = (pred.argmax(dim=1) == y.argmax(dim=1)).sum().item()
test_accuracy_history.append(correct / len(X))
test_correct += correct
test_loss /= num_batches
test_correct /= size
print(
f"Epoch {epoch} Test Error: Accuracy: {(100*test_correct)}%, Avg loss: {test_loss}"
)
return test_loss_history, test_accuracy_history
epochs = 3
total_train_loss_history = []
total_train_accuracy_history = []
total_test_loss_history = []
total_test_accuracy_history = []
for epoch in range(epochs):
# train
epoch_train_loss_history, epoch_train_accuracy_history = train(
epoch, train_loader, model, loss_func, optimizer
)
total_train_loss_history.extend(epoch_train_loss_history)
total_train_accuracy_history.extend(epoch_train_accuracy_history)
# test
epoch_test_loss_history, epoch_test_accuracy_history = test(
epoch, test_loader, model, loss_func
)
total_test_loss_history.extend(epoch_test_loss_history)
total_test_accuracy_history.extend(epoch_test_accuracy_history)
Epoch 0 Train: 100%|██████████| 55/55 [00:00<00:00, 86.61batch/s, accuracy=100, loss=0.00905]
Epoch 0 Test Error: Accuracy: 100.0%, Avg loss: 0.007285525789484382
Epoch 1 Train: 100%|██████████| 55/55 [00:00<00:00, 166.94batch/s, accuracy=100, loss=0.00344]
Epoch 1 Test Error: Accuracy: 100.0%, Avg loss: 0.0023368961177766324
Epoch 2 Train: 100%|██████████| 55/55 [00:00<00:00, 168.76batch/s, accuracy=100, loss=0.00142]
Epoch 2 Test Error: Accuracy: 100.0%, Avg loss: 0.001543202588800341
学習曲線を表示する
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].set_title("loss")
axs[0].plot(total_train_loss_history, label="train")
axs[0].plot(total_test_loss_history, label="test")
axs[0].set_xlabel("#batch")
axs[0].legend()
axs[1].set_title("accuracy")
axs[1].plot(total_train_accuracy_history, label="train")
axs[1].plot(total_test_accuracy_history, label="test")
axs[1].set_xlabel("#batch")
axs[1].legend()
plt.tight_layout()
plt.show()
モデルを保存する
torch.save(model.state_dict(), "data/pism/53-2-231-table-1.pth")
保存したモデルを読み込む
model = NeuralNetwork(image_width * image_height, len(label_names)).to(device)
model.load_state_dict(torch.load("data/pism/53-2-231-table-1.pth"))
<All keys matched successfully>
読み込んだモデルで検証Datasetのラベルを予測する
def validate(dataloader, model):
model.eval()
model = ipex.optimize(model)
size = len(dataloader.dataset)
correct = 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
# yと同じフォーマットに変換する
y_pred = torch.zeros_like(pred)
y_pred[torch.arange(len(y_pred)), pred.argmax(dim=1)] = 1
# 正解率
correct += (y_pred + y == 2).sum().item()
wrong_indices = torch.argwhere(y_pred - y == 1).sum(dim=1)
if len(wrong_indices) > 0:
print("wrong indices:", wrong_indices)
print("predicted:", y_pred[wrong_indices])
print(
"but the truth is:",
y[torch.argwhere(y - y_pred == 1).sum(dim=1)],
)
print(f"accuracy: {100*correct/size}% ({correct}/{size})")
torch.set_printoptions(threshold=100)
validate(validation_loader, model)
validate(train_loader, model)
validate(test_loader, model)
accuracy: 100.0% (293/293)
accuracy: 100.0% (1650/1650)
accuracy: 100.0% (289/289)