6. Pytorch入门教程——创建一个派生自基类的迁移学习类

现在,我们准备创建从基类派生的迁移学习类。实现迁移学习现在会变得非常简单,因为所有需要的机制都已经到位了。

迁移学习给予torchvision.models模块,包含支持下载和使用几个预先训练的计算机视觉网络架构。我们将添加三个模型的支持:

  • Densenet121,简称Densenet
  • Resent34和ResNet50

我们可以选择使用模型的预训练版本(默认 pretrained = True),获得架构+权重,或者仅仅是没有权重的架构,然后从头开始训练它们。

  • 大多数预训练版本在torchvision.models上,已经在ImageNet上训练了1000个输出类;
  • 我们希望使所选的模型适用我们的用例。例如,对于CIFAR10,我们只需要10个类,因此我们的输出应该设置为10而不是1000;
  • 每个模型由两部分组成:
  1. 卷积网络模型(一个CNN架构,包含不同数量的filters, non-linearities, max或average pooling layers, batch normalizations, dropout层等的含几个块);
  2. 在输出端有全连接的分类器的头。
  • 在大多数情况下,输出层没有任何全连接的隐藏层;
  • 但是,我们可以选择用我们自己的分类器层替换分类器层,并通过用我们自己的输出层替换输出层来添加更多的隐藏层。我们可以很容易地使用我们自己的FC类(在本教程前面定义)来实现这个目的;
  • 另一方面,我们可以选择只改变输出的数量,而不添加任何额外的隐藏层;
  • 今后,我们将使用我们自己的FC类,并将原始模型的输出层替换为FC对象。这将使我们能够灵活地传递任何额外的隐藏层。

我们的新类“TransferNetworkImg”(从基类“Network”派生而来)的代码非常简单。你只需要注意两个功能:

  • set_transfer_model设置来自torchvision.models的迁移模型;
  • set_model_head,在删除原始分类器或FC层后,在模型上设置FC层。

一、设置分类器

  • 注意,每个模型头部的分类器在每个torchvision模型中命名不同。不过,还有更好的方法来处理它,比如在文件中使用预定义的字典并加载它,然后查找每个模型类型的classifier字段以获得输出层的名称;
  • 但是,在下面的代码中,我们只是使用了简单的if else语句。也可以通过创建这样一个字典来创建您自己版本的calss;
  • 在set_model_head中设置FC模型。因为我们需要调用FC构造函数,所以我们需要传递成功创建FC类对象所需的任何东西。我们通过传递“head”字典到我们的迁移学习类来做到这一点;
  • 为了成功地创建一个FC模型,我们需要向其构造函数传递固定数量的输入,因为这是用于FC网络nn.Linear的要求。幸运的是,Pytorch中的nn.Linear类将它的输入数量存储在一个名为“in_features”的属性中。我们可以从迁移的模型(Densenet、Resnet等)中的原始分类器层中获取它,并将其作为参数传递给我们的FC构造函数。

二、Freezing和un-freezing层

  • 在使用迁移学习模型时,重要的是要决定是否从头开始对所有层(包括卷积和完全连接)进行重新训练。对于像CIFAR10这样相当大的数据集,对整个网络进行再训练是有意义的;
  • 但是,请注意,重新训练所有层并不意味着我们将从随机权重开始。我们仍然会从每一层的预训练权值开始,然后继续,但是我们会计算所有层的梯度并更新所有权值。因此,换句话说,模型开始学习的同时,保留它在以前的数据集(大多数情况下是ImageNet)上训练时获得的识别图像的知识;
  • 所以这就像一个孩子,在一个特定的事情上训练,我们不希望扔掉所有的知识,继续学习新的数据;
  • 另一方面,我们可能想要保持主干网络中冻结的权重,而只再训练头。这是一种常见的场景,当我们在新的数据上对网络进行了一段时间的训练,现在我们的主干既知道ImageNet,也知道我们的新数据集(在我们的示例中是CIFAR10);
  • 在最后一种情况下,我们可能只想做预测,并希望包括主干和头部所有的权重保持冻结状态。这只适用于预测和评估,而不适用于训练,因为如果我们不做反向传播和更新,那么训练就没有意义。

我们编写了一个函数来冻结权值,同时使用Pytorch张量的requires_grad标志使头部在默认情况下保持不冻结。这个标志在所有张量中都可用,对于权重张量,我们希望将其设置为真或假(可以通过从nn.Module派生的任何模型的parameters()方法获得)。

