本文出自实训学习内容,如有讲的不对的地方,敬请指正!
进行此工程之前,我们需要有足够多的图编数据集
1.
先编写如下的网络结构,取特征、特征分析,结果
import torch.nn as nn
class CNNet(nn.Module):
def __init__(self):
super(CNNet, self).__init__()
self.cnn_layer = nn.Sequential(
nn.Conv2d(3,16,3,1),
nn.ReLU(),
nn.MaxPool2d(2,2),
nn.Conv2d(16,32,3,1),
nn.ReLU(),
nn.MaxPool2d(2,2),
nn.Conv2d(32,64,3,1),
nn.ReLU()
)
self.mlp_layer = nn.Sequential(
nn.Linear(4*4*64, 128),
nn.ReLU(),
nn.Linear(128,10)
)
def forward(self, x):
x = self.cnn_layer(x)
x = x.reshape(-1,4*4*64)
x = self.mlp_layer(x)
return x
做一个训练器,对网络进行优化、提高识别精度
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from MyNet import CNNet
import torchvision.transforms as trans
from matplotlib.pylab import plt
import os
class Train:
def __init__(self):
self.net = CNNet()
save_path = "net.pt"
if os.path.isfile(save_path):
self.net = torch.load(save_path)
self.loss_func = nn.MSELoss()
self.optimier = torch.optim.Adam(self.net.parameters(),lr=2e-3,)
self.dataset=self.get_dataset()
def get_dataset(self):
transform = trans.Compose([
trans.ToTensor(),
trans.Normalize([0.4919, 0.4827, 0.4472],[0.2470, 0.2434, 0.2616])
])
trainData = CIFAR10(root="dataset", train=True, download=True, transform=transform)
testData = CIFAR10(root="dataset", train=False, download=False, transform=transform)
return trainData, testData
def load_data(self, trainData, testData):
trainloader = DataLoader(dataset=trainData, batch_size=500, shuffle=True)
testloader = DataLoader(dataset=testData, batch_size=500, shuffle=True)
return trainloader, testloader
def train(self):
trainloader,testloader=self.load_data(self.dataset[0],self.dataset[1])
for i in range(10):
print("epochs{}".format(i))
for index,(input,target) in enumerate(trainloader):
output=self.net(input)
target=torch.nn.functional.one_hot(target)
loss=self.loss_func(output,target.float())
losses=[]
if index%10==0:
print("{}/{} loss:{}".format(index,len(trainloader),loss))
self.optimier.zero_grad()
loss.backward()
self.optimier.step()
count=0
for input,target in testloader:
output=self.net(input)
predict=torch.argmax(output,dim=1)
predict_value=((predict==target).sum())
count+=predict_value
print("精度:{}".format(count.item()/self.dataset[1].data.shape[0]))
torch.save(self.net,"net.pt")
if __name__ == '__main__':
t = Train()
t.train()
3.
trainer和net都已准备好后,接下来就开始进行识别测试
import torch
from PIL import Image
import torchvision.transforms as trans
path="horse.jpg"
lables=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
def testImage():
net=torch.load("net.pt")
img=Image.open(path)
transform=trans.Compose([
trans.Resize(32),
trans.CenterCrop(32),
trans.ToTensor(),
trans.Normalize([0.4919, 0.4827, 0.4472], [0.2470, 0.2434, 0.2616])
])
img=transform(img).unsqueeze(0)
output=net(img)
index=output.argmax(1)
print("该图片是:{}".format(lables[index]))
if __name__ == '__main__':
testImage()