代码示例:
import torch
import torch.nn as nn
class a1(torch.nn.Module):
def __init__(self):
super(a1, self).__init__()
self.l1 = nn.Linear(3, 2)
class aa(a1):
def __init__(self):
super(aa, self).__init__()
self.a = ''
self.b = 'c'
self.l2 = nn.Linear(2, 2)
self.l3 = nn.Linear(2, 1)
def forward(self, x):
e1 = self.l1(x)
e2 = self.l2(e1)
e3 = self.l3(e2)
return e3
a = aa()
freeze_layers = ['l2', 'l3']
opt_param = []
for name, module in a._modules.items():
if name not in freeze_layers:
for p in module.parameters():
opt_param.append(p)
else:
for p in module.parameters():
p.requires_grad = False
print('original parameters:\n', list(a.parameters()))
x = torch.tensor([[1, 2, 3], [4, 5, 6], [-1, -2, -3], [-2, -4, -5]], dtype=torch.float)
y = torch.tensor([1, 1, -1, -1], dtype=torch.float)
y = y.view(4, -1)
y_ = a(x)
print('y_', y_)
celoss = nn.MSELoss()
loss = celoss(y, y_)
opt = torch.optim.Adam(opt_param)
loss.backward()
opt.step()
print('new parameters:\n', list(a.parameters()))
输出:
original parameters:
[Parameter containing:
tensor([[-0.0965, -0.3446, 0.0866],
[-0.1677, -0.2664, 0.5007]], requires_grad=True), Parameter containing:
tensor([ 0.2607, -0.3101], requires_grad=True), Parameter containing:
tensor([[0.5237, 0.1328],
[0.6994, 0.1014]]), Parameter containing:
tensor([ 0.2662, -0.1830]), Parameter containing:
tensor([[-0.0303, 0.1869]]), Parameter containing:
tensor([0.1692])]
y_ tensor([[ 0.1037],
[-0.0155],
[ 0.2007],
[ 0.2665]], grad_fn=
new parameters:
[Parameter containing:
tensor([[-0.0955, -0.3436, 0.0876],
[-0.1667, -0.2654, 0.5017]], requires_grad=True), Parameter containing:
tensor([ 0.2597, -0.3111], requires_grad=True), Parameter containing:
tensor([[0.5237, 0.1328],
[0.6994, 0.1014]]), Parameter containing:
tensor([ 0.2662, -0.1830]), Parameter containing:
tensor([[-0.0303, 0.1869]]), Parameter containing:
tensor([0.1692])]
可以看到l2和l3两层网络被冻结了