pytorch代码规范:加载预训练模型

1 加载预训练模型,并去除需要再次训练的层

model=resnet()#自己构建的模型,以resnet为例, 需要重新训练的层的名字要和之前的不同。
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

2 固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体

#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:

for k,v in model.named_parameters():
if k!='xxx.weight' and k!='xxx.bias' :
v.requires_grad=False#固定参数

3 训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

4 检查是否固定

for k,v in model.named_parameters():
if k!='xxx.weight' and k!='xxx.bias' :
print(v.requires_grad)#理想状态下,所有值都是False

你可能感兴趣的:(代码,pytorch)