第七篇 FastAI模型训练

前面一篇博客介绍了在Fast AI框架下训练模型所需的轮子——回调系统,本篇博客将介绍网络训练的方方面面,包括并不限于基础的训练方法(fit系列方法),指标(metrics)监控、网络推理与性能评估、模型保存等。

一、Learner对象的fit()fit_one_cycle()(文档链接)

其中fit_one_cycle()函数已在前一篇博客中介绍过,实际上,该函数就是在基础的训练流程上添加了OneCycleScheduler功能。而对Learner对象,基础的训练流程是由fit()函数定义的,其接口如下,具体源码见fastai.basic_train.py

fit(epochs:int,
    lr:Union[float, Collection[float], slice]=slice(None, 0.003, None), 
    wd:Floats=None, # 如果为None,则使用default.wd
    callbacks:Collection[Callback]=None)

其中对参数lr会做进一步处理,处理后会返回一个浮点数组,数组长度与Learner.layer_groups一致,用于不同深度的网络层的差异化训练。根据传入的lr的数据类型,会按照不同的方式生成lr数组:

  • 如果lr为一个数值,那么返回的就是一个值全为lr的数组。
  • 如果lr为指定了startslice对象(构建slice对象需要3S参数:startstopstep。如果只传入一个参数,那么指定的是stop。),那么就会返回一个等比序列,序列的起始值为slice.start,终止值为slice.stop
  • 如果lr为只指定了stopslice对象,那么除最后一个的lrstop外,其余值均设为stop/10

Fast AI使用优化器的包装类OptimWrapper的对象进行迭代时(即调用step()时),会按照lr数组和Learner.layer_groups的对应关系进行差异化的训练。

经常和fit()fit_one_cycle()配合使用的是Learner对象的freeze()unfreeze()。这两个函数都是通过Learnerfreeze_to(n)函数实现的,该函数可按Learner.layer_groups的分组冻结前n层的网络参数。若n=0,则表示要调整所有的网络参数,即unfreeze()的效果。

二、Fast AImetrics(文档链接)

metrics接受模型的输出outputs和以及数据标签targets为参数,计算用于评估模型性能的指标。在使用时,可在构建Learner对象时,以metrics参数传入。而metrics的调用则是由Learnerfit系列函数中使用CallbackHandler进行统一管理的(类似于对Callback的管理)。具体而言,CallbackHandler会将各个metrics函数封装成AverageMetric回调类(如果已经是Callback类了,则不进行这一封装)。该回调类主要涉及三个回调槽功能:

  • (1) on_epoch_begin: 每个epoch开始时进行初始化,主要是初始化valcount参数,分别记录metric的值和样本数。
  • (2) on_batch_end: 调用所封装的metric计算val,对valcount进行累积。
  • (3) on_epoch_end: 返回平均后的metric的值,并更新CallbackHandlerstate_dict中的last_metrics项。

而以Callback形式提供的metric,可用于计算不是平均意义上的观测指标,如监测每个epoch使用的内存:

import tracemalloc
class TraceMallocMetric(Callback):
    def __init__(self):
        super().__init__()
        self.name = "peak RAM"

    def on_epoch_begin(self, **kwargs):
        tracemalloc.start()
        
    def on_epoch_end(self, last_metrics, **kwargs):
        current, peak =  tracemalloc.get_traced_memory()
        tracemalloc.stop()
        return add_metrics(last_metrics, torch.tensor(peak))

CallbackHandler通过参数call_mets来控制是否计算metrics。当处于train过程时,call_mets=False,不计算metrics。另外,每个metric只输出一个性能指标,并且是以6位浮点数的形式存在的。为绕过这些限制,可通过定义LearnerCallback(而非Callback对象),然后以callback_fn参数传给Learner。示例如下:

import tracemalloc
class TraceMallocMultiColMetric(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn):
        super().__init__(learn)
        self.train_max = 0

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['used', 'max_used', 'peak'])
            
    def on_batch_end(self, train, **kwargs):
        # track max memory usage during the train phase
        if train:
            current, peak =  tracemalloc.get_traced_memory()
            self.train_max = max(self.train_max, current)
        
    def on_epoch_begin(self, **kwargs):
        tracemalloc.start()

    def on_epoch_end(self, last_metrics, **kwargs):
        current, peak =  tracemalloc.get_traced_memory()
        tracemalloc.stop()
        return add_metrics(last_metrics, [current, self.train_max, peak])

learn = cnn_learner(data, model, metrics=[accuracy], callback_fns=TraceMallocMultiColMetric)

如上述函数用于监视train过程中网络的内存开销,并且添加了三列观测值。使用Learnercallback_fns参数而非callbacks参数的原因是:对于那些需要Learner对象做参数的Callback(如LearnerCallback类都需要依附于Learner对象),当作为初始化参数传入Learner构造函数时,Learner对象还未创建。而Learner对以callback_fns传入的Callback,会采用部分初始化的方法(即使用python中的partial函数),无需直接提供Learner实例,而是在需要用到这些Callback的地方(即在fit()系列中),进行完整的初始化,之后,会将Learner.callback_fnsLearner.callbacks合并。
另一种方法是在Learner实例构建之后,再将TraceMallocMultiColMetric实例添加到Learner.callbacks中:

