如何进行一个简单的图编识别

如何进行一个简单的图编识别

本文出自实训学习内容,如有讲的不对的地方,敬请指正!

进行此工程之前,我们需要有足够多的图编数据集
1.
先编写如下的网络结构,取特征、特征分析,结果

import torch.nn as nn

class CNNet(nn.Module):

    def __init__(self):
        super(CNNet, self).__init__()
        self.cnn_layer = nn.Sequential(
            nn.Conv2d(3,16,3,1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(16,32,3,1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(32,64,3,1),
            nn.ReLU()
        )
        self.mlp_layer = nn.Sequential(
            nn.Linear(4*4*64, 128),
            nn.ReLU(),
            nn.Linear(128,10)
        )

    def forward(self, x):
        x = self.cnn_layer(x)
        x = x.reshape(-1,4*4*64)
        x = self.mlp_layer(x)
        return x

做一个训练器,对网络进行优化、提高识别精度

import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from MyNet import CNNet
import torchvision.transforms as trans
from matplotlib.pylab import plt
import os


class Train:

    def __init__(self):
        self.net = CNNet()
        save_path = "net.pt"
        if os.path.isfile(save_path):
            self.net = torch.load(save_path)
        self.loss_func = nn.MSELoss()
        self.optimier = torch.optim.Adam(self.net.parameters(),lr=2e-3,)
        self.dataset=self.get_dataset()

    def get_dataset(self):
        transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize([0.4919, 0.4827, 0.4472],[0.2470, 0.2434, 0.2616])
        ])

        trainData = CIFAR10(root="dataset", train=True, download=True, transform=transform)
        testData = CIFAR10(root="dataset", train=False, download=False, transform=transform)
        return trainData, testData

    def load_data(self, trainData, testData):
        trainloader = DataLoader(dataset=trainData, batch_size=500, shuffle=True)
        testloader = DataLoader(dataset=testData, batch_size=500, shuffle=True)
        return trainloader, testloader

    def train(self):

        trainloader,testloader=self.load_data(self.dataset[0],self.dataset[1])
        for i in range(10):
            print("epochs{}".format(i))
            for index,(input,target) in enumerate(trainloader):
                output=self.net(input)
                target=torch.nn.functional.one_hot(target)
                loss=self.loss_func(output,target.float())
                losses=[]
                if  index%10==0:
                    print("{}/{} loss:{}".format(index,len(trainloader),loss))

                self.optimier.zero_grad()
                loss.backward()
                self.optimier.step()
            count=0
            for input,target in testloader:
                output=self.net(input)
                predict=torch.argmax(output,dim=1)
                predict_value=((predict==target).sum())
                count+=predict_value
            print("精度:{}".format(count.item()/self.dataset[1].data.shape[0]))
            torch.save(self.net,"net.pt")

if __name__ == '__main__':

    t = Train()
    t.train()

如何进行一个简单的图编识别_第1张图片
此方法最高可将精度提高到70%左右

如何进行一个简单的图编识别_第2张图片
3.
trainer和net都已准备好后,接下来就开始进行识别测试

被识别物:horse.jpg
如何进行一个简单的图编识别_第3张图片
测试代码:

import torch
from PIL import Image
import torchvision.transforms as trans

path="horse.jpg"
lables=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

def testImage():
    net=torch.load("net.pt")
    img=Image.open(path)
    transform=trans.Compose([
        trans.Resize(32),
        trans.CenterCrop(32),
        trans.ToTensor(),
        trans.Normalize([0.4919, 0.4827, 0.4472], [0.2470, 0.2434, 0.2616])
    ])
    img=transform(img).unsqueeze(0)

    output=net(img)
    index=output.argmax(1)
    print("该图片是:{}".format(lables[index]))
if __name__ == '__main__':
        testImage()

测试结果
如何进行一个简单的图编识别_第4张图片
测试结果
如何进行一个简单的图编识别_第5张图片
测试结果

你可能感兴趣的:(学习,大学)