在我们的基类中增加Freeze和Unfreeze支持

class Network(nn.Module):
    ...
    '''
    增加模型所有层中的freeze and unfreeze参数
    '''
    def freeze(self):
        for param in self.model.parameters():
            param.requires_grad = False
        
        
    def unfreeze(self):
        for param in self.model.parameters():
            param.requires_grad = True

三、迁移学习类

from torchvision import models

class TransferNetworkImg(Network):
    def __init__(self,
                 model_name='DenseNet',
                 lr=0.003,
                 criterion_name ='NLLLoss',
                 optimizer_name = 'Adam',
                 dropout_p=0.2,
                 pretrained=True,
                 device=None,
                 best_accuracy=0.,
                 best_accuracy_file ='best_accuracy.pth',
                 chkpoint_file ='chkpoint_file',
                 head={}):

        
        super().__init__(device=device)
        
        self.model_type = 'transfer'
        
        self.set_transfer_model(model_name,pretrained=pretrained)    
        
        if head is not None:
            self.set_model_head(model_name = model_name,
                                 head = head,
                                 optimizer_name = optimizer_name,
                                 criterion_name = criterion_name,
                                 lr = lr,
                                 dropout_p = dropout_p,
                                 device = device
                                )
            
        self.set_model_params(criterion_name,
                              optimizer_name,
                              lr,
                              dropout_p,
                              model_name,
                              best_accuracy,
                              best_accuracy_file,
                              chkpoint_file,
                              head)
            
    
    '''
    set_model_params调用之前的父类,设置这个类另外的属性
    (head和model_type设置为'transfer')
    '''
    def set_model_params(self,criterion_name,
                         optimizer_name,
                         lr,
                         dropout_p,
                         model_name,
                         best_accuracy,
                         best_accuracy_file,
                         chkpoint_file,
                         head):
        
        print('Transfer: best accuracy = {:.3f}'.format(best_accuracy))
        
        super(TransferNetworkImg, self).set_model_params(
                                              criterion_name,
                                              optimizer_name,
                                              lr,
                                              dropout_p,
                                              model_name,
                                              best_accuracy,
                                              best_accuracy_file,
                                              chkpoint_file
                                              )

        '''
        我们还设置了模型的头端,以便在需要时创建FC层
        '''
        self.head = head
        
        '''
        model_type是transfer而不是classifier
        '''
        self.model_type = 'transfer'
        

    def forward(self,x):
        return self.model(x)
        
    def get_model_params(self):
        params = super(TransferNetworkImg, self).get_model_params()
        params['head'] = self.head
        params['model_type'] = self.model_type
        params['device'] = self.device
        return params
    
    '''
    Freeze首先调用基类的freeze()方法来冻结模型的所有参数,我们已经将这个方法添加到网络类中(参见下面的内容),然后根据传递的标志解除冻结头部的(classifier属性)参     数。注意,我们调用头部作为分类器。如果我们想要在将来处理回归的情况,还得添加更多的代码。我们已经向基类适当地添加了两个方法,freeze和unfreeze
    '''
    def freeze(self,train_classifier=True):
        super(TransferNetworkImg, self).freeze()
        if train_classifier:
            for param in self.model.classifier.parameters():
                param.requires_grad = True
            
                
    def set_transfer_model(self,mname,pretrained=True):   
        self.model = None
        if mname.lower() == 'densenet':
            self.model = models.densenet121(pretrained=pretrained)
            
        elif mname.lower() == 'resnet34':
            self.model = models.resnet34(pretrained=pretrained)
            
        elif mname.lower() == 'resnet50':
            self.model = models.resnet50(pretrained=pretrained)
              
        if self.model is not None:
            print('set_transfer_model: self.Model set to {}'.format(mname))
        else:
            print('set_transfer_model:Model {} not supported'.format(mname))
            
    
    '''
    set_model_head使用head字典调用FC,它从原始模型的分类器或fc层的适当属性中获取in_features。
    
    我们需要检查模型是否保存并从检查点加载,因为在后一种情况下,模型的头端对象应该具有num_input属性,而不是原始的in_features,因为它将包含我们的FC对象,而该对象
    使用num_input替代in_features。
    '''
    def set_model_head(self,
                        model_name = 'DenseNet',
                        head = {'num_inputs':128,
                                'num_outputs':10,
                                'layers':[],
                                'class_names':{}
                               },
                         optimizer_name = 'Adam',
                         criterion_name = 'NLLLoss',
                         lr = 0.003,
                         dropout_p = 0.2,
                         device = None):
        
        self.num_outputs = head['num_outputs']
        
        if model_name.lower() == 'densenet':
            if hasattr(self.model,'classifier'):
                in_features =  self.model.classifier.in_features
            else:
                in_features = self.model.classifier.num_inputs
                
            self.model.classifier = FC(num_inputs=in_features,
                                       num_outputs=head['num_outputs'],
                                       layers = head['layers'],
                                       class_names = head['class_names'],
                                       non_linearity = head['non_linearity'],
                                       model_type = head['model_type'],
                                       model_name = head['model_name'],
                                       dropout_p = dropout_p,
                                       optimizer_name = optimizer_name,
                                       lr = lr,
                                       criterion_name = criterion_name,
                                       device=device
                                      )
            
        elif model_name.lower() == 'resnet50' or model_name.lower() == 'resnet34':
            if hasattr(self.model,'fc'):
                in_features =  self.model.fc.in_features
            else:
                in_features = self.model.fc.num_inputs
                
            self.model.fc = FC(num_inputs=in_features,
                               num_outputs=head['num_outputs'],
                               layers = head['layers'],
                               class_names = head['class_names'],
                               non_linearity = head['non_linearity'],
                               model_type = head['model_type'],
                               model_name = head['model_name'],
                               dropout_p = dropout_p,
                               optimizer_name = optimizer_name,
                               lr = lr,
                               criterion_name = self.criterion_name,
                               device=device
                              )
         
        self.head = head
        
        print('{}: setting head: inputs: {} hidden:{} outputs: {}'.format(model_name,
                                                                   in_features,
                                                                   head['layers'],
                                                                   head['num_outputs']))
    
    def _get_dropout(self):
        if self.model_name.lower() == 'densenet':
            return self.model.classifier._get_dropout()
        
        elif self.model_name.lower() == 'resnet50' or self.model_name.lower() == 'resnet34':
            return self.model.fc._get_dropout()
        
            
    def _set_dropout(self,p=0.2):
        
        if self.model_name.lower() == 'densenet':
            if self.model.classifier is not None:
                print('DenseNet: setting head (FC) dropout prob to {:.3f}'.format(p))
                self.model.classifier._set_dropout(p=p)
                
        elif self.model_name.lower() == 'resnet50' or self.model_name.lower() == 'resnet34':
            if self.model.fc is not None:
                print('ResNet: setting head (FC) dropout prob to {:.3f}'.format(p))
                self.model.fc._set_dropout(p=p)

