代码来源
GitHub - mivlab/AI_course
一、对于图像的训练
数据集:https://pan.baidu.com/s/18Fz9Cpj0Lf9BC7As8frZrw 提取码:xhgk
注意:训练时需要在添加参数,即训练集的目录
训练集有60000万张0-9的手写字符图片
import torch import math import torch.nn as nn from torch.autograd import Variable from torchvision import transforms, models import argparse import os from torch.utils.data import DataLoader from dataloader import mnist_loader as ml from models.cnn import Net from toonnx import to_onnx parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--datapath', required=True, help='data path') parser.add_argument('--batch_size', type=int, default=256, help='training batch size') parser.add_argument('--epochs', type=int, default=30, help='number of epochs to train') parser.add_argument('--use_cuda', default=False, help='using CUDA for training') args = parser.parse_args() args.cuda = args.use_cuda and torch.cuda.is_available() if args.cuda: torch.backends.cudnn.benchmark = True def train(): os.makedirs('./output', exist_ok=True) if True: #not os.path.exists('output/total.txt'): ml.image_list(args.datapath, 'output/total.txt') ml.shuffle_split('output/total.txt', 'output/train.txt', 'output/val.txt') train_data = ml.MyDataset(txt='output/train.txt', transform=transforms.ToTensor()) val_data = ml.MyDataset(txt='output/val.txt', transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size) model = Net() #model = models.resnet18(num_classes=10) # 调用内置模型 #model.load_state_dict(torch.load('./output/params_10.pth')) if args.cuda: print('training with cuda') model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20], 0.1) loss_func = nn.CrossEntropyLoss() for epoch in range(args.epochs): # training----------------------------------- model.train() train_loss = 0 train_acc = 0 for batch, (batch_x, batch_y) in enumerate(train_loader): if args.cuda: batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda()) else: batch_x, batch_y = Variable(batch_x), Variable(batch_y) out = model(batch_x) # 256x3x28x28 out 256x10 loss = loss_func(out, batch_y) train_loss += loss.item() pred = torch.max(out, 1)[1] train_correct = (pred == batch_y).sum() train_acc += train_correct.item() print('epoch: %2d/%d batch %3d/%d Train Loss: %.3f, Acc: %.3f' % (epoch + 1, args.epochs, batch, math.ceil(len(train_data) / args.batch_size), loss.item(), train_correct.item() / len(batch_x))) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() # 更新learning rate print('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/args.batch_size)), train_acc / (len(train_data)))) # evaluation-------------------------------- model.eval() eval_loss = 0 eval_acc = 0 for batch_x, batch_y in val_loader: if args.cuda: batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda()) else: batch_x, batch_y = Variable(batch_x), Variable(batch_y) out = model(batch_x) loss = loss_func(out, batch_y) eval_loss += loss.item() pred = torch.max(out, 1)[1] num_correct = (pred == batch_y).sum() eval_acc += num_correct.item() print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/args.batch_size)), eval_acc / (len(val_data)))) # save model -------------------------------- if (epoch + 1) % 1 == 0: # torch.save(model, 'output/model_' + str(epoch+1) + '.pth') torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth') #to_onnx(model, 3, 28, 28, 'params.onnx') if __name__ == '__main__': train()
以上为训练代码,训练的主要过程是通过cv2读取.jpg格式的图片进行训练;
分为30个epoch,一个epoch188个样本;
训练网络的结构是三个卷积块连接,每一个卷积块有卷积,激活,最大池化操作。
以上是训练一个epoch得到的结果,可以看到 测试集的成功率大概为0.96,还是很不错的。
在训练代码中使用的pytorch的函数可以在pytorch的主页中进行搜索,可以了解该函数的用法及作用。
二、测试过程
以下为测试代码
import torch import cv2 from torch.autograd import Variable from torchvision import transforms from models.cnn import Net from toonnx import to_onnx use_cuda = False model = Net() model.load_state_dict(torch.load('output/params_1.pth')) # model = torch.load('output/model.pth') model.eval() if use_cuda and torch.cuda.is_available(): model.cuda() to_onnx(model, 3, 28, 28, 'output/params.onnx') img = cv2.imread('4_00440.jpg') img_tensor = transforms.ToTensor()(img) img_tensor = img_tensor.unsqueeze(0) if use_cuda and torch.cuda.is_available(): prediction = model(Variable(img_tensor.cuda())) else: prediction = model(Variable(img_tensor)) pred = torch.max(prediction, 1)[1] print(pred) cv2.imshow("image", img) cv2.waitKey(0)
测试结果为tensor([4])
而我们屏幕上反馈的图片也是字符4,这说明我们的训练可以说是成功的。
以上就是python 手写字符识别的代码用例,有问题可留言。