用于处理序列问题:翻译(N vs N)、信息提取(N vs 1)、生成(1 vs N)。
RNN 要求输入队列和输出队列等长,Seq2Seq 可以解决输入队列与输出队列不等长的问题。
数据集:生成 4 位数字的验证码图片(测试集和训练集各 1000 张),图片名称为 index.code.jpg,截取 code 作为标签。
网络结构:
优化器:Adam。
损失函数:均方差(MSELoss)。
输出:4 个 one-hot 类型,结果为最大的索引值。
生成验证码
import random
from PIL import Image, ImageDraw, ImageFont
# 随机数字
def rand_char():
return chr(random.randint(48, 57))
# 随机背景颜色
def rand_bg():
return (random.randint(50, 150), random.randint(50, 150), random.randint(50, 150))
# 随机数字颜色
def rand_color():
return (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
width = 240
height = 60
font = ImageFont.truetype("arial.ttf", size=36)
for i in range(1000):
img = Image.new("RGB", (width, height), (255, 255, 255))
draw = ImageDraw.ImageDraw(img)
# 画背景
for x in range(width):
for y in range(height):
draw.point((x, y), rand_bg())
# 写数字
chrs = []
for n in range(4):
each = rand_char()
chrs.append(each)
draw.text((n * 60 + 10, 10), each, rand_color(), font)
image = image.filter(ImageFilter.BLUR)
img.save("data/{}.{}{}{}{}.jpg".format(i, chrs[0], chrs[1], chrs[2], chrs[3]))
img.save("test/{}.{}{}{}{}.jpg".format(i, chrs[0], chrs[1], chrs[2], chrs[3]))
数据集
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import os
from PIL import Image
class MyDataset(Dataset):
def __init__(self, path):
# 数据标准化
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
self.imgs = os.listdir(path)
self.path = path
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.imgs[index]))
img = self.transform(img)
label = self.imgs[index].split(".")[1]
label = self.one_hot(label)
return img, label
# 把标签转为 one-hot 格式
def one_hot(self, x):
result = torch.zeros(4, 10)
for i in range(4):
result[i][int(x[i])] = 1
return result
网络
import torch
from torch import nn
from torch.nn import functional as f
# 编码器
class Encoder(nn.Module):
def __init__(self):
super().__init__()
# 全连接 + 标准化(BN) + 激活(ReLU)
self.mlp = nn.Sequential(nn.Linear(180, 128), nn.BatchNorm1d(128), nn.ReLU())
self.lstm = nn.LSTM(128, 128, 2, batch_first=True)
def forward(self, x):
# [n,c,h,w] → [n,c*h,w] (验证码是横向的,所以竖着切)
x = x.reshape(-1, 180, 240)
# [n,c*h,w] → [n,w,c*h]
x = x.permute(0, 2, 1)
# [n,w,c*h] → [n*w,c*h] (把 c*h 作为输入参数)
x = x.reshape(-1, 180)
out = self.mlp(x)
# [n*w,128] → [n,w,128] (w 是数据长度,要切 w 次,有 w 个输出)
out = out.reshape(-1, 240, 128)
out, _ = self.lstm(out)
# [n,w,128] → [n,128] (取最后一个输出)
out = out[:, -1, :]
return out
# 解码器
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(128, 128, 2, batch_first=True)
# 输出层:全连接,返回(4 个值)
self.mlp = nn.Linear(128, 10)
def forward(self, x):
# [n,128] → [n,1,128]
x = x.reshape(-1, 1, 128)
# [n,1,128] → [n,4,128] (输入长度为 4 的数据)
x = x.expand(-1, 4, 128)
out, _ = self.lstm(x)
# [n,4,128] → [n*4,128]
out = out.reshape(-1, 128)
out = self.mlp(out)
# [n*4,10] → [n,4,10]
out = out.reshape(-1, 4, 10)
# 输出层:返回 one-hot 类型(4 个十分类)
out = f.softmax(out, 2)
return out
# 主网络
class MyNet(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
out = self.encoder(x)
out = self.decoder(out)
return out
训练
from dataset import MyDataset
from net import MyNet
import torch
from torch import nn
from torch.utils.data import DataLoader
import os
import numpy as np
batch_size = 100
net_path = r"modules/mynet.pth"
is_train = True
# 数据集
train_path = r"data/train_dataset"
test_path = r"data/test_dataset"
if is_train:
dataset = MyDataset(train_path)
dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=4)
else:
dataset = MyDataset(test_path)
dataloader = DataLoader(dataset, batch_size, shuffle=False)
if __name__ == '__main__':
# 加载网络
if os.path.isfile(net_path):
net = torch.load(net_path)
else:
net = MyNet()
opt = torch.optim.Adam([{"params": net.encoder.parameters()}, {"params": net.decoder.parameters()}])
loss_fn = nn.MSELoss()
if is_train:
# 训练
net.train()
while True:
for i, (x, y) in enumerate(dataloader):
out = net(x)
loss = loss_fn(out, y)
opt.zero_grad()
loss.backward()
opt.step()
# 结果是 one-hot 类型,取最大索引
result = torch.argmax(out, 2).numpy()
label = torch.argmax(y, 2).numpy()
acc = np.mean(np.all(result == label, axis=1))
print("i:{},loss:{:.5},acc:{:.3}".format(i, loss, acc))
# 保存网络
torch.save(net, net_path)
else:
# 测试
net.eval()
for x, y in dataloader:
out = net(x)
result = torch.argmax(out[0], 1)
print("result:{}".format(result))
label = torch.argmax(y[0], 1)
print("label:{}".format(label))