learn = cnn_learner(data, model, metrics=[accuracy])
learn.callbacks.append(TraceMallocMultiColMetric(learn))

Fast AI内置的metrics定义在fastai.metrics.py文件中。

三、网络推理与性能评估

1.对单一数据的推理

Fast AI中的数据可从Learner.data中获取,如Learner.data.train_ds[0]会返回一个两元素的元组,第一个元素为图像数据,第二个元素为其标签。如要对其中的一个元素进行推理,可采用Learner.predict()方法。

y = learn.predict(learn.data.train_ds[0][0])

y为一个三元组,第一个元素为模型输出的分类,第二个元素为该类对应的索引,第三个元素即为网络的输出。

2.对一个batch的数据的推理

可使用Learnerpred_batch()函数,其接口定义如下:

pred_batch(
    ds_type:DatasetType=<DatasetType.Valid: 2>, # 默认对valid数据进行处理
    batch:Tuple=None, reconstruct:bool=False,
    with_dropout:bool=False, activ:Module=None) → List[Tensor]

将一个数据封装为batch,进行预测:

batch = learn.data.one_item(item)
learn.pred_batch(batch=batch)
3.对某一数据集(trainvalid)的推理

可以使用Learnerget_preds()函数,其接口定义如下:

get_preds(
    ds_type:DatasetType=<DatasetType.Valid: 2>, # 对train数据:DatasetType.Train
    activ:Module=None, 
    with_loss:bool=False, # 是否同时返回在每个样本上的loss
    n_batch:Optional[int]=None, # 如果为None,则使用构建数据集时设定的batch size
    pbar:Union[MasterBar, ProgressBar, NoneType]=None
) → List[Tensor]

当不设置with_loss时,该函数的返回值为一个两元素元组:第一个元素是模型在所有样本上的输出,第二个是每个样本所属的类别的索引。

4.对某一数据集计算metrics

可以使用Learnervalidate()函数,其接口定义如下:

validate(dl=None, # 数据加载器,默认使用Learner的valid_dl
    callbacks=None, # 会和Learner.callbacks进行拼接
    metrics=None # 若非空,则会替换掉Learner.metrics
)
# 示例
learn.validate(learn.data.valid_dl)
5.随机抽取图像进行结果的可视化

可使用Learner对象的show_results()函数:

show_results(ds_type=<DatasetType.Valid: 2>, rows:int=5, **kwargs)

默认显示5x5的图像阵列。

6.创建结果解析器ClassificationInterpretation(文档链接)

可使用Learnerinterpret()方法,接口如下:

interpret(learn:Learner, ds_type:DatasetType=<DatasetType.Valid: 2>, tta=False)
# 示例
interp = learn.interpret()

ClassificationInterpretation提供了很多有用的功能(具体参见fastai.train.py文件):

  • 按照损失值排序,返回排序后的损失值及索引(即返回两个数组组成的元组)。
    top_losses(k:int=None, largest=True) # 若不指定k,则对所有样本进行排序。
    
  • 将损失值最大的k张图像可视化:
    plot_top_losses(k, largest=True, figsize=(12, 12), 
        heatmap:bool=False, heatmap_thresh:int=16, # 是否显示热力图
        alpha:float=0.6, cmap:str='magma', # 热力图的透明度
        show_text:bool=True, return_fig:bool=None) → Optional[Figure]
    
  • 计算混淆矩阵:
    confusion_matrix(slice_size:int=1)
    
  • 绘制混淆矩阵:
    plot_confusion_matrix(
        normalize:bool=False, # 是否进行归一化
        title:str='Confusion matrix', cmap:Any='Blues', slice_size:int=1, norm_dec:int=2, plot_txt:bool=True, return_fig:bool=None, **kwargs) → Optional[Figure]
    
  • 找出最容易混淆的类:
    most_confused(min_val:int=1, slice_size:int=1) → Collection[Tuple[str, str, int]]
    # 其中返回的元组中的元素为:actual class, predicted class, misclassified_no
    

此外,Fast AISegmentationDetection等任务都定义了相应的Interpertation类,在具体任务时再做论述。

四、模型的保存与重载

1.save()load()
save(file:PathLikeOrBinaryStream=None, return_path:bool=False, with_opt:bool=True)
  • 如果file为一个绝对路径,那么就不会使用Learner.path作为存储路径。
  • 所保存的只是模型的参数,并不能保存模型的结构。
  • 对应的重载函数是load
load(file:PathLikeOrBinaryStream=None, device:device=None, strict:bool=True, with_opt:bool=None, purge:bool=False, remove_module:bool=False) → Learner
2.export()load_learner()
export(file:PathLikeOrBinaryStream='export.pkl', destroy=False)
  • 如果file为一个绝对路径,那么就不会使用Learner.path作为存储路径。
  • 如果设置destroy=True,那么在存储模型的时候就会释放掉Learner对象所占用的显存。
  • 对应的重载函数为load_learner
load_learner(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', test:ItemList=None, tfm_y=None, **db_kwargs)

注意:与save()不同的是,export()会保存网络结构,因此在使用load_learner()进行重载的时候,不需要显式生成Learner对象。另外,如果原Learner对象如果使用了自定义的模块,那么在重新加载的时候,要首先定义这些模块。

一些有用的链接

  • 关于热力图的论文。

你可能感兴趣的:(Fast,AI文档,深度学习,计算机视觉)