前面一篇博客介绍了在Fast AI
框架下训练模型所需的轮子——回调系统,本篇博客将介绍网络训练的方方面面,包括并不限于基础的训练方法(fit
系列方法),指标(metrics
)监控、网络推理与性能评估、模型保存等。
Learner
对象的fit()
与fit_one_cycle()
(文档链接)其中fit_one_cycle()
函数已在前一篇博客中介绍过,实际上,该函数就是在基础的训练流程上添加了OneCycleScheduler
功能。而对Learner
对象,基础的训练流程是由fit()
函数定义的,其接口如下,具体源码见fastai.basic_train.py
。
fit(epochs:int,
lr:Union[float, Collection[float], slice]=slice(None, 0.003, None),
wd:Floats=None, # 如果为None,则使用default.wd
callbacks:Collection[Callback]=None)
其中对参数lr
会做进一步处理,处理后会返回一个浮点数组,数组长度与Learner.layer_groups
一致,用于不同深度的网络层的差异化训练。根据传入的lr
的数据类型,会按照不同的方式生成lr
数组:
lr
为一个数值,那么返回的就是一个值全为lr
的数组。lr
为指定了start
的slice
对象(构建slice
对象需要3S
参数:start
、stop
、step
。如果只传入一个参数,那么指定的是stop
。),那么就会返回一个等比序列,序列的起始值为slice.start
,终止值为slice.stop
。lr
为只指定了stop
的slice
对象,那么除最后一个的lr
为stop
外,其余值均设为stop/10
。当Fast AI
使用优化器的包装类OptimWrapper
的对象进行迭代时(即调用step()
时),会按照lr
数组和Learner.layer_groups
的对应关系进行差异化的训练。
经常和fit()
或fit_one_cycle()
配合使用的是Learner
对象的freeze()
和unfreeze()
。这两个函数都是通过Learner
的freeze_to(n)
函数实现的,该函数可按Learner.layer_groups
的分组冻结前n
层的网络参数。若n=0
,则表示要调整所有的网络参数,即unfreeze()
的效果。
Fast AI
的metrics
(文档链接)metrics
接受模型的输出outputs
和以及数据标签targets
为参数,计算用于评估模型性能的指标。在使用时,可在构建Learner
对象时,以metrics
参数传入。而metrics
的调用则是由Learner
在fit
系列函数中使用CallbackHandler
进行统一管理的(类似于对Callback
的管理)。具体而言,CallbackHandler
会将各个metrics
函数封装成AverageMetric
回调类(如果已经是Callback
类了,则不进行这一封装)。该回调类主要涉及三个回调槽功能:
on_epoch_begin
: 每个epoch
开始时进行初始化,主要是初始化val
和count
参数,分别记录metric
的值和样本数。on_batch_end
: 调用所封装的metric
计算val
,对val
和count
进行累积。on_epoch_end
: 返回平均后的metric
的值,并更新CallbackHandler
的state_dict
中的last_metrics
项。而以Callback
形式提供的metric
,可用于计算不是平均意义上的观测指标,如监测每个epoch
使用的内存:
import tracemalloc
class TraceMallocMetric(Callback):
def __init__(self):
super().__init__()
self.name = "peak RAM"
def on_epoch_begin(self, **kwargs):
tracemalloc.start()
def on_epoch_end(self, last_metrics, **kwargs):
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
return add_metrics(last_metrics, torch.tensor(peak))
CallbackHandler
通过参数call_mets
来控制是否计算metrics
。当处于train
过程时,call_mets=False
,不计算metrics
。另外,每个metric
只输出一个性能指标,并且是以6
位浮点数的形式存在的。为绕过这些限制,可通过定义LearnerCallback
(而非Callback
对象),然后以callback_fn
参数传给Learner
。示例如下:
import tracemalloc
class TraceMallocMultiColMetric(LearnerCallback):
_order=-20 # Needs to run before the recorder
def __init__(self, learn):
super().__init__(learn)
self.train_max = 0
def on_train_begin(self, **kwargs):
self.learn.recorder.add_metric_names(['used', 'max_used', 'peak'])
def on_batch_end(self, train, **kwargs):
# track max memory usage during the train phase
if train:
current, peak = tracemalloc.get_traced_memory()
self.train_max = max(self.train_max, current)
def on_epoch_begin(self, **kwargs):
tracemalloc.start()
def on_epoch_end(self, last_metrics, **kwargs):
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
return add_metrics(last_metrics, [current, self.train_max, peak])
learn = cnn_learner(data, model, metrics=[accuracy], callback_fns=TraceMallocMultiColMetric)
如上述函数用于监视train
过程中网络的内存开销,并且添加了三列观测值。使用Learner
的callback_fns
参数而非callbacks
参数的原因是:对于那些需要Learner
对象做参数的Callback
(如LearnerCallback
类都需要依附于Learner
对象),当作为初始化参数传入Learner
构造函数时,Learner
对象还未创建。而Learner
对以callback_fns
传入的Callback
,会采用部分初始化的方法(即使用python
中的partial
函数),无需直接提供Learner
实例,而是在需要用到这些Callback
的地方(即在fit()
系列中),进行完整的初始化,之后,会将Learner.callback_fns
与Learner.callbacks
合并。
另一种方法是在Learner
实例构建之后,再将TraceMallocMultiColMetric
实例添加到Learner.callbacks
中:
learn = cnn_learner(data, model, metrics=[accuracy])
learn.callbacks.append(TraceMallocMultiColMetric(learn))
Fast AI
内置的metrics
定义在fastai.metrics.py
文件中。
Fast AI
中的数据可从Learner.data
中获取,如Learner.data.train_ds[0]
会返回一个两元素的元组,第一个元素为图像数据,第二个元素为其标签。如要对其中的一个元素进行推理,可采用Learner.predict()
方法。
y = learn.predict(learn.data.train_ds[0][0])
y
为一个三元组,第一个元素为模型输出的分类,第二个元素为该类对应的索引,第三个元素即为网络的输出。
可使用Learner
的pred_batch()
函数,其接口定义如下:
pred_batch(
ds_type:DatasetType=<DatasetType.Valid: 2>, # 默认对valid数据进行处理
batch:Tuple=None, reconstruct:bool=False,
with_dropout:bool=False, activ:Module=None) → List[Tensor]
将一个数据封装为batch
,进行预测:
batch = learn.data.one_item(item)
learn.pred_batch(batch=batch)
train
或valid
)的推理可以使用Learner
的get_preds()
函数,其接口定义如下:
get_preds(
ds_type:DatasetType=<DatasetType.Valid: 2>, # 对train数据:DatasetType.Train
activ:Module=None,
with_loss:bool=False, # 是否同时返回在每个样本上的loss
n_batch:Optional[int]=None, # 如果为None,则使用构建数据集时设定的batch size
pbar:Union[MasterBar, ProgressBar, NoneType]=None
) → List[Tensor]
当不设置with_loss
时,该函数的返回值为一个两元素元组:第一个元素是模型在所有样本上的输出,第二个是每个样本所属的类别的索引。
metrics
可以使用Learner
的validate()
函数,其接口定义如下:
validate(dl=None, # 数据加载器,默认使用Learner的valid_dl
callbacks=None, # 会和Learner.callbacks进行拼接
metrics=None # 若非空,则会替换掉Learner.metrics
)
# 示例
learn.validate(learn.data.valid_dl)
可使用Learner
对象的show_results()
函数:
show_results(ds_type=<DatasetType.Valid: 2>, rows:int=5, **kwargs)
默认显示5x5
的图像阵列。
ClassificationInterpretation
(文档链接)可使用Learner
的interpret()
方法,接口如下:
interpret(learn:Learner, ds_type:DatasetType=<DatasetType.Valid: 2>, tta=False)
# 示例
interp = learn.interpret()
ClassificationInterpretation
提供了很多有用的功能(具体参见fastai.train.py
文件):
top_losses(k:int=None, largest=True) # 若不指定k,则对所有样本进行排序。
k
张图像可视化:plot_top_losses(k, largest=True, figsize=(12, 12),
heatmap:bool=False, heatmap_thresh:int=16, # 是否显示热力图
alpha:float=0.6, cmap:str='magma', # 热力图的透明度
show_text:bool=True, return_fig:bool=None) → Optional[Figure]
confusion_matrix(slice_size:int=1)
plot_confusion_matrix(
normalize:bool=False, # 是否进行归一化
title:str='Confusion matrix', cmap:Any='Blues', slice_size:int=1, norm_dec:int=2, plot_txt:bool=True, return_fig:bool=None, **kwargs) → Optional[Figure]
most_confused(min_val:int=1, slice_size:int=1) → Collection[Tuple[str, str, int]]
# 其中返回的元组中的元素为:actual class, predicted class, misclassified_no
此外,Fast AI
对Segmentation
和Detection
等任务都定义了相应的Interpertation
类,在具体任务时再做论述。
save()
与load()
save(file:PathLikeOrBinaryStream=None, return_path:bool=False, with_opt:bool=True)
file
为一个绝对路径,那么就不会使用Learner.path
作为存储路径。load
:load(file:PathLikeOrBinaryStream=None, device:device=None, strict:bool=True, with_opt:bool=None, purge:bool=False, remove_module:bool=False) → Learner
export()
与load_learner()
export(file:PathLikeOrBinaryStream='export.pkl', destroy=False)
file
为一个绝对路径,那么就不会使用Learner.path
作为存储路径。destroy=True
,那么在存储模型的时候就会释放掉Learner
对象所占用的显存。load_learner
:load_learner(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', test:ItemList=None, tfm_y=None, **db_kwargs)
注意:与save()
不同的是,export()
会保存网络结构,因此在使用load_learner()
进行重载的时候,不需要显式生成Learner
对象。另外,如果原Learner
对象如果使用了自定义的模块,那么在重新加载的时候,要首先定义这些模块。