blog

PDFから表の画像を取り出してDataFrameにする(2)

Published:

By nob

Category: Posts

Tags: 文字認識 OCR Intel-Extension-for-PyTorch PyTorch CNN OpenCV Pillow Python

前回:「PDFから表の画像を取り出してDataFrameにする」 では文字認識にOCRライブラリである Tesseract を使った。

今回は PyTorch を使う。

前回Tesseractを使ったのだが文字ごとに個別の前処理(フィルタ)を実装するのが面倒だった。

tesstrain を使ってTesseractをTrainingするというアイデアもあったのだが、PyTorchに興味があるので使ってみる。

私のPCにはIntel Arc AシリーズのGPUが搭載されているのでIntel Extension for PyTorch(2.1.30+xpu)を使った。インストール手順は Welcome to Intel® Extension for PyTorch* Documentation! 記載の通り。

Intel Extension for PyTorchを有効にする

作業の前にIntel Extension for PyTorchを有効にしておく。

import torch

import intel_extension_for_pytorch as ipex  # isort: skip

警告メッセージのフォーマットを変更する

import re
import sys


def custom_warning(message, category, filename, lineno, file=None, line=None):
    filename = re.sub("^.*?lib", "{virtualenv}/lib", filename)
    sys.stderr.write(
        "{filename}:{lineno}\n{category}:{message}\n".format(
            message=message,
            category=category,
            filename=filename,
            lineno=lineno,
        )
    )

データを準備する

前回 保存した画像ファイルを使う。

表には2桁の数字が含まれるので、画像を1文字づつに分割して文字認識する。

参考:

import os

import cv2
import numpy as np
from PIL import Image
from torchvision.io import read_image
from torchvision.transforms.v2 import functional


def load_image(annotation_files, label_names, image_dir):

    data = []
    labels = []
    num_chars = []
    datasets = []

    for annotation_file in annotation_files:
        with open(annotation_file, "r") as f:
            while line := f.readline():
                file_name, label_name = line.strip().split()
                datasets.append([file_name, label_name])

    for file_name, label_name in datasets:
        if file_name.endswith("-0000.png"):
            # 1番目のカラムは読み飛ばす
            continue

        # 画像に含まれる文字
        chars = list(label_name)

        # 画像に含まれる文字数
        num_chars.append(len(chars))

        # ラベルの分割
        for char in chars:
            if char not in label_names:
                continue

            # ラベルをOne-Hotエンコーディングする
            label = torch.zeros(len(label_names))
            label_index = label_names.index(char)
            label[label_index] = 1
            labels.append(label)

        # 画像の読み込み
        image_path = os.path.join(image_dir, file_name)
        image = read_image(image_path)

        tmp = functional.to_pil_image(image)
        tmp = np.array(tmp)

        # 画像の分割
        # 少し線を太くする
        dilated = cv2.dilate(tmp, None, iterations=1)
        # 輪郭を抽出する
        contours, _ = cv2.findContours(
            dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        # バウンディングボックスを求める
        bboxes = []
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            bboxes.append([x, y, w, h])

        if len(chars) != len(bboxes):
            raise ValueError(
                "number of bounding box should be {} but found {}. image: {}".format(
                    len(chars), len(bboxes), file_name
                )
            )

        if len(chars) > 1:
            # "10"の画像を読み込んで"0","1"と輪郭が取れるので
            # 輪郭をx軸でソートする
            bboxes = sorted(bboxes, key=lambda bbox: bbox[0])

        for x, y, w, h in bboxes:
            data.append(
                functional.pil_to_tensor(
                    Image.fromarray(tmp[y : y + h, x : x + w])
                )
            )

    return data, labels, num_chars
teams = ["a", "b", "c"]

label_names = [str(i) for i in range(10)]
label_names.extend(["+", "-", "/", "n", "N"])

image_dir = "data/pism/image"
annotation_files = [
    os.path.join(image_dir, f"table-1-{team}-label.txt") for team in teams
]
annotation_files
['data/pism/image/table-1-a-label.txt',
 'data/pism/image/table-1-b-label.txt',
 'data/pism/image/table-1-c-label.txt']
data, labels, num_chars = load_image(annotation_files, label_names, image_dir)
len(data), len(labels), len(num_chars), sum(num_chars)
(1042, 1042, 980, 1042)

読み込んだデータを表示してみる

import matplotlib.pyplot as plt

fig, axs = plt.subplots(28, 40, figsize=(15, 10))
axs = axs.flatten()

i = j = 0

for k, num_char in enumerate(num_chars):
    # 2桁の数字が表示される列の場合
    if k % 35 in [0, 31, 32, 33, 34]:
        if num_char == 1:
            # 1桁ならダミーの画像を追加して表示する
            axs[j].imshow(Image.fromarray(np.zeros((1, 1))), cmap=plt.cm.gray)
            axs[j].set_axis_off()
            j += 1
        else:
            # 2桁なら画像を2つとも表示する
            axs[j].imshow(functional.to_pil_image(data[i]), cmap=plt.cm.gray)
            axs[j].set_axis_off()
            i += 1
            j += 1
    axs[j].imshow(functional.to_pil_image(data[i]), cmap=plt.cm.gray)
    axs[j].set_axis_off()
    i += 1
    j += 1


plt.tight_layout()
plt.show()

fig-1

データの変換処理を定義する

参考:

from nobituk import PillowPad, Trim
from torchvision.transforms import v2

image_width = 90
image_height = 120

transforms = v2.Compose(
    [
        Trim(),
        PillowPad(size=(image_width, image_height)),
        v2.Grayscale(),
        v2.ToDtype(torch.float32, scale=True),
    ]
)

Datasetを作成する

from torch.utils.data import Dataset


class TransformDataset(Dataset):
    def __init__(self, data, target, transform=None, target_transform=None):
        self.data = data
        self.target = target
        self.transform = transform
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(data)

    def __getitem__(self, i):
        x = self.data[i]
        if self.transform:
            x = self.transform(self.data[i])
        t = self.target[i]
        if self.target_transform:
            t = self.target_transform(self.target[i])
        return x, t
dataset = TransformDataset(data, labels, transforms)

Datasetを分割する

from torch.utils.data import random_split

train_size = int(0.6 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

DataLoaderを作成する

from torch.utils.data import DataLoader

batch_size = 30
num_workers = 2

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)

test_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)

