现在,我们准备创建从基类派生的迁移学习类。实现迁移学习现在会变得非常简单,因为所有需要的机制都已经到位了。
迁移学习给予torchvision.models模块,包含支持下载和使用几个预先训练的计算机视觉网络架构。我们将添加三个模型的支持:
我们可以选择使用模型的预训练版本(默认 pretrained = True),获得架构+权重,或者仅仅是没有权重的架构,然后从头开始训练它们。
我们的新类“TransferNetworkImg”(从基类“Network”派生而来)的代码非常简单。你只需要注意两个功能:
我们编写了一个函数来冻结权值,同时使用Pytorch张量的requires_grad标志使头部在默认情况下保持不冻结。这个标志在所有张量中都可用,对于权重张量,我们希望将其设置为真或假(可以通过从nn.Module派生的任何模型的parameters()方法获得)。
在我们的基类中增加Freeze和Unfreeze支持
class Network(nn.Module):
...
'''
增加模型所有层中的freeze and unfreeze参数
'''
def freeze(self):
for param in self.model.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.model.parameters():
param.requires_grad = True
from torchvision import models
class TransferNetworkImg(Network):
def __init__(self,
model_name='DenseNet',
lr=0.003,
criterion_name ='NLLLoss',
optimizer_name = 'Adam',
dropout_p=0.2,
pretrained=True,
device=None,
best_accuracy=0.,
best_accuracy_file ='best_accuracy.pth',
chkpoint_file ='chkpoint_file',
head={}):
super().__init__(device=device)
self.model_type = 'transfer'
self.set_transfer_model(model_name,pretrained=pretrained)
if head is not None:
self.set_model_head(model_name = model_name,
head = head,
optimizer_name = optimizer_name,
criterion_name = criterion_name,
lr = lr,
dropout_p = dropout_p,
device = device
)
self.set_model_params(criterion_name,
optimizer_name,
lr,
dropout_p,
model_name,
best_accuracy,
best_accuracy_file,
chkpoint_file,
head)
'''
set_model_params调用之前的父类,设置这个类另外的属性
(head和model_type设置为'transfer')
'''
def set_model_params(self,criterion_name,
optimizer_name,
lr,
dropout_p,
model_name,
best_accuracy,
best_accuracy_file,
chkpoint_file,
head):
print('Transfer: best accuracy = {:.3f}'.format(best_accuracy))
super(TransferNetworkImg, self).set_model_params(
criterion_name,
optimizer_name,
lr,
dropout_p,
model_name,
best_accuracy,
best_accuracy_file,
chkpoint_file
)
'''
我们还设置了模型的头端,以便在需要时创建FC层
'''
self.head = head
'''
model_type是transfer而不是classifier
'''
self.model_type = 'transfer'
def forward(self,x):
return self.model(x)
def get_model_params(self):
params = super(TransferNetworkImg, self).get_model_params()
params['head'] = self.head
params['model_type'] = self.model_type
params['device'] = self.device
return params
'''
Freeze首先调用基类的freeze()方法来冻结模型的所有参数,我们已经将这个方法添加到网络类中(参见下面的内容),然后根据传递的标志解除冻结头部的(classifier属性)参 数。注意,我们调用头部作为分类器。如果我们想要在将来处理回归的情况,还得添加更多的代码。我们已经向基类适当地添加了两个方法,freeze和unfreeze
'''
def freeze(self,train_classifier=True):
super(TransferNetworkImg, self).freeze()
if train_classifier:
for param in self.model.classifier.parameters():
param.requires_grad = True
def set_transfer_model(self,mname,pretrained=True):
self.model = None
if mname.lower() == 'densenet':
self.model = models.densenet121(pretrained=pretrained)
elif mname.lower() == 'resnet34':
self.model = models.resnet34(pretrained=pretrained)
elif mname.lower() == 'resnet50':
self.model = models.resnet50(pretrained=pretrained)
if self.model is not None:
print('set_transfer_model: self.Model set to {}'.format(mname))
else:
print('set_transfer_model:Model {} not supported'.format(mname))
'''
set_model_head使用head字典调用FC,它从原始模型的分类器或fc层的适当属性中获取in_features。
我们需要检查模型是否保存并从检查点加载,因为在后一种情况下,模型的头端对象应该具有num_input属性,而不是原始的in_features,因为它将包含我们的FC对象,而该对象
使用num_input替代in_features。
'''
def set_model_head(self,
model_name = 'DenseNet',
head = {'num_inputs':128,
'num_outputs':10,
'layers':[],
'class_names':{}
},
optimizer_name = 'Adam',
criterion_name = 'NLLLoss',
lr = 0.003,
dropout_p = 0.2,
device = None):
self.num_outputs = head['num_outputs']
if model_name.lower() == 'densenet':
if hasattr(self.model,'classifier'):
in_features = self.model.classifier.in_features
else:
in_features = self.model.classifier.num_inputs
self.model.classifier = FC(num_inputs=in_features,
num_outputs=head['num_outputs'],
layers = head['layers'],
class_names = head['class_names'],
non_linearity = head['non_linearity'],
model_type = head['model_type'],
model_name = head['model_name'],
dropout_p = dropout_p,
optimizer_name = optimizer_name,
lr = lr,
criterion_name = criterion_name,
device=device
)
elif model_name.lower() == 'resnet50' or model_name.lower() == 'resnet34':
if hasattr(self.model,'fc'):
in_features = self.model.fc.in_features
else:
in_features = self.model.fc.num_inputs
self.model.fc = FC(num_inputs=in_features,
num_outputs=head['num_outputs'],
layers = head['layers'],
class_names = head['class_names'],
non_linearity = head['non_linearity'],
model_type = head['model_type'],
model_name = head['model_name'],
dropout_p = dropout_p,
optimizer_name = optimizer_name,
lr = lr,
criterion_name = self.criterion_name,
device=device
)
self.head = head
print('{}: setting head: inputs: {} hidden:{} outputs: {}'.format(model_name,
in_features,
head['layers'],
head['num_outputs']))
def _get_dropout(self):
if self.model_name.lower() == 'densenet':
return self.model.classifier._get_dropout()
elif self.model_name.lower() == 'resnet50' or self.model_name.lower() == 'resnet34':
return self.model.fc._get_dropout()
def _set_dropout(self,p=0.2):
if self.model_name.lower() == 'densenet':
if self.model.classifier is not None:
print('DenseNet: setting head (FC) dropout prob to {:.3f}'.format(p))
self.model.classifier._set_dropout(p=p)
elif self.model_name.lower() == 'resnet50' or self.model_name.lower() == 'resnet34':
if self.model.fc is not None:
print('ResNet: setting head (FC) dropout prob to {:.3f}'.format(p))
self.model.fc._set_dropout(p=p)
向load_chkpoint实用程序添加对迁移学习模型的支持
def load_chkpoint(chkpoint_file):
restored_data = torch.load(chkpoint_file)
params = restored_data['params']
print('load_chkpoint: best accuracy = {:.3f}'.format(params['best_accuracy']))
if params['model_type'].lower() == 'classifier':
net = FC( num_inputs=params['num_inputs'],
num_outputs=params['num_outputs'],
layers=params['layers'],
device=params['device'],
criterion_name = params['criterion_name'],
optimizer_name = params['optimizer_name'],
model_name = params['model_name'],
lr = params['lr'],
dropout_p = params['dropout_p'],
best_accuracy = params['best_accuracy'],
best_accuracy_file = params['best_accuracy_file'],
chkpoint_file = params['chkpoint_file'],
class_names = params['class_names']
)
elif params['model_type'].lower() == 'transfer':
net = TransferNetworkImg(criterion_name = params['criterion_name'],
optimizer_name = params['optimizer_name'],
model_name = params['model_name'],
lr = params['lr'],
device=params['device'],
dropout_p = params['dropout_p'],
best_accuracy = params['best_accuracy'],
best_accuracy_file = params['best_accuracy_file'],
chkpoint_file = params['chkpoint_file'],
head = params['head']
)
net.load_state_dict(torch.load(params['best_accuracy_file']))
net.to(params['device'])
return net