resnet152训练_用fastai ResNet50训练CIFAR10,85%准确度

版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:[email protected]

Fastai是在pytorch上封装的深度学习框架,效果出众,以下是训练CIFAR10的过程。

导入库

from fastai import *

from fastai.vision import *

from fastai.callbacks import CSVLogger, SaveModelCallback

验证集上训练结果计算和显示

def show_result(learn):

# 得到验证集上的准确度

probs, val_labels = learn.get_preds(ds_type=DatasetType.Valid)

print('Accuracy', accuracy(probs, val_labels)),

print('Error Rate', error_rate(probs, val_labels))

训练结果混淆矩阵及预测错误最多的类型显示

def show_matrix(learn):

# 画训练结果的混合矩阵

interp = ClassificationInterpretation.from_learner(learn)

interp.confusion_matrix()

interp.plot_confusion_matrix(dpi=120)

# 显示判断错误最多的类型,min_val指定错误次数,默认1

# 打印顺序为actual, predicted, number of occurrences.

interp.most_confused(min_val=5)

# 模型预测最困难的9个样本显示

# 显示顺序为预测值、实际值、损失值、预测对的概率

interp.plot_top_losses(9, figsize=(10, 10))

下载数据集,因调用linux的tar进行解压,在windows下会出错,可手动解压,解压后的目录:

# 下载数据集

untar_data(URLs.CIFAR)

# 训练数据目录

path = Path(r'C:\Users\Administrator\.fastai\data\cifar10')

定义数据及数据在线增强方式

# 数据在线增强方式定义

tfms = get_transforms(do_flip=False)

data = (ImageList.from_folder(path) # Where to find the data? -> in path and its subfolders

.split_by_rand_pct() # How to split in train/valid? -> use the folders

.label_from_folder() # How to label? -> depending on the folder of the filenames

.add_test_folder() # Optionally add a test set (here default name is test)

.transform(tfms, size=(32, 32)) # Data augmentation? -> use tfms with a size of 164

.databunch(bs=128) # Finally? -> use the defaults for conversion to ImageDataBunch

.normalize(imagenet_stats))

查看数据

# 查看数据信息

data.classes, data.c, data

(['airplane',

'automobile',

'bird',

'cat',

'deer',

'dog',

'frog',

'horse',

'ship',

'truck'],

10,

ImageDataBunch;

Train: LabelList (39072 items)

x: ImageList

Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)

y: CategoryList

airplane,airplane,airplane,airplane,airplane

Path: C:\Users\Administrator\.fastai\data\cifar10;

Valid: LabelList (9767 items)

x: ImageList

Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)

y: CategoryList

airplane,deer,deer,deer,automobile

Path: C:\Users\Administrator\.fastai\data\cifar10;

Test: LabelList (10000 items)

x: ImageList

Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)

y: EmptyLabelList

,,,,

Path: C:\Users\Administrator\.fastai\data\cifar10)

创建训练器

# 创建learn

learn = cnn_learner(data, models.resnet50, metrics=[accuracy, error_rate], callback_fns=[ShowGraph, SaveModelCallback])

第一阶段训练

# 最佳学习率寻找

learn.lr_find(end_lr=1)

LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

# 画出学习率寻找曲线,给出建议学习率

learn.recorder.plot(suggestion=True)

# 根据学习率曲线得到max_lr,开始训练

learn.fit_one_cycle(cyc_len=15, max_lr=1.78e-2)

epochtrain_lossvalid_lossaccuracyerror_ratetime

0

1.074162

0.882136

0.709225

0.290775

01:17

1

0.824112

0.766163

0.740453

0.259547

01:16

2

0.811090

0.938345

0.707792

0.292208

01:16

3

0.799450

0.790665

0.733797

0.266203

01:16

4

0.763200

1.364758

0.752636

0.247364

01:18

5

0.693490

0.683559

0.776902

0.223098

01:16

6

0.673621

0.611799

0.800655

0.199345

01:16

7

0.665126

0.630715

0.796150

0.203850

01:16

8

0.612187

0.874567

0.826149

0.173851

01:16

9

0.563634

0.785189

0.820723

0.179277

01:16

10

0.515540

1.286271

0.829835

0.170165

01:21

11

0.485959

0.524455

0.840688

0.159312

01:16

12

0.444417

0.759944

0.842736

0.157264

01:17

13

0.419838

0.830482

0.845500

0.154500

01:17

14

0.421095

0.550606

0.844783

0.155217

01:16

Better model found at epoch 0 with val_loss value: 0.8821364045143127.

Better model found at epoch 1 with val_loss value: 0.7661632299423218.