画像から文字を認識する

モデルを作成する

「PyTorchで文字認識」 ではニューラルネットワークを使ったが、今回はCNN(Convolutional Neural Network, 畳み込みニューラルネットワーク)を使う。

  • 入力

    90x120、1チャネル(白黒)

  • 出力

    0-9, -, N, n, /, + (15ラベル)

  • モデル

    • 畳み込み (入力チャネル: 1, 出力チャネル: 32, カーネル: (3, 3))
    • 活性化 (ReLU)
    • 最大値プーリング (カーネル: (2, 2))
    • 畳み込み (入力チャネル: 32, 出力チャネル: 64, カーネル: (3, 3))
    • 活性化 (ReLU)
    • 最大値プーリング (カーネル: (2, 2))
    • 全結合 (入力: n , 出力: 2048)
    • 活性化 (ReLU)
    • 全結合 (入力: 2048, 出力: 256)
    • 活性化 (ReLU)
    • 全結合 (入力: 256, 出力: 15)

全結合層の入力 n は、畳み込みや最大値プーリングのカーネル・ストライド・パディング等から以下のように順を追って計算する必要がある。

  • 入力

    1 * 90 * 120

  • 畳み込み

    32 * (90-3+1) * (120-3+1) = 32 * 88 * 118

  • 最大値プーリング

    32 * (88//2) * (118//2) = 32 * 44 * 59

  • 畳み込み

    64 * (44-3+1) * (59-3+1) = 64 * 42 * 57

  • 最大値プーリング

    64 * (42//2) * (57//2) = 64 * 21 * 28

  • 全結合層の入力 n

    64 * 21 * 28

しかし torch.nn.LazyLinear を使うと初回の入力時に入力サイズが自動的に計算されるため入力サイズの指定は不要である。

from torch import nn


class CNN(nn.Module):

    def __init__(self):

        super().__init__()

        conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=(3, 3),
            ),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=(2, 2),
            ),
        )

        conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=(3, 3),
            ),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=(2, 2),
            ),
        )

        convolve = nn.Sequential(
            conv1,
            conv2,
        )

        connect = nn.Sequential(
            nn.LazyLinear(2048),
            nn.ReLU(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 15),
        )

        self.layers = nn.Sequential(
            convolve,
            nn.Flatten(),
            connect,
        )

    def forward(self, x):
        return self.layers(x)
import warnings

device = "xpu"

