"""
# @file name : 1_split_dataset.py
# @author : tingsongyu
# @date : 2019-09-07 10:08:00
# @brief : 将数据集划分为训练集,验证集,测试集
"""
# for data split
import os
import random
import shutil
# for train_lenet
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
np.random.seed(1)
dataset_dir = os.path.join('D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06','data','RMB_data')
split_dir = os.path.join('D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06','data','rmb_split')
train_dir = os.path.join(split_dir, 'train')
valid_dir = os.path.join(split_dir, 'valid')
test_dir = os.path.join(split_dir, 'test')
train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1
print(dataset_dir)
for root, dirs, files in os.walk(dataset_dir):
print(root,'\n')
print(dirs,'\n')
print(files,'\n')
for sub_dir in dirs:
imgs = os.listdir(os.path.join(root, sub_dir))
imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
random.shuffle(imgs)
img_count = len(imgs)
train_point = int(img_count * train_pct)
valid_point = int(img_count * (train_pct + valid_pct))
for i in range(img_count):
if i < train_point:
out_dir = os.path.join(train_dir, sub_dir)
elif i < valid_point:
out_dir = os.path.join(valid_dir, sub_dir)
else:
out_dir = os.path.join(test_dir, sub_dir)
makedir(out_dir)
target_path = os.path.join(out_dir, imgs[i])
src_path = os.path.join(dataset_dir, sub_dir, imgs[i])
shutil.copy(src_path, target_path)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir,
train_point,
valid_point-train_point,
valid_point))
D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06\data\RMB_data
D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06\data\RMB_data
['1', '100']
[]
Class:1, train:80, valid:10, test:90
Class:100, train:80, valid:10, test:90
D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06\data\RMB_data\1
[]
['01B68AKT.jpg', '01EIM65B.jpg', '01LNYXO4.jpg', '01MF2W5S.jpg', '01NISKCG.jpg', '02C4V1SW.jpg', '03WGM2XG.jpg', '03WV6GFZ.jpg', '049I6MVB.jpg', '04A32I57.jpg', '04MGL637.jpg', '04QE2KHA.jpg', '04QGLB16.jpg', '04RWK2B5.jpg', '04VRAHK2.jpg', '04YVW9CN.jpg', '059GS728.jpg', '05MLGSGI.jpg', '05MO1N93.jpg', '067TZA8C.jpg', '069N3OK2.jpg', '073LW92O.jpg', '07GVXBMG.jpg', '07HTXU3W.jpg', '07IUEGQX.jpg', '07PVUGTB.jpg', '07R6PKIX.jpg', '08596RNG.jpg', '08C3EHPG.jpg', '09F2SGOT.jpg', '09PUM1HY.jpg', '0B89KOA3.jpg', '0BGHNV6P.jpg', '0BRO7XVG.jpg', '0C4UDH9S.jpg', '0CNU427V.jpg', '0CTO7MER.jpg', '0D29EFZO.jpg', '0D6HCAXL.jpg', '0D73KYGN.jpg', '0DLF8NU6.jpg', '0DLW9G7O.jpg', '0DRZXTK3.jpg', '0E4QRCTS.jpg', '0E5Q62TM.jpg', '0E6AGCOW.jpg', '0EBSK2GF.jpg', '0EMSWVIR.jpg', '0EP9R4N8.jpg', '0EZ7ND18.jpg', '0F9CEKGH.jpg', '0FA8LD9Z.jpg', '0FY3IOKC.jpg', '0G7ZDUOL.jpg', '0GE1UZT5.jpg', '0GFISMAH.jpg', '0GHKAWQX.jpg', '0GP1SABX.jpg', '0GPYRDQM.jpg', '0GQBKCAW.jpg', '0GRZFSDG.jpg', '0GZN8V26.jpg', '0H2X15GN.jpg', '0H5TYFCG.jpg', '0HBEG1TG.jpg', '0HK5MCIY.jpg', '0HO24UXQ.jpg', '0I1X9SDG.jpg', '0IDB8M67.jpg', '0IPXU5A9.jpg', '0KOIAHWT.jpg', '0KS8UVFH.jpg', '0KYWGVO5.jpg', '0L7G631P.jpg', '0LWI5TZA.jpg', '0M9RSHZX.jpg', '0MN9158I.jpg', '0MNY59BW.jpg', '0MOHTNXQ.jpg', '0MRGZVBU.jpg', '0MVSIP89.jpg', '0NAQUMVX.jpg', '0NARV1BG.jpg', '0NK2SDAY.jpg', '0NP96EGY.jpg', '0O1LZWHQ.jpg', '0OFE6MSI.jpg', '0ON7E9RU.jpg', '0OS6ZK8X.jpg', '0PNDS7OG.jpg', '0QDIG4BK.jpg', '0QS21LZ5.jpg', '0QX8O4K5.jpg', '0R2P4H1I.jpg', '0R5836GQ.jpg', '0R5C4V7X.jpg', '0RAYEIFB.jpg', '0RBDE8G9.jpg', '0RNYOPL5.jpg', '0RPZ5WDL.jpg']
D://片儿//1910pt课程//lesson-06-DataLoader与Dataset//lesson//lesson-06\data\RMB_data\100
[]
['013MNV9B.jpg', '01953EH7.jpg', '01GUGTQ4.jpg', '0237YRPB.jpg', '027AXFQE.jpg', '02GMCEUY.jpg', '02OE5LH4.jpg', '02U7GMR4.jpg', '04PKGVRH.jpg', '04VA2NX7.jpg', '04W3GHSB.jpg', '04XUW3YA.jpg', '04Y816LH.jpg', '05CGTWNF.jpg', '05IDEW2M.jpg', '05NY9E4L.jpg', '062RVGPX.jpg', '069EUOGR.jpg', '06DCY1X7.jpg', '06NEIRC4.jpg', '06WXAH5B.jpg', '07C64BY2.jpg', '07EVDRNY.jpg', '07G9OTZ8.jpg', '07GG9EL5.jpg', '07UHGSGR.jpg', '0845ZXHV.jpg', '08FB4P92.jpg', '08KCVAP1.jpg', '094HN2DV.jpg', '0A4DSPGE.jpg', '0A8IHWYD.jpg', '0AGN4YMI.jpg', '0AXUR9N7.jpg', '0AYIPVK9.jpg', '0B4S7DIX.jpg', '0B6G4MGL.jpg', '0BEHP27M.jpg', '0BOVSMYN.jpg', '0BT9RL7V.jpg', '0E6MQZRX.jpg', '0F9X81GD.jpg', '0FBU7PYL.jpg', '0FGZ2O94.jpg', '0FI3Q5G6.jpg', '0FLQ2NM8.jpg', '0FVUA72W.jpg', '0GBP7VOD.jpg', '0GK7QWHV.jpg', '0GO4CF9X.jpg', '0H3XGENW.jpg', '0I18Y4DC.jpg', '0I376P29.jpg', '0ICF2DMA.jpg', '0IPRGGO8.jpg', '0IUO1B5C.jpg', '0K3UG2A7.jpg', '0K9UF7PZ.jpg', '0KR1M3IQ.jpg', '0KVUR7IY.jpg', '0L6X2HUZ.jpg', '0LIKD2CE.jpg', '0M2XL8TP.jpg', '0M9ELBI8.jpg', '0MEDOLBT.jpg', '0MEG4GXO.jpg', '0MLDWG4I.jpg', '0N6GK7DV.jpg', '0NOGYHMV.jpg', '0NR17KZY.jpg', '0NVG5T3I.jpg', '0NVLGX81.jpg', '0OMXR67T.jpg', '0OWAK5B7.jpg', '0P1HGRT2.jpg', '0P3IQFN1.jpg', '0P3LXCWK.jpg', '0PDHIO85.jpg', '0PGDL742.jpg', '0PLFQ4UI.jpg', '0PQXSWVG.jpg', '0QDUHBWO.jpg', '0R6X4SO8.jpg', '0S4TL8YH.jpg', '0TPMVXD2.jpg', '0TW2YELM.jpg', '0U47M9CX.jpg', '0UB46RM1.jpg', '0W462YMN.jpg', '0WLNF8AG.jpg', '0WLO5FN2.jpg', '0WMICN9B.jpg', '0WRGMZ5Y.jpg', '0WV65B8Z.jpg', '0X8MOP1K.jpg', '0YSHMFA2.jpg', '0YU1K84V.jpg', '0Z1DMA2S.jpg', '0Z85G1SR.jpg', '0ZA9M8E2.jpg']
"""
# @file name : train_lenet.py
# @author : tingsongyu
# @date : 2019-09-07 10:08:00
# @brief : 人民币分类模型训练
"""
def set_seed(seed = 1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
set_seed()
rmb_label = {
"1":0, "100":1}
# parameter setting
MAX_EXPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
# ============================ step 1/5 数据 ============================
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(MAX_EXPOCH):
loss_mean = 0.
correct = 0.
total = 0.
MAX_EXPOCH = MAX_EXPOCH
net.train()
for i, data in enumerate(train_loader):
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch,
MAX_EXPOCH,
i+1,
len(train_loader),
loss_mean,
correct / total))
loss_mean = 0.
scheduler.step() # 更新学习率
# validate the model
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum().numpy()
loss_val += loss.item()
valid_curve.append(loss_val)
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch,
MAX_EXPOCH,
j+1,
len(valid_loader),
loss_val,
correct / total))
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
Training:Epoch[000/010] Iteration[010/010] Loss: 0.1794 Acc:95.62%
Valid: Epoch[000/010] Iteration[002/002] Loss: 0.0104 Acc:95.62%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.5612 Acc:85.00%
Valid: Epoch[001/010] Iteration[002/002] Loss: 2.4378 Acc:85.00%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.2957 Acc:87.50%
Valid: Epoch[002/010] Iteration[002/002] Loss: 0.0667 Acc:87.50%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.0508 Acc:98.12%
Valid: Epoch[003/010] Iteration[002/002] Loss: 0.0001 Acc:98.12%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.0089 Acc:100.00%
Valid: Epoch[004/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.0015 Acc:100.00%
Valid: Epoch[005/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.0363 Acc:98.75%
Valid: Epoch[006/010] Iteration[002/002] Loss: 0.0000 Acc:98.75%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.0220 Acc:98.75%
Valid: Epoch[007/010] Iteration[002/002] Loss: 0.0000 Acc:98.75%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.0005 Acc:100.00%
Valid: Epoch[008/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.0373 Acc:98.75%
Valid: Epoch[009/010] Iteration[002/002] Loss: 0.0000 Acc:98.75%
# ============================ inference ============================
__file__ = "D:\片儿\1910pt课程\lesson-06-DataLoader与Dataset\lesson\lesson-06"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")
test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)
for i, data in enumerate(valid_loader):
# forward
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
rmb = 1 if predicted.numpy()[0] == 0 else 100
print("模型获得{}元".format(rmb))