【365计划-2】pytorch实现彩色图片识别

本文为365天深度学习训练营 中的学习记录博客
参考文章地址: 365天深度学习训练营-第P2周:彩色图片识别
作者:K同学啊

###本项目来自K同学在线指导###

import torch
from torchsummary import summary
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import os,time,warnings

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:",device)
ROOT_FOLDER="data"
CIFAR10_FOLDER=os.path.join(ROOT_FOLDER,"cifar-10-batches-py")
if not os.path.exists(CIFAR10_FOLDER) or not os.path.isdir(CIFAR10_FOLDER):
    print("开始下载")
    train_ds=torchvision.datasets.CIFAR10(ROOT_FOLDER,train=True,transform=torchvision.transforms.ToTensor(),download=True)
    test_ds=torchvision.datasets.CIFAR10(ROOT_FOLDER,train=False,transform=torchvision.transforms.ToTensor(),download=True)
else:
    print("数据集已下载")
    train_ds=torchvision.datasets.CIFAR10(ROOT_FOLDER,train=True,transform=torchvision.transforms.ToTensor(),download=True)
    test_ds=torchvision.datasets.CIFAR10(ROOT_FOLDER,train=False,transform=torchvision.transforms.ToTensor(),download=False)

batch_size=128
train_dl=torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True)
test_dl=torch.utils.data.DataLoader(test_ds,batch_size=batch_size)
imgs,labels=next(iter(train_dl))
print("imges shape:",imgs.shape)

plt.figure(figsize=(20,5))

for i,imge in enumerate(imgs[:20]):
    #维度顺序调整{3,32,32}---》{32,32,3}
    #npimg=np.squeeze(imge.numpy())
    npimg=imge.numpy().transpose(1,2,0)
    plt.subplot(2,10,i+1)
    plt.imshow(npimg,cmap=plt.cm.binary)
    plt.axis('off')

num_classes=10

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        #特征提取网络
        self.conv1=nn.Conv2d(3,64,kernel_size=3)
        self.pool1=nn.MaxPool2d(kernel_size=2)
        self.drop1=nn.Dropout(p=0.15)
        self.conv2=nn.Conv2d(64,64,kernel_size=3)
        self.pool2=nn.MaxPool2d(kernel_size=2)
        self.drop2=nn.Dropout(p=0.15)
        self.conv3=nn.Conv2d(64,128,kernel_size=3)
        self.pool3=nn.MaxPool2d(kernel_size=2)
        self.drop3=nn.Dropout(p=0.15)
        #分类网络
        self.fc1=nn.Linear(512,256)
        self.fc2=nn.Linear(256,num_classes)
    def forward(self,x):
        x=self.drop1(self.pool1(F.relu(self.conv1(x))))
        x=self.drop2(self.pool2(F.relu(self.conv2(x))))
        x=self.drop3(self.pool3(F.relu(self.conv3(x))))
        x=torch.flatten(x,start_dim=1)
        #x=x.view(batch_size,-1) #此处可替代
        x=F.relu(self.fc1(x))
        x=self.fc2(x)
        return x

model=Model().to(device)
summary(model,input_size=(3,32,32))

#超级参数
loss_fn=nn.CrossEntropyLoss()
learn_rate=1e-3
optimizer=torch.optim.SGD(model.parameters(),lr=learn_rate)

def train(dataloader,model,loss_fn,opt):
    size=len(dataloader.dataset)
    num_batches=len(dataloader)
    train_acc,train_loss=0,0
    for x,y in dataloader:
        x=x.to(device)
        y=y.to(device)
        pre=model(x)
        loss=loss_fn(pre,y)
        opt.zero_grad()
        loss.backward()
        opt.step()

        train_acc+=(pre.argmax(1)==y).type(torch.float).sum().item()
        train_loss+=loss.item()
    train_acc/=size
    train_loss/=num_batches
    return train_acc,train_loss

