abc.ttf是字体文件,用于在图片上标注预测标签
model.py是AlexNet模型代码
train.py用于训练
test.py用于单张图片测试
MNIST文件夹是数据集文件
将train.py调整一下结构后,三部分代码内容为:
train.py
'''
复现AlexNet
MNIST 数据集 大小28x28
'''
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as transform
from torch.utils.data import DataLoader
from model import AlexNet
'''
MNIST数据集下载与准备
'''
data_train = MNIST('./MNIST/data',
download = True,
transform = transform.Compose([
transform.Resize((32,32)),
transform.ToTensor()
])
)
data_test = MNIST('./MNIST/data',
train = 'False',
download = True,
transform = transform.Compose([
transform.Resize((32,32)),
transform.ToTensor()t
])
)
data_train_loader = DataLoader(data_train,batch_size=32,shuffle=True)
data_test_loader = DataLoader(data_test,batch_size=32,shuffle=True)
alexNet = AlexNet()
'''
准备训练参数
'''
alexNet.train()
lr = 0.01
criterion = nn.CrossEntropyLoss() #定义损失函数
optimizer = torch.optim.SGD(alexNet.parameters(),lr=lr,momentum=0.9,weight_decay=5e-4)
train_loss = 0
correct = 0
total = 0
index = 0
'''
训练
'''
for batch_idx,(inputs,targets) in enumerate(data_train_loader):
optimizer.zero_grad()
outputs = alexNet(inputs)
loss = criterion(outputs,targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_,predicted = outputs.max(1) #返回outputs每一行最大值组成的一维数组
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
print(batch_idx,'Loss: %.3f | Acc: %.3f'%(train_loss/(batch_idx+1),100*correct/total))
# if batch_idx % 300 == 0:
# index += 1
# torch.save(alexNet, './MINIST_AlexNet_{}.pt'.format(index))
'''
保存最后一次的模型
'''
torch.save(alexNet, './MINIST_AlexNet_last.pt')
model.py
import torch
import torch.nn as nn
from model import AlexNet
'''
AlexNet定义
'''
class AlexNet(nn.Module):
def __init__(self):
super(AlexNet,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1,64,5), #输入通道3 输出通道64 卷积核大小5
nn.ReLU(True)
)
self.max_pool1=nn.MaxPool2d(3,2) #卷积核3 步长2
self.conv2 = nn.Sequential(
nn.Conv2d(64,64,5), #输入通道3 输出通道64 卷积核大小5
nn.ReLU(True)
)
self.max_pool2=nn.MaxPool2d(3,2) #卷积核3 步长2
self.fc1 = nn.Sequential( #输入1024 输出384
nn.Linear(1024,384),
nn.ReLU(True)
)
self.fc2 = nn.Sequential(
nn.Linear(384,192),
nn.ReLU(True)
)
self.fc3 = nn.Linear(192,10)
def forward(self,x):
x = self.conv1(x)
# print(x.shape)
x = self.max_pool1(x)
# print(x.shape)
x = self.conv2(x)
# print(x.shape)
x = self.max_pool2(x)
# print(x.shape)
#将张量x展平为向量
# print(x.shape[0])
x = x.view(x.shape[0], -1)
# print(x.shape)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
test.py : 测试单张图片
import torch
import torch.nn as nn
from model import AlexNet
from PIL import Image
import numpy as np
import torchvision.transforms as transform
model_path = './MINIST_AlexNet_last.pt'
alexnet = torch.load(model_path)
'''
读入一张图片
转换成AlexNet可以处理的格式
'''
def readImage(path='3.jpg'):
mode = Image.open(path)
transform1 = transform.Compose([
transform.Resize((32,32)),
transform.Grayscale(1),
transform.ToTensor()
])
mode = transform1(mode)
return mode
img = readImage()
# print(img.shape)
img.unsqueeze_(0) #增加一个维度 第0维度增加
# print(img.shape)
'''
预测
'''
_,pre = alexnet(img).max(1)
print(pre.item())
更新后的test.py代码
import torch
import torch.nn as nn
from model import AlexNet
from PIL import Image,ImageDraw,ImageFont
import numpy as np
import torchvision.transforms as transform
model_path = './MINIST_AlexNet_last.pt'
alexnet = torch.load(model_path)
'''
读入一张图片
转换成AlexNet可以处理的格式
'''
def readImage(path='2.jpg'):
mode = Image.open(path)
transform1 = transform.Compose([
transform.Resize((32,32)),
transform.Grayscale(1),
transform.ToTensor()
])
mode = transform1(mode)
return mode
'''
在图片上添加预测标签
在图片上添加数字,imageFile为要添加数字的图片文件,
targetImageFile为添加数字后保存的图片文件,txtnum为添加的数字
'''
def DrawImageTxt(imageFile,targetImageFile,txtnum):
#设置字体大小
font = ImageFont.truetype('abc.ttf', 100)
#打开文件
im = Image.open(imageFile)
#字体坐标
draw = ImageDraw.Draw(im)
draw.text((0,0), txtnum , (255,255,0), font=font)
#保存
im.save(targetImageFile)
#关闭
im.close()
if __name__ == "__main__":
img = readImage("./5.jpg")
# print(img.shape)
img.unsqueeze_(0) #增加一个维度 第0维度增加
# print(img.shape)
'''
预测
'''
_,pre = alexnet(img).max(1)
imageFile = './5.jpg'
targetImageFile = './5_pre.jpg'
txtnum = str(pre.item())
DrawImageTxt(imageFile,targetImageFile,txtnum)
测试图片
预测结果:
所有文件我已经上传到CSDN,欢迎下载!
下载链接:https://download.csdn.net/download/qq_41964545/15600435
如果您觉得有用的话,不妨帮我点个关注和赞哦!