pytorch cifar-10 图像分类

图片的加载与显示

import torch
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np

transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomGrayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

image=trainset[0][0]
image=np.array(image)
c,h,w=image.shape
print('数据的格式',image.shape)
image=image.reshape([h,w,c])
print('修改后数据的格式',image.shape)

num = 40
col = 8
row = int(num / 8)
index = np.random.randint(1, len(trainset), num)
for i in range(num):
    for j in range(8):
        plt.subplot(row, col, i + 1)
        plt.xticks([])  # 去掉x轴的刻度
        plt.yticks([])  # 去掉y轴的刻度
        image=trainset[index[i]][0]
        image=np.array(image)  ## c,h,w
        image=image*0.5+0.5
        # image = image.reshape([h, w, c])  #这种方法存在问题
        image=np.transpose(image,(1,2,0))  ##h,w,c
        label=trainset[index[i]][1]
        plt.imshow(image, cmap='gray')
        plt.title(classes[label])  ##修改x,y的值可以将标题放在任意位置,y=0表示最下方
plt.show()

pytorch cifar-10 图像分类_第1张图片

模型的搭建

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomGrayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x=F.sigmoid(x)
        return x

model=Net()
image=torch.randn(1,3,32,32)
output=model(image)
print('output',output)
print('输出维度',output.shape)

 pytorch cifar-10 图像分类_第2张图片

 模型的训练

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import os,time
from tqdm import tqdm

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        # x=F.sigmoid(x)
        return x

def train(model):
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomGrayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./', train=True, download=True, transform=transform)
    print('训练集数量',len(trainset))
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    path = 'model.tar'
    initepoch = 0

    if os.path.exists(path) is not True:
        loss = nn.CrossEntropyLoss()
    else:
        # 如果存在已保存的权重,则加载
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        initepoch = checkpoint['epoch']
        loss = checkpoint['loss']

    for epoch in range(initepoch, 50):
        with tqdm(total=len(trainset)%100,ncols=80) as t:
            t.set_description('epoch: {}/{}'.format(epoch, 50))

            timestart = time.time()

            running_loss = 0.0
            correct = 0
            for i, data in enumerate(trainloader, 0):
                # get the inputs
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                l = loss(outputs, labels)
                l.backward()
                # outputs=torch.argsort(outputs,axis=1)[:,-1]
                _,outputs=torch.max(outputs.data,1)
                correct += (outputs == labels).sum().item()
                optimizer.step()

                running_loss += l.item()
                t.set_postfix(trainloss='{:.6f}'.format(running_loss))
                t.update(len(inputs))

        torch.save({'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'loss': loss
                            }, path)
        print('准确率',correct/len(trainset))
        print('epoch %d cost %3f sec' % (epoch, time.time() - timestart))

    print('Finished Training')
if __name__=='__main__':
    model=Net()
    train(model)

pytorch cifar-10 图像分类_第3张图片

交叉熵损失函数

loss = nn.CrossEntropyLoss()

 交叉熵损失函数原理详解_Cigar丶的博客-CSDN博客_交叉熵损失函数

模型的测试

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import itertools
from sklearn.metrics import confusion_matrix
import numpy as np
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        # x=F.sigmoid(x)
        return x

def test(model, testloader):
    correct = 0
    total = 0
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    true_label=[]
    pre_label=[]
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            pre_label.append(int(predicted.data))
            true_label.append(int(labels.data))

    print('Accuracy of the network on the 10000 test images: %.3f %%' % (100.0 * correct / total))

    return np.array(true_label),np.array(pre_label)
def plot_confusion_matrix(y_true, y_pred, title = "Confusion matrix",
                          cmap = plt.cm.Blues, save_flg = True):
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    labels = range(10)#数据集的标签类别,跟上面I对应
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    plt.figure(figsize=(14, 12))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, fontsize=40)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, fontsize=20)
    plt.yticks(tick_marks, classes, fontsize=20)
    print('Confusion matrix, without normalization')
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.ylabel('True label', fontsize=30)
    plt.xlabel('Predicted label', fontsize=30)
    if save_flg:
        plt.savefig("./confusion_matrix.png")
    plt.show()

if __name__=='__main__':
    model=Net()
    hh=torch.load('./model.pth',map_location=lambda storage,loc:storage)
    model.load_state_dict(hh['model_state_dict'])
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomGrayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    true_label,pre_label=test(model,testloader)
    plot_confusion_matrix(true_label,pre_label)

pytorch cifar-10 图像分类_第4张图片

pytorch cifar-10 图像分类_第5张图片 

 

你可能感兴趣的:(pytorch,pytorch,分类,深度学习)