lesson-06-image-classification-based-on-lenet

Part I 拆分数据

"""
# @file name  : 1_split_dataset.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 将数据集划分为训练集,验证集,测试集
"""
# for data split
import os
import random
import shutil
# for train_lenet
import numpy as np
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)
np.random.seed(1)
dataset_dir = os.path.join('D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06','data','RMB_data')
split_dir = os.path.join('D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06','data','rmb_split')
train_dir = os.path.join(split_dir, 'train')
valid_dir = os.path.join(split_dir, 'valid')
test_dir = os.path.join(split_dir, 'test')

train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1
print(dataset_dir)
for root, dirs, files in os.walk(dataset_dir):
    print(root,'\n')
    print(dirs,'\n')
    print(files,'\n')
    for sub_dir in dirs:
            imgs = os.listdir(os.path.join(root, sub_dir))
            imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
            random.shuffle(imgs)
            img_count = len(imgs)

            train_point = int(img_count * train_pct)
            valid_point = int(img_count * (train_pct + valid_pct))
            
            for i in range(img_count):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sub_dir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sub_dir)
                else:
                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, 
                                                                 train_point, 
                                                                 valid_point-train_point,
                                                                 valid_point))
               

D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06\data\RMB_data
D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06\data\RMB_data 

['1', '100'] 

[] 

Class:1, train:80, valid:10, test:90
Class:100, train:80, valid:10, test:90

D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06\data\RMB_data\1 

[] 

['01B68AKT.jpg', '01EIM65B.jpg', '01LNYXO4.jpg', '01MF2W5S.jpg', '01NISKCG.jpg', '02C4V1SW.jpg', '03WGM2XG.jpg', '03WV6GFZ.jpg', '049I6MVB.jpg', '04A32I57.jpg', '04MGL637.jpg', '04QE2KHA.jpg', '04QGLB16.jpg', '04RWK2B5.jpg', '04VRAHK2.jpg', '04YVW9CN.jpg', '059GS728.jpg', '05MLGSGI.jpg', '05MO1N93.jpg', '067TZA8C.jpg', '069N3OK2.jpg', '073LW92O.jpg', '07GVXBMG.jpg', '07HTXU3W.jpg', '07IUEGQX.jpg', '07PVUGTB.jpg', '07R6PKIX.jpg', '08596RNG.jpg', '08C3EHPG.jpg', '09F2SGOT.jpg', '09PUM1HY.jpg', '0B89KOA3.jpg', '0BGHNV6P.jpg', '0BRO7XVG.jpg', '0C4UDH9S.jpg', '0CNU427V.jpg', '0CTO7MER.jpg', '0D29EFZO.jpg', '0D6HCAXL.jpg', '0D73KYGN.jpg', '0DLF8NU6.jpg', '0DLW9G7O.jpg', '0DRZXTK3.jpg', '0E4QRCTS.jpg', '0E5Q62TM.jpg', '0E6AGCOW.jpg', '0EBSK2GF.jpg', '0EMSWVIR.jpg', '0EP9R4N8.jpg', '0EZ7ND18.jpg', '0F9CEKGH.jpg', '0FA8LD9Z.jpg', '0FY3IOKC.jpg', '0G7ZDUOL.jpg', '0GE1UZT5.jpg', '0GFISMAH.jpg', '0GHKAWQX.jpg', '0GP1SABX.jpg', '0GPYRDQM.jpg', '0GQBKCAW.jpg', '0GRZFSDG.jpg', '0GZN8V26.jpg', '0H2X15GN.jpg', '0H5TYFCG.jpg', '0HBEG1TG.jpg', '0HK5MCIY.jpg', '0HO24UXQ.jpg', '0I1X9SDG.jpg', '0IDB8M67.jpg', '0IPXU5A9.jpg', '0KOIAHWT.jpg', '0KS8UVFH.jpg', '0KYWGVO5.jpg', '0L7G631P.jpg', '0LWI5TZA.jpg', '0M9RSHZX.jpg', '0MN9158I.jpg', '0MNY59BW.jpg', '0MOHTNXQ.jpg', '0MRGZVBU.jpg', '0MVSIP89.jpg', '0NAQUMVX.jpg', '0NARV1BG.jpg', '0NK2SDAY.jpg', '0NP96EGY.jpg', '0O1LZWHQ.jpg', '0OFE6MSI.jpg', '0ON7E9RU.jpg', '0OS6ZK8X.jpg', '0PNDS7OG.jpg', '0QDIG4BK.jpg', '0QS21LZ5.jpg', '0QX8O4K5.jpg', '0R2P4H1I.jpg', '0R5836GQ.jpg', '0R5C4V7X.jpg', '0RAYEIFB.jpg', '0RBDE8G9.jpg', '0RNYOPL5.jpg', '0RPZ5WDL.jpg'] 


D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06\data\RMB_data\100 

[] 

