版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:[email protected]
Fastai是在pytorch上封装的深度学习框架,效果出众,以下是训练CIFAR10的过程。
导入库
from fastai import *
from fastai.vision import *
from fastai.callbacks import CSVLogger, SaveModelCallback
验证集上训练结果计算和显示
def show_result(learn):
# 得到验证集上的准确度
probs, val_labels = learn.get_preds(ds_type=DatasetType.Valid)
print('Accuracy', accuracy(probs, val_labels)),
print('Error Rate', error_rate(probs, val_labels))
训练结果混淆矩阵及预测错误最多的类型显示
def show_matrix(learn):
# 画训练结果的混合矩阵
interp = ClassificationInterpretation.from_learner(learn)
interp.confusion_matrix()
interp.plot_confusion_matrix(dpi=120)
# 显示判断错误最多的类型,min_val指定错误次数,默认1
# 打印顺序为actual, predicted, number of occurrences.
interp.most_confused(min_val=5)
# 模型预测最困难的9个样本显示
# 显示顺序为预测值、实际值、损失值、预测对的概率
interp.plot_top_losses(9, figsize=(10, 10))
下载数据集,因调用linux的tar进行解压,在windows下会出错,可手动解压,解压后的目录:
# 下载数据集
untar_data(URLs.CIFAR)
# 训练数据目录
path = Path(r'C:\Users\Administrator\.fastai\data\cifar10')
定义数据及数据在线增强方式
# 数据在线增强方式定义
tfms = get_transforms(do_flip=False)
data = (ImageList.from_folder(path) # Where to find the data? -> in path and its subfolders
.split_by_rand_pct() # How to split in train/valid? -> use the folders
.label_from_folder() # How to label? -> depending on the folder of the filenames
.add_test_folder() # Optionally add a test set (here default name is test)
.transform(tfms, size=(32, 32)) # Data augmentation? -> use tfms with a size of 164
.databunch(bs=128) # Finally? -> use the defaults for conversion to ImageDataBunch
.normalize(imagenet_stats))
查看数据
# 查看数据信息
data.classes, data.c, data
(['airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck'],
10,
ImageDataBunch;
Train: LabelList (39072 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
airplane,airplane,airplane,airplane,airplane
Path: C:\Users\Administrator\.fastai\data\cifar10;
Valid: LabelList (9767 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
airplane,deer,deer,deer,automobile
Path: C:\Users\Administrator\.fastai\data\cifar10;
Test: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: EmptyLabelList
,,,,
Path: C:\Users\Administrator\.fastai\data\cifar10)
创建训练器
# 创建learn
learn = cnn_learner(data, models.resnet50, metrics=[accuracy, error_rate], callback_fns=[ShowGraph, SaveModelCallback])
第一阶段训练
# 最佳学习率寻找
learn.lr_find(end_lr=1)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
# 画出学习率寻找曲线,给出建议学习率
learn.recorder.plot(suggestion=True)
# 根据学习率曲线得到max_lr,开始训练
learn.fit_one_cycle(cyc_len=15, max_lr=1.78e-2)
epochtrain_lossvalid_lossaccuracyerror_ratetime
0
1.074162
0.882136
0.709225
0.290775
01:17
1
0.824112
0.766163
0.740453
0.259547
01:16
2
0.811090
0.938345
0.707792
0.292208
01:16
3
0.799450
0.790665
0.733797
0.266203
01:16
4
0.763200
1.364758
0.752636
0.247364
01:18
5
0.693490
0.683559
0.776902
0.223098
01:16
6
0.673621
0.611799
0.800655
0.199345
01:16
7
0.665126
0.630715
0.796150
0.203850
01:16
8
0.612187
0.874567
0.826149
0.173851
01:16
9
0.563634
0.785189
0.820723
0.179277
01:16
10
0.515540
1.286271
0.829835
0.170165
01:21
11
0.485959
0.524455
0.840688
0.159312
01:16
12
0.444417
0.759944
0.842736
0.157264
01:17
13
0.419838
0.830482
0.845500
0.154500
01:17
14
0.421095
0.550606
0.844783
0.155217
01:16
Better model found at epoch 0 with val_loss value: 0.8821364045143127.
Better model found at epoch 1 with val_loss value: 0.7661632299423218.
Better model found at epoch 5 with val_loss value: 0.6835585832595825.
Better model found at epoch 6 with val_loss value: 0.6117991805076599.
Better model found at epoch 11 with val_loss value: 0.5244545340538025.
训练结果
# 计算和显示训练结果
show_result(learn)
Accuracy tensor(0.8407)
Error Rate tensor(0.1593)
# 保存训练模型
learn.save('stg1')
第二阶段训练
learn.load('stg1')
learn.unfreeze()
learn.lr_find(end_lr=1)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot(suggestion=True)
learn.fit_one_cycle(15, slice(1e-6, 5e-5))
epochtrain_lossvalid_lossaccuracyerror_ratetime
0
0.444569
0.521828
0.840381
0.159619
01:26
1
0.427062
0.513434
0.840483
0.159517
01:27
2
0.430344
0.514867
0.846524
0.153476
01:23
3
0.421480
0.550527
0.845295
0.154705
01:23
4
0.410170
0.506949
0.847855
0.152145
01:23
5
0.402150
0.542091
0.849186
0.150814
01:26
6
0.387639
0.491120
0.850927
0.149073
01:27
7
0.373022
0.511580
0.852155
0.147845
01:28
8
0.375497
0.505493
0.854101
0.145899
01:28
9
0.355466
0.585425
0.852462
0.147538
01:28
10
0.355327
0.506402
0.855534
0.144466
01:28
11
0.341208
0.498502
0.855944
0.144057
01:29
12
0.347057
0.549146
0.851746
0.148254
01:28
13
0.345185
0.533962
0.852155
0.147845
01:28
14
0.334336
0.504231
0.855432
0.144568
01:29
Better model found at epoch 0 with val_loss value: 0.5218283534049988.
Better model found at epoch 1 with val_loss value: 0.5134344696998596.
Better model found at epoch 4 with val_loss value: 0.5069490671157837.
Better model found at epoch 6 with val_loss value: 0.491120308637619.
训练结果
# 计算和显示训练结果
show_result(learn)
Accuracy tensor(0.8509)
Error Rate tensor(0.1491)
保存模型
learn.save('stg2')
# 画训练结果的混合矩阵
interp = ClassificationInterpretation.from_learner(learn)
interp.confusion_matrix()
interp.plot_confusion_matrix(dpi=120)
显示预测错误次数最多的类型,错误次数大于5,输出顺序为actual, predicted, number of occurrences.
interp.most_confused(5)
[('bird', 'frog', 86),
('truck', 'automobile', 71),
('deer', 'frog', 66),
('dog', 'bird', 59),
('airplane', 'ship', 57),
('bird', 'airplane', 54),
('dog', 'frog', 54),
('bird', 'deer', 53),
('dog', 'deer', 50),
('cat', 'dog', 47),
('deer', 'bird', 47),
('automobile', 'truck', 45),
('ship', 'airplane', 45),
('cat', 'frog', 44),
('bird', 'dog', 37),
('ship', 'automobile', 34),
('ship', 'truck', 32),
('airplane', 'bird', 31),
('deer', 'dog', 26),
('frog', 'bird', 25),
('dog', 'cat', 24),
('dog', 'horse', 24),
('airplane', 'automobile', 23),
('horse', 'deer', 23),
('airplane', 'truck', 22),
('airplane', 'deer', 20),
('frog', 'deer', 17),
('cat', 'deer', 16),
('horse', 'dog', 14),
('automobile', 'ship', 13),
('deer', 'horse', 13),
('truck', 'ship', 13),
('bird', 'ship', 12),
('cat', 'bird', 12),
('deer', 'airplane', 12),
('dog', 'truck', 12),
('truck', 'airplane', 12),
('frog', 'dog', 11),
('airplane', 'frog', 10),
('deer', 'ship', 10),
('dog', 'airplane', 9),
('frog', 'automobile', 8),
('horse', 'frog', 8),
('ship', 'bird', 8),
('cat', 'truck', 7),
('horse', 'airplane', 7),
('horse', 'bird', 7),
('ship', 'deer', 7),
('dog', 'automobile', 6),
('truck', 'frog', 6),
('automobile', 'frog', 5),
('bird', 'cat', 5),
('bird', 'truck', 5),
('cat', 'ship', 5),
('dog', 'ship', 5),
('frog', 'airplane', 5)]
预测最困难的9个样本
# 模型预测最困难的9个样本显示
# 显示顺序为预测值、实际值、损失值、预测对的概率
interp.plot_top_losses(9, figsize=(10, 10))