def test(testloader,model,loss_fn):
    size=len(testloader.dataset)
    num_batsize=len(testloader)
    test_acc,test_loss=0,0
    with torch.no_grad():
        for imgs,targets in testloader:
            imgs,targets=imgs.to(device),targets.to(device)
            target_pre=model(imgs)
            loss=loss_fn(target_pre,targets)
            test_acc +=(target_pre.argmax(1)==targets).type(torch.float).sum().item()
            test_loss +=loss.item()
    test_acc/=size
    test_loss/=num_batsize
    return test_acc,test_loss

epoches=200
train_acc=[]
train_loss=[]
test_acc=[]
test_loss=[]
output="./output"
start_epoch=0
if not os.path.exists(output) or not os.path.isdir(output):
    os.makedirs(output)
if start_epoch>0:
    resumeFile=os.path.join(output,'epoch'+str(start_epoch)+'.pkl')
    if not os.path.exists(resumeFile) or not os.path.isfile(resumeFile):
        start_epoch=0
    else:
        model.load_state_dict(torch.load(resumeFile))

print("\n开始训练")
for epoch in range(epoches):
    model.train()
    epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
    model.eval()
    epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    template=('Epoch:{:2d},Train_acc:{:.1f}%,Tran_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(time.strftime('[%Y-%m-%d %H:%M:%S]'),template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss))
print("Done")

###参数保存
saveFile=os.path.join(output,'epoch'+str(epoch)+'.pkl')
torch.save(model.state_dict(),saveFile)
warnings.filterwarnings("ignore")#忽略警告信息

epoch_range=range(epoches)
plt.figure(figsize=(12,3))
plt.subplot(1,2,1)
plt.plot(epoch_range,train_acc,label="Training Acc")
plt.plot(epoch_range,test_acc,label="Test Acc")
plt.legend(loc='lower right')
plt.title("Training and Test Accuracy")

plt.subplot(1,2,2)
plt.plot(epoch_range,train_loss,label="Training Loss")
plt.plot(epoch_range,test_loss,label="Test Loss")
plt.legend(loc="upper right")
plt.title("Training and Test Loss")
plt.show()

打印结果

网络信息


    Layer (type)               Output Shape         Param #

================================================================
Conv2d-1 [-1, 64, 30, 30] 1,792
MaxPool2d-2 [-1, 64, 15, 15] 0
Dropout-3 [-1, 64, 15, 15] 0
Conv2d-4 [-1, 64, 13, 13] 36,928
MaxPool2d-5 [-1, 64, 6, 6] 0
Dropout-6 [-1, 64, 6, 6] 0
Conv2d-7 [-1, 128, 4, 4] 73,856
MaxPool2d-8 [-1, 128, 2, 2] 0
Dropout-9 [-1, 128, 2, 2] 0
Linear-10 [-1, 256] 131,328
Linear-11 [-1, 10] 2,570

Total params: 246,474
Trainable params: 246,474
Non-trainable params: 0

开始训练
[2022-12-25 20:38:35] Epoch: 1,Train_acc:10.7%,Tran_loss:2.302,Test_acc:14.1%,Test_loss:2.301
[2022-12-25 20:38:43] Epoch: 2,Train_acc:10.9%,Tran_loss:2.298,Test_acc:13.4%,Test_loss:2.295
[2022-12-25 20:38:51] Epoch: 3,Train_acc:12.1%,Tran_loss:2.285,Test_acc:18.3%,Test_loss:2.269
[2022-12-25 20:38:59] Epoch: 4,Train_acc:18.5%,Tran_loss:2.207,Test_acc:22.2%,Test_loss:2.117
[2022-12-25 20:39:07] Epoch: 5,Train_acc:22.8%,Tran_loss:2.071,Test_acc:24.8%,Test_loss:2.036
[2022-12-25 20:39:15] Epoch: 6,Train_acc:24.6%,Tran_loss:2.026,Test_acc:26.7%,Test_loss:1.992
[2022-12-25 20:39:23] Epoch: 7,Train_acc:26.2%,Tran_loss:1.993,Test_acc:27.6%,Test_loss:1.967
[2022-12-25 20:39:31] Epoch: 8,Train_acc:27.6%,Tran_loss:1.959,Test_acc:29.3%,Test_loss:1.927
[2022-12-25 20:39:39] Epoch: 9,Train_acc:29.1%,Tran_loss:1.918,Test_acc:31.6%,Test_loss:1.865
[2022-12-25 20:39:47] Epoch:10,Train_acc:31.2%,Tran_loss:1.868,Test_acc:33.9%,Test_loss:1.810
[2022-12-25 20:39:55] Epoch:11,Train_acc:33.1%,Tran_loss:1.817,Test_acc:35.7%,Test_loss:1.765
[2022-12-25 20:40:03] Epoch:12,Train_acc:34.9%,Tran_loss:1.771,Test_acc:37.5%,Test_loss:1.724
[2022-12-25 20:40:11] Epoch:13,Train_acc:36.7%,Tran_loss:1.728,Test_acc:38.4%,Test_loss:1.692
[2022-12-25 20:40:19] Epoch:14,Train_acc:38.1%,Tran_loss:1.687,Test_acc:40.8%,Test_loss:1.634
[2022-12-25 20:40:27] Epoch:15,Train_acc:39.7%,Tran_loss:1.652,Test_acc:41.9%,Test_loss:1.597
[2022-12-25 20:40:35] Epoch:16,Train_acc:40.7%,Tran_loss:1.620,Test_acc:43.1%,Test_loss:1.567
[2022-12-25 20:40:43] Epoch:17,Train_acc:42.0%,Tran_loss:1.594,Test_acc:44.2%,Test_loss:1.542
[2022-12-25 20:40:51] Epoch:18,Train_acc:42.7%,Tran_loss:1.563,Test_acc:44.4%,Test_loss:1.530
[2022-12-25 20:40:59] Epoch:19,Train_acc:43.9%,Tran_loss:1.540,Test_acc:45.6%,Test_loss:1.504
[2022-12-25 20:41:07] Epoch:20,Train_acc:44.9%,Tran_loss:1.511,Test_acc:47.3%,Test_loss:1.454
[2022-12-25 20:41:15] Epoch:21,Train_acc:45.9%,Tran_loss:1.489,Test_acc:48.8%,Test_loss:1.429
[2022-12-25 20:41:23] Epoch:22,Train_acc:46.9%,Tran_loss:1.466,Test_acc:48.9%,Test_loss:1.415
[2022-12-25 20:41:31] Epoch:23,Train_acc:47.8%,Tran_loss:1.446,Test_acc:49.6%,Test_loss:1.396
[2022-12-25 20:41:40] Epoch:24,Train_acc:48.4%,Tran_loss:1.424,Test_acc:49.9%,Test_loss:1.386
[2022-12-25 20:41:48] Epoch:25,Train_acc:49.1%,Tran_loss:1.406,Test_acc:51.4%,Test_loss:1.356
[2022-12-25 20:41:56] Epoch:26,Train_acc:50.0%,Tran_loss:1.385,Test_acc:52.4%,Test_loss:1.351
[2022-12-25 20:42:04] Epoch:27,Train_acc:50.9%,Tran_loss:1.371,Test_acc:52.4%,Test_loss:1.325
[2022-12-25 20:42:12] Epoch:28,Train_acc:51.7%,Tran_loss:1.350,Test_acc:54.0%,Test_loss:1.295
[2022-12-25 20:42:20] Epoch:29,Train_acc:52.0%,Tran_loss:1.338,Test_acc:54.2%,Test_loss:1.281
[2022-12-25 20:42:28] Epoch:30,Train_acc:53.1%,Tran_loss:1.318,Test_acc:54.7%,Test_loss:1.271
[2022-12-25 20:42:36] Epoch:31,Train_acc:53.3%,Tran_loss:1.303,Test_acc:54.9%,Test_loss:1.266
[2022-12-25 20:42:44] Epoch:32,Train_acc:54.0%,Tran_loss:1.291,Test_acc:56.1%,Test_loss:1.230
[2022-12-25 20:42:52] Epoch:33,Train_acc:54.5%,Tran_loss:1.278,Test_acc:56.0%,Test_loss:1.236
[2022-12-25 20:43:00] Epoch:34,Train_acc:55.2%,Tran_loss:1.263,Test_acc:56.3%,Test_loss:1.220
[2022-12-25 20:43:09] Epoch:35,Train_acc:55.4%,Tran_loss:1.248,Test_acc:56.3%,Test_loss:1.217
[2022-12-25 20:43:16] Epoch:36,Train_acc:56.2%,Tran_loss:1.235,Test_acc:56.3%,Test_loss:1.232
[2022-12-25 20:43:24] Epoch:37,Train_acc:56.7%,Tran_loss:1.220,Test_acc:58.5%,Test_loss:1.174
[2022-12-25 20:43:32] Epoch:38,Train_acc:57.1%,Tran_loss:1.209,Test_acc:59.1%,Test_loss:1.153
[2022-12-25 20:43:40] Epoch:39,Train_acc:57.7%,Tran_loss:1.196,Test_acc:59.8%,Test_loss:1.139
[2022-12-25 20:43:48] Epoch:40,Train_acc:58.2%,Tran_loss:1.182,Test_acc:60.5%,Test_loss:1.130
[2022-12-25 20:43:56] Epoch:41,Train_acc:58.7%,Tran_loss:1.170,Test_acc:60.7%,Test_loss:1.119
[2022-12-25 20:44:04] Epoch:42,Train_acc:59.1%,Tran_loss:1.157,Test_acc:60.5%,Test_loss:1.118
[2022-12-25 20:44:12] Epoch:43,Train_acc:59.5%,Tran_loss:1.146,Test_acc:60.8%,Test_loss:1.114
[2022-12-25 20:44:20] Epoch:44,Train_acc:60.1%,Tran_loss:1.136,Test_acc:62.4%,Test_loss:1.087
[2022-12-25 20:44:28] Epoch:45,Train_acc:60.6%,Tran_loss:1.123,Test_acc:61.8%,Test_loss:1.100
[2022-12-25 20:44:36] Epoch:46,Train_acc:60.8%,Tran_loss:1.117,Test_acc:62.4%,Test_loss:1.076
[2022-12-25 20:44:44] Epoch:47,Train_acc:61.0%,Tran_loss:1.105,Test_acc:62.2%,Test_loss:1.080
[2022-12-25 20:44:52] Epoch:48,Train_acc:61.6%,Tran_loss:1.095,Test_acc:63.5%,Test_loss:1.052
[2022-12-25 20:45:00] Epoch:49,Train_acc:61.9%,Tran_loss:1.085,Test_acc:64.0%,Test_loss:1.044
[2022-12-25 20:45:08] Epoch:50,Train_acc:62.4%,Tran_loss:1.073,Test_acc:64.1%,Test_loss:1.030
[2022-12-25 20:45:16] Epoch:51,Train_acc:62.6%,Tran_loss:1.067,Test_acc:64.8%,Test_loss:1.022
[2022-12-25 20:45:24] Epoch:52,Train_acc:62.8%,Tran_loss:1.057,Test_acc:64.7%,Test_loss:1.022
[2022-12-25 20:45:32] Epoch:53,Train_acc:63.4%,Tran_loss:1.046,Test_acc:63.9%,Test_loss:1.047
[2022-12-25 20:45:40] Epoch:54,Train_acc:63.6%,Tran_loss:1.041,Test_acc:65.4%,Test_loss:1.011
[2022-12-25 20:45:48] Epoch:55,Train_acc:63.7%,Tran_loss:1.033,Test_acc:65.4%,Test_loss:0.997
[2022-12-25 20:45:56] Epoch:56,Train_acc:64.2%,Tran_loss:1.024,Test_acc:64.7%,Test_loss:1.013
[2022-12-25 20:46:04] Epoch:57,Train_acc:64.5%,Tran_loss:1.016,Test_acc:65.6%,Test_loss:0.987
[2022-12-25 20:46:12] Epoch:58,Train_acc:65.2%,Tran_loss:1.004,Test_acc:65.9%,Test_loss:0.979
[2022-12-25 20:46:20] Epoch:59,Train_acc:65.2%,Tran_loss:0.998,Test_acc:66.7%,Test_loss:0.969
[2022-12-25 20:46:28] Epoch:60,Train_acc:65.4%,Tran_loss:0.993,Test_acc:64.8%,Test_loss:1.015
[2022-12-25 20:46:36] Epoch:61,Train_acc:65.7%,Tran_loss:0.983,Test_acc:67.1%,Test_loss:0.950
[2022-12-25 20:46:44] Epoch:62,Train_acc:65.8%,Tran_loss:0.977,Test_acc:66.8%,Test_loss:0.960
[2022-12-25 20:46:52] Epoch:63,Train_acc:66.1%,Tran_loss:0.965,Test_acc:67.3%,Test_loss:0.947
[2022-12-25 20:47:00] Epoch:64,Train_acc:66.4%,Tran_loss:0.964,Test_acc:66.4%,Test_loss:0.975
[2022-12-25 20:47:08] Epoch:65,Train_acc:66.6%,Tran_loss:0.955,Test_acc:68.2%,Test_loss:0.924
[2022-12-25 20:47:16] Epoch:66,Train_acc:66.9%,Tran_loss:0.943,Test_acc:68.1%,Test_loss:0.920
[2022-12-25 20:47:24] Epoch:67,Train_acc:67.1%,Tran_loss:0.940,Test_acc:67.9%,Test_loss:0.926
[2022-12-25 20:47:32] Epoch:68,Train_acc:67.3%,Tran_loss:0.933,Test_acc:67.8%,Test_loss:0.932
[2022-12-25 20:47:40] Epoch:69,Train_acc:67.5%,Tran_loss:0.926,Test_acc:69.0%,Test_loss:0.904
[2022-12-25 20:47:48] Epoch:70,Train_acc:68.0%,Tran_loss:0.922,Test_acc:68.9%,Test_loss:0.899
[2022-12-25 20:47:56] Epoch:71,Train_acc:68.1%,Tran_loss:0.918,Test_acc:68.6%,Test_loss:0.911
[2022-12-25 20:48:04] Epoch:72,Train_acc:68.3%,Tran_loss:0.907,Test_acc:69.7%,Test_loss:0.885
[2022-12-25 20:48:12] Epoch:73,Train_acc:68.5%,Tran_loss:0.901,Test_acc:69.7%,Test_loss:0.880
[2022-12-25 20:48:20] Epoch:74,Train_acc:68.9%,Tran_loss:0.895,Test_acc:69.6%,Test_loss:0.880
[2022-12-25 20:48:28] Epoch:75,Train_acc:68.9%,Tran_loss:0.888,Test_acc:69.3%,Test_loss:0.889
[2022-12-25 20:48:36] Epoch:76,Train_acc:69.3%,Tran_loss:0.880,Test_acc:70.5%,Test_loss:0.864
[2022-12-25 20:48:45] Epoch:77,Train_acc:69.5%,Tran_loss:0.877,Test_acc:69.9%,Test_loss:0.879
[2022-12-25 20:48:53] Epoch:78,Train_acc:69.6%,Tran_loss:0.871,Test_acc:70.5%,Test_loss:0.857
[2022-12-25 20:49:01] Epoch:79,Train_acc:70.1%,Tran_loss:0.860,Test_acc:70.0%,Test_loss:0.872
[2022-12-25 20:49:09] Epoch:80,Train_acc:70.2%,Tran_loss:0.855,Test_acc:70.5%,Test_loss:0.857
[2022-12-25 20:49:18] Epoch:81,Train_acc:70.4%,Tran_loss:0.850,Test_acc:71.4%,Test_loss:0.838
[2022-12-25 20:49:28] Epoch:82,Train_acc:70.5%,Tran_loss:0.848,Test_acc:71.2%,Test_loss:0.832
[2022-12-25 20:49:37] Epoch:83,Train_acc:70.5%,Tran_loss:0.843,Test_acc:71.4%,Test_loss:0.833
[2022-12-25 20:49:45] Epoch:84,Train_acc:70.7%,Tran_loss:0.838,Test_acc:70.8%,Test_loss:0.842
[2022-12-25 20:49:53] Epoch:85,Train_acc:71.0%,Tran_loss:0.828,Test_acc:71.3%,Test_loss:0.834
[2022-12-25 20:50:01] Epoch:86,Train_acc:71.3%,Tran_loss:0.824,Test_acc:71.3%,Test_loss:0.823
[2022-12-25 20:50:09] Epoch:87,Train_acc:71.5%,Tran_loss:0.821,Test_acc:71.5%,Test_loss:0.823
[2022-12-25 20:50:17] Epoch:88,Train_acc:71.4%,Tran_loss:0.820,Test_acc:70.4%,Test_loss:0.848
[2022-12-25 20:50:25] Epoch:89,Train_acc:71.8%,Tran_loss:0.808,Test_acc:71.6%,Test_loss:0.822
[2022-12-25 20:50:33] Epoch:90,Train_acc:72.2%,Tran_loss:0.802,Test_acc:71.4%,Test_loss:0.824
[2022-12-25 20:50:41] Epoch:91,Train_acc:72.2%,Tran_loss:0.799,Test_acc:72.4%,Test_loss:0.805
[2022-12-25 20:50:49] Epoch:92,Train_acc:72.2%,Tran_loss:0.796,Test_acc:72.6%,Test_loss:0.804
[2022-12-25 20:50:57] Epoch:93,Train_acc:72.6%,Tran_loss:0.786,Test_acc:72.7%,Test_loss:0.796
[2022-12-25 20:51:05] Epoch:94,Train_acc:72.8%,Tran_loss:0.785,Test_acc:72.7%,Test_loss:0.795
[2022-12-25 20:51:13] Epoch:95,Train_acc:73.0%,Tran_loss:0.779,Test_acc:72.4%,Test_loss:0.798
[2022-12-25 20:51:21] Epoch:96,Train_acc:73.0%,Tran_loss:0.771,Test_acc:72.5%,Test_loss:0.801
[2022-12-25 20:51:30] Epoch:97,Train_acc:73.2%,Tran_loss:0.772,Test_acc:72.0%,Test_loss:0.813
[2022-12-25 20:51:38] Epoch:98,Train_acc:73.4%,Tran_loss:0.764,Test_acc:72.5%,Test_loss:0.803
[2022-12-25 20:51:46] Epoch:99,Train_acc:73.7%,Tran_loss:0.759,Test_acc:73.4%,Test_loss:0.773
[2022-12-25 20:51:54] Epoch:100,Train_acc:73.7%,Tran_loss:0.757,Test_acc:73.4%,Test_loss:0.775

结果可视化

【365计划-2】pytorch实现彩色图片识别_第1张图片
训练与测试拟合得十分好,本文只训练了100轮达到75%的识别率,继续加大训练轮数将能提升识别率。
本项目的学习率为0.01,先开始换成0.001,收敛训练巨慢。

你可能感兴趣的:(pytorch项目实践,pytorch,深度学习,python)