with warnings.catch_warnings(
    action="default",
):
    warnings.showwarning = custom_warning

    model = CNN().to(device)
{virtualenv}/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180
<class 'UserWarning'>:Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
print(model)
CNN(
  (layers): Sequential(
    (0): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      )
      (1): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1): Flatten(start_dim=1, end_dim=-1)
    (2): Sequential(
      (0): LazyLinear(in_features=0, out_features=2048, bias=True)
      (1): ReLU()
      (2): Linear(in_features=2048, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=15, bias=True)
    )
  )
)

損失関数を定義する

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

訓練する

from tqdm import tqdm


def train(epoch, dataloader, model, criterion, optimizer):

    model.train()
    model, optimizer = ipex.optimize(model, optimizer=optimizer)

    train_loss_history = []
    train_accuracy_history = []

    with tqdm(dataloader, unit="batch") as tbat:
        tbat.set_description(f"Epoch {epoch} Train")
        for images, labels in tbat:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            loss = criterion(outputs, labels)
            train_loss_history.append(loss.item())
            loss.backward()

            correct = (
                (outputs.argmax(dim=1) == labels.argmax(dim=1)).sum().item()
            )
            accuracy = correct / len(images)
            train_accuracy_history.append(accuracy)

            optimizer.step()
            optimizer.zero_grad()

            tbat.set_postfix(loss=loss.item(), accuracy=100 * accuracy)

    return train_loss_history, train_accuracy_history
def test(epoch, dataloader, model, criterion):

    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    model.eval()
    model = ipex.optimize(model)

    test_loss, test_correct = 0, 0
    test_loss_history = []
    test_accuracy_history = []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            loss = criterion(outputs, labels).item()
            test_loss_history.append(loss)
            test_loss += loss

            correct = (
                (outputs.argmax(dim=1) == labels.argmax(dim=1)).sum().item()
            )
            test_accuracy_history.append(correct / len(images))
            test_correct += correct

    test_correct /= size
    test_loss /= num_batches
    print(
        f"Epoch {epoch} Test Error: Accuracy: {(100*test_correct)}%, Avg loss: {test_loss}"
    )

    return test_loss_history, test_accuracy_history
epochs = 5
total_train_loss_history = []
total_train_accuracy_history = []
total_test_loss_history = []
total_test_accuracy_history = []

for epoch in range(epochs):

    # train
    epoch_loss_history, epoch_accuracy_history = train(
        epoch,
        train_loader,
        model,
        criterion,
        optimizer,
    )

    total_train_loss_history.extend(epoch_loss_history)
    total_train_accuracy_history.extend(epoch_accuracy_history)

    # test
    epoch_loss_history, epoch_accuracy_history = test(
        epoch,
        test_loader,
        model,
        criterion,
    )

    total_test_loss_history.extend(epoch_loss_history)
    total_test_accuracy_history.extend(epoch_accuracy_history)
Epoch 0 Train: 100%|██████████| 21/21 [00:02<00:00,  9.10batch/s, accuracy=100, loss=0.00975]


Epoch 0 Test Error: Accuracy: 98.56%, Avg loss: 0.05610149020018677


Epoch 1 Train: 100%|██████████| 21/21 [00:01<00:00, 11.60batch/s, accuracy=100, loss=0.0109]


Epoch 1 Test Error: Accuracy: 99.83999999999999%, Avg loss: 0.006129331724271954


Epoch 2 Train: 100%|██████████| 21/21 [00:01<00:00, 11.60batch/s, accuracy=100, loss=0.00412]


Epoch 2 Test Error: Accuracy: 99.83999999999999%, Avg loss: 0.009472846632152573


Epoch 3 Train: 100%|██████████| 21/21 [00:01<00:00, 11.61batch/s, accuracy=100, loss=0.00361]


Epoch 3 Test Error: Accuracy: 99.83999999999999%, Avg loss: 0.0019177654740772034


Epoch 4 Train: 100%|██████████| 21/21 [00:01<00:00, 11.59batch/s, accuracy=100, loss=1.92e-6]


Epoch 4 Test Error: Accuracy: 100.0%, Avg loss: 0.00019342863193319966

学習曲線を表示する

fig, axs = plt.subplots(1, 2, figsize=(10, 4))

axs[0].set_title("loss")
axs[0].plot(total_train_loss_history, label="train")
axs[0].plot(total_test_loss_history, label="test")
axs[0].set_xlabel("#batch")
axs[0].legend()

