这其实是一个多标签分类问题,每个验证码图片有4个字符(标签),并且顺序固定;只要将卷积神经网络的最后一层稍加修改就能实现多标签分类。
如下图所示,我们的验证码一共有4个数字,将4个数字转换成40位one_hot形式,输出层的[0-9]输出值对应第一个字符的onehot编码,[10-19]输出值对应第二个字符的onehot编码,[20-29]输出值对应第三个字符,[30-39]输出值对于第四个字符,并使用pytorch的多标签分类函数nn.MultiLabelSoftMarginLoss作为损失函数。
训练集800张图片,测试集200张,每张图片大小20*60
模式结构:
CNN (
(conv1): Sequential (
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU ()
(2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
)
(conv2): Sequential (
(0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU ()
(2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
)
(out): Linear (624 -> 40)
)
# coding: utf-8
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import math
import csv
import cv2
#读取标签
csvfile = open('GenPics/label.csv')
reader = csv.reader(csvfile)
lables = []
for line in reader:
tmpLine = [line[0],line[1]]
lables.append(tmpLine)
csvfile.close()
X = []
y = []
#读入图片
picnum = len(lables)
print("picnum : ", picnum)
for i in range(0, picnum):
img_name = "GenPics/" + lables[i][0] + '.jpg'
img = cv2.imread(img_name, cv2.IMREAD_GRAYSCALE)
X.append(img)
y.append(lables[i][1])
tmp = []
for i in range(len(y)):
c0 = int(y[i][0])
c1 = int(y[i][1])
c2 = int(y[i][2])
c3 = int(y[i][3])
tmp.append(c0)
tmp.append(c1)
tmp.append(c2)
tmp.append(c3)
#处理成one_hot形式
X = np.array(X)
X = torch.from_numpy(X)
X = torch.unsqueeze(X, dim=1)
X = X.type(torch.FloatTensor)/255.
batch_size = 4000
yt = torch.LongTensor(tmp)
yt = torch.unsqueeze(yt, 1)
yt_onehot = torch.FloatTensor(batch_size, 10)
yt_onehot.zero_()
yt_onehot.scatter_(1, yt, 1)
yt_onehot = yt_onehot.view(-1, 40)
y = yt_onehot
#划分训练集和测试集
train_x = X[:800]
train_y = y[:800]
test_x = X[800:]
test_x = Variable(test_x, volatile=True)
test_y = y[800:]
#定义模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=32,
kernel_size=3,
stride=1,
padding=0,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 16, 3, 1, 0),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.out = nn.Linear(16*3*13, 40)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
output = self.out(x)
return output
cnn = CNN()
print(cnn)
# CNN (
# (conv1): Sequential (
# (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
# (1): ReLU ()
# (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
# )
# (conv2): Sequential (
# (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
# (1): ReLU ()
# (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
# )
# (out): Linear (624 -> 40)
# )
#定义优化模型和损失函数
batsize = 8
epochs = 10
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
loss_func = nn.MultiLabelSoftMarginLoss()
#进行迭代训练
for epoch in range(epochs):
losses = []
iters = int(math.ceil(train_x.shape[0]/batsize))
for i in range(iters):
train_x_i = train_x[i*batsize: (i+1)*batsize]
train_y_i = train_y[i*batsize: (i+1)*batsize]
tx = Variable(train_x_i)
ty = Variable(train_y_i)
out = cnn(tx)
loss = loss_func(out, ty)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.data.mean())
print('[%d/%d] Loss: %.3f' % (epoch+1, epochs, np.mean(losses)))
# [1/10] Loss: 0.352
# [2/10] Loss: 0.322
# [3/10] Loss: 0.244
# [4/10] Loss: 0.100
# [5/10] Loss: 0.053
# [6/10] Loss: 0.040
# [7/10] Loss: 0.035
# [8/10] Loss: 0.031
# [9/10] Loss: 0.028
# [10/10] Loss: 0.026
#测试集验证准确率
test_output = cnn(test_x)
correct_num = 0
for i in range(test_output.size()[0]):
c0 = np.argmax(test_output[i, 0:10].data.numpy())
c1 = np.argmax(test_output[i, 10:20].data.numpy())
c2 = np.argmax(test_output[i, 20:30].data.numpy())
c3 = np.argmax(test_output[i, 30:40].data.numpy())
c = '%s%s%s%s' % (c0, c1, c2, c3)
if c == lables[800+i][1]:
correct_num += 1
print("Test accurate :", float(correct_num)/ len(test_output))
# Test accurate : 0.98
#单个图片验证
img_path = 'test2.jpg'
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
imgArr = np.array(img)
imgArr = np.expand_dims(imgArr, axis=0)
imgArr = torch.from_numpy(imgArr)
imgArr = torch.unsqueeze(imgArr, dim=1)
imgArr = imgArr.type(torch.FloatTensor)/255.
imgArr = Variable(imgArr, volatile=True)
pred_img = cnn(imgArr)
c0 = np.argmax(pred_img[0, 0:10].data.numpy())
c1 = np.argmax(pred_img[0, 10:20].data.numpy())
c2 = np.argmax(pred_img[0, 20:30].data.numpy())
c3 = np.argmax(pred_img[0, 30:40].data.numpy())
c = '%s%s%s%s' % (c0, c1, c2, c3)
print(c)
# 5955
import matplotlib.pyplot as plt
img = plt.imread(img_path)
plt.imshow(img)
plt.show()
参考引用:https://github.com/junliangliu/captcha