我使用的样本例子:
样本链接在最下边
训练过程第一个epoch直接就是99的成功率
拿下边的代码直接运行就完了
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import models
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import numpy as np
from tqdm import tqdm
import torchvision.transforms as T
IMAGE_SHAPE = (24, 100)
transform = T.Compose([
T.ToPILImage(),
T.Resize(IMAGE_SHAPE),
T.ToTensor(),
])
LABEL_MAP = [i for i in '0123456789']
# 命名不规范,请忽略。应该是验证码长度,因为是固定长度文本
Max_label_len = 5
class MyDataset(Dataset):
def __init__(self, data_path, label_map, max_label_len):
super(MyDataset, self).__init__()
self.data = [(os.path.join(data_path, file), file.split('.')[0]) for file in os.listdir(data_path)]
self.label_map = [char for char in label_map]
self.label_map_len = len(self.label_map)
self.max_label_len = max_label_len
def __getitem__(self, index):
file = self.data[index][0]
label = self.data[index][1]
raw_len = len(label)
im = np.fromfile(file, dtype=np.uint8)
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
im = transform(im)
label = [self.label_map.index(i) for i in label]
label = torch.as_tensor(label, dtype=torch.int64)
label = F.one_hot(label, num_classes=len(LABEL_MAP)).float()
return im, label, raw_len
def __len__(self):
return len(self.data)
class Net(nn.Module):
"""
这里用类的原因是为了好自定义网络结构
"""
def __init__(self):
super(Net, self).__init__()
self.resnet18 = models.resnet18(num_classes=Max_label_len * len(LABEL_MAP))
def forward(self, x):
x = self.resnet18(x)
return x
train = DataLoader(
dataset=MyDataset(r'../sample/train', label_map=LABEL_MAP, max_label_len=Max_label_len),
batch_size=32, shuffle=True,
num_workers=3)
test = DataLoader(
dataset=MyDataset(r'../sample/test', label_map=LABEL_MAP, max_label_len=Max_label_len),
batch_size=4, shuffle=True,
num_workers=0)
if __name__ == '__main__':
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net()
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.MSELoss()
scheduler = StepLR(optimizer, step_size=2, gamma=0.7)
for epoch in range(0, 100):
# Train
bar = tqdm(train, 'Training')
for x, label, _ in bar:
x, label = x.to(DEVICE), label.to(DEVICE)
out = model(x)
label = label.view(-1, Max_label_len * len(LABEL_MAP))
loss = loss_func(out, label)
# 快乐三步曲
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr = optimizer.param_groups[0]['lr']
bar.set_description("Train epoch %d, loss %.4f, lr %.6f" % (
epoch, loss.detach().cpu().numpy(), lr
))
# Valid
bar = tqdm(test, 'Training')
correct = count = 0
for x, label, _ in bar:
x, label = x.to(DEVICE), label.to(DEVICE)
out = model(x)
label_copy = label.view(-1, Max_label_len * len(LABEL_MAP))
loss = loss_func(out, label_copy)
out = out.view(-1, Max_label_len, len(LABEL_MAP)) # (BATCH_SIZE, 4, 28)
predict = torch.argmax(out, dim=2) # (BATCH_SIZE, 4)
label = torch.argmax(label, dim=2)
count += x.shape[0] * Max_label_len
correct += (predict == label).sum()
lr = optimizer.param_groups[0]['lr']
bar.set_description("Eval epoch %d, acc %.4f, loss %.4f, lr %.6f" % (
epoch, correct * 1.0 / count, loss.detach().cpu().numpy(), lr
))
scheduler.step(epoch)
torch.save(model.state_dict(), "models/save_%d.model" % epoch)
import torch
from torch import nn
from torchvision import models
import cv2
import numpy as np
import torchvision.transforms as T
IMAGE_SHAPE = (24, 100)
transform = T.Compose([
T.ToPILImage(),
T.Resize(IMAGE_SHAPE),
T.ToTensor(),
])
LABEL_MAP = [i for i in '0123456789']
Max_label_len = 5
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.resnet18 = models.resnet18(num_classes=Max_label_len * len(LABEL_MAP))
def forward(self, x):
x = self.resnet18(x)
return x
# 是否使用GPU
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = torch.device('cpu')
model = Net()
model.to(DEVICE)
# 载入训练模型
model.load_state_dict(torch.load("./models/save_9.model"))
model.eval()
def captcha(im):
im = transform(im)
im = im.to(DEVICE)
im = im.unsqueeze(0)
out = model(im)
out = out.view(-1, Max_label_len, len(LABEL_MAP))
predict = torch.argmax(out, dim=2)
label = predict.cpu().detach().numpy().tolist()[0]
return ''.join(str(i) for i in label)
# 调用方法
im = cv2.imread('path', cv2.IMREAD_COLOR)
ret = captcha(im)
print(ret)
链接: https://pan.baidu.com/s/1bEeqtlCFLqvAQcH3GqKXqQ 提取码: 7ue5