向load_chkpoint实用程序添加对迁移学习模型的支持

  • 我们需要在load_chkpoint函数中添加TransferNetworkImg;
  • 主要的新增功能是头和其他参数的存储和检索,还增加了对将检索到的头传递到构造函数的支持。
def load_chkpoint(chkpoint_file):
        
    restored_data = torch.load(chkpoint_file)

    params = restored_data['params']
    print('load_chkpoint: best accuracy = {:.3f}'.format(params['best_accuracy']))  
    
    if params['model_type'].lower() == 'classifier':
        net = FC( num_inputs=params['num_inputs'],
                  num_outputs=params['num_outputs'],
                  layers=params['layers'],
                  device=params['device'],
                  criterion_name = params['criterion_name'],
                  optimizer_name = params['optimizer_name'],
                  model_name = params['model_name'],
                  lr = params['lr'],
                  dropout_p = params['dropout_p'],
                  best_accuracy = params['best_accuracy'],
                  best_accuracy_file = params['best_accuracy_file'],
                  chkpoint_file = params['chkpoint_file'],
                  class_names =  params['class_names']
          )
    elif params['model_type'].lower() == 'transfer':
        net = TransferNetworkImg(criterion_name = params['criterion_name'],
                                 optimizer_name = params['optimizer_name'],
                                 model_name = params['model_name'],
                                 lr = params['lr'],
                                 device=params['device'],
                                 dropout_p = params['dropout_p'],
                                 best_accuracy = params['best_accuracy'],
                                 best_accuracy_file = params['best_accuracy_file'],
                                 chkpoint_file = params['chkpoint_file'],
                                 head = params['head']
                               )

    net.load_state_dict(torch.load(params['best_accuracy_file']))

    net.to(params['device'])
    
    return net

你可能感兴趣的:(Pytorch入门教程,神经网络,机器学习,pytorch,深度学习)