网络模型(Seq2Seq-注意力机制-编解码)

概念

用于处理序列问题:翻译(N vs N)、信息提取(N vs 1)、生成(1 vs N)。
网络模型(Seq2Seq-注意力机制-编解码)_第1张图片
RNN 要求输入队列和输出队列等长,Seq2Seq 可以解决输入队列与输出队列不等长的问题。

实验(验证码识别)

数据集:生成 4 位数字的验证码图片(测试集和训练集各 1000 张),图片名称为 index.code.jpg,截取 code 作为标签。

网络结构:

  • 编码:全连接 + 标准化(BN)+ 激活(ReLU)+ LSTM。
  • 解码:LSTM + 全连接 + softmax(多分类)。

优化器:Adam。

损失函数:均方差(MSELoss)。

输出:4 个 one-hot 类型,结果为最大的索引值。

生成验证码

import random
from PIL import Image, ImageDraw, ImageFont


# 随机数字
def rand_char():
    return chr(random.randint(48, 57))


# 随机背景颜色
def rand_bg():
    return (random.randint(50, 150), random.randint(50, 150), random.randint(50, 150))


# 随机数字颜色
def rand_color():
    return (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))


width = 240
height = 60
font = ImageFont.truetype("arial.ttf", size=36)
for i in range(1000):
    img = Image.new("RGB", (width, height), (255, 255, 255))
    draw = ImageDraw.ImageDraw(img)
    # 画背景
    for x in range(width):
        for y in range(height):
            draw.point((x, y), rand_bg())
    # 写数字
    chrs = []
    for n in range(4):
        each = rand_char()
        chrs.append(each)
        draw.text((n * 60 + 10, 10), each, rand_color(), font)

    image = image.filter(ImageFilter.BLUR)
    img.save("data/{}.{}{}{}{}.jpg".format(i, chrs[0], chrs[1], chrs[2], chrs[3]))
    img.save("test/{}.{}{}{}{}.jpg".format(i, chrs[0], chrs[1], chrs[2], chrs[3]))

数据集

import torch
from torch.utils.data import Dataset
from torchvision import transforms
import os
from PIL import Image


class MyDataset(Dataset):
    def __init__(self, path):
        # 数据标准化
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
        self.imgs = os.listdir(path)
        self.path = path

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.path, self.imgs[index]))
        img = self.transform(img)
        label = self.imgs[index].split(".")[1]
        label = self.one_hot(label)
        return img, label

    # 把标签转为 one-hot 格式
    def one_hot(self, x):
        result = torch.zeros(4, 10)
        for i in range(4):
            result[i][int(x[i])] = 1
        return result

网络

import torch
from torch import nn
from torch.nn import functional as f


# 编码器
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        # 全连接 + 标准化(BN) + 激活(ReLU)
        self.mlp = nn.Sequential(nn.Linear(180, 128), nn.BatchNorm1d(128), nn.ReLU())
        self.lstm = nn.LSTM(128, 128, 2, batch_first=True)

    def forward(self, x):
        # [n,c,h,w] → [n,c*h,w] (验证码是横向的,所以竖着切)
        x = x.reshape(-1, 180, 240)
        # [n,c*h,w] → [n,w,c*h]
        x = x.permute(0, 2, 1)
        # [n,w,c*h] → [n*w,c*h] (把 c*h 作为输入参数)
        x = x.reshape(-1, 180)
        out = self.mlp(x)
        # [n*w,128] → [n,w,128] (w 是数据长度,要切 w 次,有 w 个输出)
        out = out.reshape(-1, 240, 128)
        out, _ = self.lstm(out)
        # [n,w,128] → [n,128] (取最后一个输出)
        out = out[:, -1, :]
        return out


# 解码器
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(128, 128, 2, batch_first=True)
        # 输出层:全连接,返回(4 个值)
        self.mlp = nn.Linear(128, 10)

    def forward(self, x):
        # [n,128] → [n,1,128]
        x = x.reshape(-1, 1, 128)
        # [n,1,128] → [n,4,128] (输入长度为 4 的数据)
        x = x.expand(-1, 4, 128)
        out, _ = self.lstm(x)
        # [n,4,128] → [n*4,128]
        out = out.reshape(-1, 128)
        out = self.mlp(out)
        # [n*4,10] → [n,4,10]
        out = out.reshape(-1, 4, 10)
        # 输出层:返回 one-hot 类型(4 个十分类)
        out = f.softmax(out, 2)
        return out


# 主网络
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        out = self.encoder(x)
        out = self.decoder(out)
        return out

训练

from dataset import MyDataset
from net import MyNet

import torch
from torch import nn
from torch.utils.data import DataLoader
import os
import numpy as np


batch_size = 100
net_path = r"modules/mynet.pth"

is_train = True

# 数据集
train_path = r"data/train_dataset"
test_path = r"data/test_dataset"
if is_train:
    dataset = MyDataset(train_path)
    dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=4)
else:
    dataset = MyDataset(test_path)
    dataloader = DataLoader(dataset, batch_size, shuffle=False)


if __name__ == '__main__':
    # 加载网络
    if os.path.isfile(net_path):
        net = torch.load(net_path)
    else:
        net = MyNet()
    opt = torch.optim.Adam([{"params": net.encoder.parameters()}, {"params": net.decoder.parameters()}])
    loss_fn = nn.MSELoss()

    if is_train:
        # 训练
        net.train()
        while True:
            for i, (x, y) in enumerate(dataloader):
                out = net(x)
                loss = loss_fn(out, y)
                opt.zero_grad()
                loss.backward()
                opt.step()
                # 结果是 one-hot 类型,取最大索引
                result = torch.argmax(out, 2).numpy()
                label = torch.argmax(y, 2).numpy()
                acc = np.mean(np.all(result == label, axis=1))
                print("i:{},loss:{:.5},acc:{:.3}".format(i, loss, acc))
            # 保存网络
            torch.save(net, net_path)
    else:
        # 测试
        net.eval()
        for x, y in dataloader:
            out = net(x)
            result = torch.argmax(out[0], 1)
            print("result:{}".format(result))
            label = torch.argmax(y[0], 1)
            print("label:{}".format(label))

你可能感兴趣的:(AI)