最近开始学习fastai,感觉这是一个对于初学者比较友好的库,官方提供了详细的视频教程和代码。
这个笔记基于我对于课程内容和代码的理解,对一些代码的使用进行一些说明,也是为了帮助我更好的理解课程的内容。
代码笔记会按照课程来完成,每课一篇。
课程代码:https://course.fast.ai/start_kaggle.html
首先使用resNet34网络进行训练和测试
#下面的代码用于修改代码后自动重载%aimport下的模块
%reload_ext autoreload
%autoreload 2
#这句代码一般用于jupyter,是的matplotlib直接在python的console下显示图像
%matplotlib inline
from fastai import *
from fastai.vision import *
#这个参数用于在之后将数据分批的时候设置每个批次的大小,当内存不足时可以将这个值减小
bs = 64
# bs = 16 # uncomment this line if you run out of memory even after clicking Kernel->Restart
#help可以查看具体方法的参数和作用
help(untar_data)
#从url下载数据并解压
path = untar_data(URLs.PETS); path
path.ls()
#分别获取图片路径和注解路径
path_anno = path/'annotations'
path_img = path/'images'
#从图片路径读取图片文件
fnames = get_image_files(path_img)
fnames[:5]
#这里是设置随机种子数,应该是为了保证每次训练的参数一致
np.random.seed(2)
#从文件名获取图片标签的正则
pat = re.compile(r'/([^/]+)_\d+.jpg$')
# 将图片根据文件名标记,并通过get_transforms()函数进行变换,使得图片归一化为224大小的图片
# 正规化图片,使得图片的个个通道下的数值基于0到255且不过亮或过暗
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs, num_workers=0).normalize(imagenet_stats)
data.show_batch(rows=3, figsize=(7,6))
print(data.classes)
#data.c返回数据的类别数
len(data.classes),data.c
#创建resnet34模型的cnn网络
learn = create_cnn(data, models.resnet34, metrics=error_rate)
#fit_one_cycle是一种通过逐渐增大学习率进行学习的方式,这里传入的数字是学习的轮数epoch
learn.fit_one_cycle(4)
#保存模型
learn.save('stage-1')
#获取最多分类错误的数据的混淆矩阵
interp = ClassificationInterpretation.from_learner(learn)
#获取分错的最多的数据以及这些数据的
losses,idxs = interp.top_losses()
len(data.valid_ds)==len(losses)==len(idxs)
#展示预测错误最多的图片
interp.plot_top_losses(9, figsize=(15,11), heatmap=False)
#显示函数的文档
doc(interp.plot_top_losses)
#展示预测混淆矩阵
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)
#展示分错最多的混淆矩阵
interp.most_confused(min_val=2)
#将训练好的模型的参数解冻,使其可以再次训练
learn.unfreeze()
learn.fit_one_cycle(1)
#读取之前训练的模型
learn.load('stage-1');
#寻找学习率,展示学习率与误差的关系,用以寻找较好的学习率区间
learn.lr_find()
learn.recorder.plot()
#尝试使用误差较小的学习率区间
learn.unfreeze()
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))
下面采用resNet50网络进行训练
#重新构建数据集,尺寸大小为299,每个批次的尺寸为原来的一般,并且正规化图像
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(),
size=299, bs=bs//2, num_workers=0).normalize(imagenet_stats)
#构建resNet50层的网络
learn = create_cnn(data, models.resnet50, metrics=error_rate)
#寻找学习率,并展示不同学习率下的结果
learn.lr_find()
learn.recorder.plot()
#进行8轮学习
learn.fit_one_cycle(8)
#保存当前模型
learn.save('stage-1-50')
#将参数解冻,并且使用调整过的学习率进行3轮训练
learn.unfreeze()
learn.fit_one_cycle(3, max_lr=slice(1e-6,1e-4))
#读取之前保存的模型
learn.load('stage-1-50');
#获取分类结果中错误程度最严重的结果的混淆矩阵,。展示混淆矩阵中错误超过2个的分错统计,比如将灰熊(tag=1)分成黑熊(tag=2)的有10张,就有[1,2,10]
interp = ClassificationInterpretation.from_learner(learn)
interp.most_confused(min_val=2)
最后还介绍了其他数据格式构建数据集的方法,课程中用了手写数字的数据集MNIST
#首先下载并解压数据集
path = untar_data(URLs.MNIST_SAMPLE); path
#从文件夹获取数据集,对数据集进行变换,尺寸为26
tfms = get_transforms(do_flip=False)
data = ImageDataBunch.from_folder(path, ds_tfms=tfms, size=26, num_workers=0)
#显示变换后的数据
data.show_batch(rows=3, figsize=(5,5))
#创建resnet18网络,并且使用默认的学习率训练2轮
learn = create_cnn(data, models.resnet18, metrics=accuracy)
learn.fit(2)
#读取标签csv文件,并显示前五行
df = pd.read_csv(path/'labels.csv')
df.head()
#从csv文件获取数据集,并且进行变换,尺寸为28
data = ImageDataBunch.from_csv(path, ds_tfms=tfms, size=28, num_workers=0)
#显示前三行数据,并显示所有类别
data.show_batch(rows=3, figsize=(5,5))
data.classes
#从df数据文件中获取数据集,变换尺寸为24,并展示类别
data = ImageDataBunch.from_df(path, df, ds_tfms=tfms, size=24, num_workers=0)
data.classes
#从df文件中查找文件路径,并且保存在fn_paths中
fn_paths = [path/name for name in df['name']]; fn_paths[:2]
#定义正则,并且通过文件名匹配正则获取数据集,并且变换为尺寸24,最后显示类别
pat = r"/(\d)/\d+\.png$"
data = ImageDataBunch.from_name_re(path, fn_paths, pat=pat, ds_tfms=tfms, size=24, num_workers=0)
data.classes
#通过lambda方程来匹配文件名构建数据集,尺寸变换为24,并显示类别
data = ImageDataBunch.from_name_func(path, fn_paths, ds_tfms=tfms, size=24,
label_func = lambda x: '3' if '/3/' in str(x) else '7', num_workers=0)
data.classes
#通过文件名来标记数据的标签
labels = [('3' if '/3/' in str(x) else '7') for x in fn_paths]
labels[:5]
#通过给定的标签列表和文件路径列表构建对应标签的数据集,尺寸变换为24,最后显示类别
data = ImageDataBunch.from_lists(path, fn_paths, labels=labels, ds_tfms=tfms, size=24, num_workers=0)
data.classes