Better model found at epoch 5 with val_loss value: 0.6835585832595825.

Better model found at epoch 6 with val_loss value: 0.6117991805076599.

Better model found at epoch 11 with val_loss value: 0.5244545340538025.

训练结果

# 计算和显示训练结果

show_result(learn)

Accuracy tensor(0.8407)

Error Rate tensor(0.1593)

# 保存训练模型

learn.save('stg1')

第二阶段训练

learn.load('stg1')

learn.unfreeze()

learn.lr_find(end_lr=1)

LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

learn.recorder.plot(suggestion=True)

learn.fit_one_cycle(15, slice(1e-6, 5e-5))

epochtrain_lossvalid_lossaccuracyerror_ratetime

0

0.444569

0.521828

0.840381

0.159619

01:26

1

0.427062

0.513434

0.840483

0.159517

01:27

2

0.430344

0.514867

0.846524

0.153476

01:23

3

0.421480

0.550527

0.845295

0.154705

01:23

4

0.410170

0.506949

0.847855

0.152145

01:23

5

0.402150

0.542091

0.849186

0.150814

01:26

6

0.387639

0.491120

0.850927

0.149073

01:27

7

0.373022

0.511580

0.852155

0.147845

01:28

8

0.375497

0.505493

0.854101

0.145899

01:28

9

0.355466

0.585425

0.852462

0.147538

01:28

10

0.355327

0.506402

0.855534

0.144466

01:28

11

0.341208

0.498502

0.855944

0.144057

01:29

12

0.347057

0.549146

0.851746

0.148254

01:28

13

0.345185

0.533962

0.852155

0.147845

01:28

14

0.334336

0.504231

0.855432

0.144568

01:29

Better model found at epoch 0 with val_loss value: 0.5218283534049988.

Better model found at epoch 1 with val_loss value: 0.5134344696998596.

Better model found at epoch 4 with val_loss value: 0.5069490671157837.

Better model found at epoch 6 with val_loss value: 0.491120308637619.

训练结果

# 计算和显示训练结果

show_result(learn)

Accuracy tensor(0.8509)

Error Rate tensor(0.1491)

保存模型

learn.save('stg2')

# 画训练结果的混合矩阵

interp = ClassificationInterpretation.from_learner(learn)

interp.confusion_matrix()

interp.plot_confusion_matrix(dpi=120)

显示预测错误次数最多的类型,错误次数大于5,输出顺序为actual, predicted, number of occurrences.

interp.most_confused(5)

[('bird', 'frog', 86),

('truck', 'automobile', 71),

('deer', 'frog', 66),

('dog', 'bird', 59),

('airplane', 'ship', 57),

('bird', 'airplane', 54),

('dog', 'frog', 54),

('bird', 'deer', 53),

('dog', 'deer', 50),

('cat', 'dog', 47),

('deer', 'bird', 47),

('automobile', 'truck', 45),

('ship', 'airplane', 45),

('cat', 'frog', 44),

('bird', 'dog', 37),

('ship', 'automobile', 34),

('ship', 'truck', 32),

('airplane', 'bird', 31),

('deer', 'dog', 26),

('frog', 'bird', 25),

('dog', 'cat', 24),

('dog', 'horse', 24),

('airplane', 'automobile', 23),

('horse', 'deer', 23),

('airplane', 'truck', 22),

('airplane', 'deer', 20),

('frog', 'deer', 17),

('cat', 'deer', 16),

('horse', 'dog', 14),

('automobile', 'ship', 13),

('deer', 'horse', 13),

('truck', 'ship', 13),

('bird', 'ship', 12),

('cat', 'bird', 12),

('deer', 'airplane', 12),

('dog', 'truck', 12),

('truck', 'airplane', 12),

('frog', 'dog', 11),

('airplane', 'frog', 10),

('deer', 'ship', 10),

('dog', 'airplane', 9),

('frog', 'automobile', 8),

('horse', 'frog', 8),

('ship', 'bird', 8),

('cat', 'truck', 7),

('horse', 'airplane', 7),

('horse', 'bird', 7),

('ship', 'deer', 7),

('dog', 'automobile', 6),

('truck', 'frog', 6),

('automobile', 'frog', 5),

('bird', 'cat', 5),

('bird', 'truck', 5),

('cat', 'ship', 5),

('dog', 'ship', 5),

('frog', 'airplane', 5)]

预测最困难的9个样本

# 模型预测最困难的9个样本显示

# 显示顺序为预测值、实际值、损失值、预测对的概率

interp.plot_top_losses(9, figsize=(10, 10))

你可能感兴趣的:(resnet152训练)