axs[1].set_title("accuracy")
axs[1].plot(total_train_accuracy_history, label="train")
axs[1].plot(total_test_accuracy_history, label="test")
axs[1].set_xlabel("#batch")
axs[1].legend()

plt.tight_layout()
plt.show()

fig-2

モデルを保存する

torch.save(model.state_dict(), "data/pism/53-2-231-table-1.cnn.pth")

モデルを読み込む

model.load_state_dict(torch.load("data/pism/53-2-231-table-1.cnn.pth"))
<All keys matched successfully>

読み込んだモデルでDataset全体のラベルを予測する

def predict(dataloader, model):
    model.eval()
    model = ipex.optimize(model)
    result = []
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            result.extend(outputs.tolist())
    return result
result = predict(dataloader, model)

DataFrameに変換する

PyTorchの予測結果を変換する

予測結果はこのような形式なので

result[0]
[-14.32275390625,
 20.943679809570312,
 -1.1709486246109009,
 -6.548696517944336,
 -7.440320014953613,
 1.206632375717163,
 -12.301752090454102,
 6.903934001922607,
 -6.3719305992126465,
 -2.277395725250244,
 -1.5444855690002441,
 -10.331376075744629,
 -6.116647720336914,
 -14.116082191467285,
 -20.127317428588867]

最大値を調べてラベルに変換する。

decoded_result = []
for r in result:
    decoded_result.append(label_names[r.index(max(r))])

予測結果は1文字ごとの予測であるので、2桁の数字のカラムは2文字分を結合する。

shift_table_data = []

i = 0

staff_shift_table_data = []
for j, num_char in enumerate(num_chars):

    if j > 0 and j % 35 == 0:
        shift_table_data.append(staff_shift_table_data)
        staff_shift_table_data = []

    if j % 35 in [0, 31, 32, 33, 34]:
        if num_char == 1:
            staff_shift_table_data.append(int(decoded_result[i]))
        else:
            num = decoded_result[i]
            i += 1
            num += decoded_result[i]
            staff_shift_table_data.append(int(num))
    else:
        staff_shift_table_data.append(decoded_result[i])

    i += 1

shift_table_data.append(staff_shift_table_data)

DataFrameに変換する

import pandas as pd

pd.options.display.max_columns = 40
pd.options.display.max_rows = 100

shift_table = pd.DataFrame(
    shift_table_data,
    columns=list(range(1, len(shift_table_data[0]) + 1)),
)
shift_table
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
0 1 n / / - N n / - / - - + N n / - - N n / - - N n / - / N n / 9 9 5 1
1 2 / N n / - - N n / / - N n / - - N n / - - N n / - - / / - - 9 11 5 0
2 3 N n / / - N n / - - N n / - - - / / - - - / - N n / - - / N 9 12 5 0
3 4 / - / / / - - - N n / - - N n / - - - / N n / - - N n / - / 10 12 4 0
4 5 - - N n / - - N n / - - / / N n / / - N n / / / N n / - - - 10 10 5 0
5 6 / - - N n / - - - N n / - - - N n / - - / - / / - - N n / / 9 13 4 0
6 7 - / - - / - + N n / / - N n / / / - N n / - - - / - - - N n 9 12 4 1
7 8 - N n / - + - + N n / - / N n / - - / N n / / / N n / / - - 10 8 5 2
8 9 - / / - N n / / - - / + - - N n / - - / - N n / / - N n / - 10 11 4 1
9 10 / / - N n / / - - N n / - - - / / N n / / - - - / - - - N n 10 12 4 0
10 11 / - - N n / / - N n / + - N n / - - - N n / / / - / / - N n 10 9 5 1
11 12 N n / / - - - / - N n / / - - N n / / / / - - - N n / - - - 10 12 4 0
12 13 / N n / - - N n / / - N n / / - N n / - / - - N n / - - / - 10 10 5 0
13 14 - - N n / - - - / - - - / / / / / N n / - - N n / - - / - N 10 13 4 0
14 15 n / / / - - + N n / / - - - N n / - - / N n / - - - / N n / 10 10 4 1
15 16 N n / - - N n / - - N n / - - / / / - - - N n / / - N n / / 10 10 5 0
16 17 - / - - N n / - / / - + N n / - - - N n / - - - / N n / / / 10 11 4 1
17 18 / - - / / N n / - - / - - + - N n / - - N n / / - N n / - / 10 11 4 1
18 19 - / / - / - - - / / / + - + - - - / - - - / - / - - - / - - 10 18 0 2
19 20 - - N n / / - N n / - - / - N n / - - - N n / / - N n / / - 9 11 5 0
20 21 - - - N n / - - N n / - - N n / / / N n / / - - / - N n / / 10 10 5 0
21 22 - N n / - - N n / / - N n / / - N n / / - N n / - - - / / - 10 10 5 0
22 23 N n / / - - + / - N n / - - / N n / - / - - / N n / - - N n 9 10 5 1
23 24 - - N n / - + - / / N n / / - - - / - - / - N n / - - N n / 9 12 4 1
24 25 n / / - - N n / - - / - N n / / - N n / - / - N n / / - - N 10 10 5 0
25 26 / / - - N n + / - - N n / - - / / - - N n / - - N n / / - - 9 12 4 1
26 27 n / / / - - N n / - - + - / - - N n / - + / N n / - - N n / 9 10 4 2
27 28 / / - - / - - - / / - N n / / - - - N n / - / - - - / - - N 10 15 3 0