['013MNV9B.jpg', '01953EH7.jpg', '01GUGTQ4.jpg', '0237YRPB.jpg', '027AXFQE.jpg', '02GMCEUY.jpg', '02OE5LH4.jpg', '02U7GMR4.jpg', '04PKGVRH.jpg', '04VA2NX7.jpg', '04W3GHSB.jpg', '04XUW3YA.jpg', '04Y816LH.jpg', '05CGTWNF.jpg', '05IDEW2M.jpg', '05NY9E4L.jpg', '062RVGPX.jpg', '069EUOGR.jpg', '06DCY1X7.jpg', '06NEIRC4.jpg', '06WXAH5B.jpg', '07C64BY2.jpg', '07EVDRNY.jpg', '07G9OTZ8.jpg', '07GG9EL5.jpg', '07UHGSGR.jpg', '0845ZXHV.jpg', '08FB4P92.jpg', '08KCVAP1.jpg', '094HN2DV.jpg', '0A4DSPGE.jpg', '0A8IHWYD.jpg', '0AGN4YMI.jpg', '0AXUR9N7.jpg', '0AYIPVK9.jpg', '0B4S7DIX.jpg', '0B6G4MGL.jpg', '0BEHP27M.jpg', '0BOVSMYN.jpg', '0BT9RL7V.jpg', '0E6MQZRX.jpg', '0F9X81GD.jpg', '0FBU7PYL.jpg', '0FGZ2O94.jpg', '0FI3Q5G6.jpg', '0FLQ2NM8.jpg', '0FVUA72W.jpg', '0GBP7VOD.jpg', '0GK7QWHV.jpg', '0GO4CF9X.jpg', '0H3XGENW.jpg', '0I18Y4DC.jpg', '0I376P29.jpg', '0ICF2DMA.jpg', '0IPRGGO8.jpg', '0IUO1B5C.jpg', '0K3UG2A7.jpg', '0K9UF7PZ.jpg', '0KR1M3IQ.jpg', '0KVUR7IY.jpg', '0L6X2HUZ.jpg', '0LIKD2CE.jpg', '0M2XL8TP.jpg', '0M9ELBI8.jpg', '0MEDOLBT.jpg', '0MEG4GXO.jpg', '0MLDWG4I.jpg', '0N6GK7DV.jpg', '0NOGYHMV.jpg', '0NR17KZY.jpg', '0NVG5T3I.jpg', '0NVLGX81.jpg', '0OMXR67T.jpg', '0OWAK5B7.jpg', '0P1HGRT2.jpg', '0P3IQFN1.jpg', '0P3LXCWK.jpg', '0PDHIO85.jpg', '0PGDL742.jpg', '0PLFQ4UI.jpg', '0PQXSWVG.jpg', '0QDUHBWO.jpg', '0R6X4SO8.jpg', '0S4TL8YH.jpg', '0TPMVXD2.jpg', '0TW2YELM.jpg', '0U47M9CX.jpg', '0UB46RM1.jpg', '0W462YMN.jpg', '0WLNF8AG.jpg', '0WLO5FN2.jpg', '0WMICN9B.jpg', '0WRGMZ5Y.jpg', '0WV65B8Z.jpg', '0X8MOP1K.jpg', '0YSHMFA2.jpg', '0YU1K84V.jpg', '0Z1DMA2S.jpg', '0Z85G1SR.jpg', '0ZA9M8E2.jpg'] 

Part II train_letnet

"""
# @file name  : train_lenet.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 人民币分类模型训练
"""
def set_seed(seed = 1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
set_seed()
rmb_label = {
     "1":0, "100":1}

# parameter setting
MAX_EXPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
# ============================ step 1/5 数据 ============================

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(MAX_EXPOCH):
    loss_mean = 0.
    correct = 0.
    total = 0.
    MAX_EXPOCH = MAX_EXPOCH
    net.train()
    for i, data in enumerate(train_loader):
        # forward
        inputs, labels = data
        outputs = net(inputs)
        
        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()
        
        # update weights
        optimizer.step()
        
 
        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch,
                                                                                                          MAX_EXPOCH,
                                                                                                          i+1,
                                                                                                          len(train_loader),
                                                                                                          loss_mean,
                                                                                                          correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率
    
     # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss_val)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch,
                                                                                                          MAX_EXPOCH,
                                                                                                          j+1,
                                                                                                          len(valid_loader),
                                                                                                          loss_val,
                                                                                                          correct / total))
train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

Training:Epoch[000/010] Iteration[010/010] Loss: 0.1794 Acc:95.62%
Valid:	 Epoch[000/010] Iteration[002/002] Loss: 0.0104 Acc:95.62%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.5612 Acc:85.00%
Valid:	 Epoch[001/010] Iteration[002/002] Loss: 2.4378 Acc:85.00%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.2957 Acc:87.50%
Valid:	 Epoch[002/010] Iteration[002/002] Loss: 0.0667 Acc:87.50%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.0508 Acc:98.12%
Valid:	 Epoch[003/010] Iteration[002/002] Loss: 0.0001 Acc:98.12%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.0089 Acc:100.00%
Valid:	 Epoch[004/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.0015 Acc:100.00%
Valid:	 Epoch[005/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.0363 Acc:98.75%
Valid:	 Epoch[006/010] Iteration[002/002] Loss: 0.0000 Acc:98.75%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.0220 Acc:98.75%
Valid:	 Epoch[007/010] Iteration[002/002] Loss: 0.0000 Acc:98.75%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.0005 Acc:100.00%
Valid:	 Epoch[008/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.0373 Acc:98.75%
Valid:	 Epoch[009/010] Iteration[002/002] Loss: 0.0000 Acc:98.75%

lesson-06-image-classification-based-on-lenet_第1张图片

# ============================ inference ============================
__file__ = "D:\片儿\1910pt课程\lesson-06-DataLoader与Dataset\lesson\lesson-06"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")

test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)

for i, data in enumerate(valid_loader):
    # forward
    inputs, labels = data
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)

    rmb = 1 if predicted.numpy()[0] == 0 else 100
    print("模型获得{}元".format(rmb))

你可能感兴趣的:(pytorch)