基于Pytorch深度学习神经网络CNN花朵图像识别系统

第一步:导入头文件

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim

第二步:数据预处理

def loadtraindata():
    # 路径
    path = "data\\train"
    trainset = torchvision.datasets.ImageFolder(path, transform=transforms.Compose([
        # 将图片缩放到指定大小(h,w)或者保持长宽比并缩放最短的边到int大小
        transforms.Resize((32, 32)),
        transforms.ToTensor()])
                                                )
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    return trainloader

第三步:构建模型网络

开源网络结构

第四步:预测(提供正确率,混淆矩阵结果)

def test():
    net = reload_net()
    net.eval()
    testloader = loadtestdata()
    sum = 0
    num = 0
    labels_list = []
    predicted_list = []
    for images, labels in testloader:
        sum += 1
        outputs = net(Variable(images))
        _, predicted = torch.max(outputs.data, 1)
        if labels == predicted:
            num += 1

        labels_list.append(labels.numpy()[0])
        predicted_list.append(predicted.numpy()[0])
    print('正确率: ' + str(num*1.0 / sum))
    cm = confusion_matrix(labels_list, predicted_list)
    print(cm) #打印混淆矩阵的值
    plot_confusion_matrix(cm, classes, "Confusion Matrix")
    plt.show()

结果:

1)代码结构

基于Pytorch深度学习神经网络CNN花朵图像识别系统_第1张图片

2)预测的一些结果

基于Pytorch深度学习神经网络CNN花朵图像识别系统_第2张图片

3)带界面运行结果

基于Pytorch深度学习神经网络CNN花朵图像识别系统_第3张图片

你可能感兴趣的:(pytorch小demo,图像识别,深度学习,神经网络)