結果を確認する

正しい勤務表とPyTorchで予測した勤務表を出力する。

answer = pd.read_csv(
    "data/pism/53-2-231-table-1.csv",
    header=None,
    skiprows=1,
    names=list(range(1, 36)),
)
answer
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
0 1 n / / - N n / - / - - + N n / - - N n / - - N n / - / N n / 9 9 5 1
1 2 / N n / - - N n / / - N n / - - N n / - - N n / - - / / - - 9 11 5 0
2 3 N n / / - N n / - - N n / - - - / / - - - / - N n / - - / N 9 12 5 0
3 4 / - / / / - - - N n / - - N n / - - - / N n / - - N n / - / 10 12 4 0
4 5 - - N n / - - N n / - - / / N n / / - N n / / / N n / - - - 10 10 5 0
5 6 / - - N n / - - - N n / - - - N n / - - / - / / - - N n / / 9 13 4 0
6 7 - / - - / - + N n / / - N n / / / - N n / - - - / - - - N n 9 12 4 1
7 8 - N n / - + - + N n / - / N n / - - / N n / / / N n / / - - 10 8 5 2
8 9 - / / - N n / / - - / + - - N n / - - / - N n / / - N n / - 10 11 4 1
9 10 / / - N n / / - - N n / - - - / / N n / / - - - / - - - N n 10 12 4 0
10 11 / - - N n / / - N n / + - N n / - - - N n / / / - / / - N n 10 9 5 1
11 12 N n / / - - - / - N n / / - - N n / / / / - - - N n / - - - 10 12 4 0
12 13 / N n / - - N n / / - N n / / - N n / - / - - N n / - - / - 10 10 5 0
13 14 - - N n / - - - / - - - / / / / / N n / - - N n / - - / - N 10 13 4 0
14 15 n / / / - - + N n / / - - - N n / - - / N n / - - - / N n / 10 10 4 1
15 16 N n / - - N n / - - N n / - - / / / - - - N n / / - N n / / 10 10 5 0
16 17 - / - - N n / - / / - + N n / - - - N n / - - - / N n / / / 10 11 4 1
17 18 / - - / / N n / - - / - - + - N n / - - N n / / - N n / - / 10 11 4 1
18 19 - / / - / - - - / / / + - + - - - / - - - / - / - - - / - - 10 18 0 2
19 20 - - N n / / - N n / - - / - N n / - - - N n / / - N n / / - 9 11 5 0
20 21 - - - N n / - - N n / - - N n / / / N n / / - - / - N n / / 10 10 5 0
21 22 - N n / - - N n / / - N n / / - N n / / - N n / - - - / / - 10 10 5 0
22 23 N n / / - - + / - N n / - - / N n / - / - - / N n / - - N n 9 10 5 1
23 24 - - N n / - + - / / N n / / - - - / - - / - N n / - - N n / 9 12 4 1
24 25 n / / - - N n / - - / - N n / / - N n / - / - N n / / - - N 10 10 5 0
25 26 / / - - N n + / - - N n / - - / / - - N n / - - N n / / - - 9 12 4 1
26 27 n / / / - - N n / - - + - / - - N n / - + / N n / - - N n / 9 10 4 2
27 28 / / - - / - - - / / - N n / / - - - N n / - / - - - / - - N 10 15 3 0
answer.compare(shift_table)