【pytorch】如何更新模型父类参数

更新方法例子如下:

import torch
import torch.nn as nn
class A(nn.Module):
    def __init__(self):
        super(A,self).__init__()
        self.l1=nn.Linear(2,2)
    def forward(self,x):
        return self.l1(x)

class B(A):
    def __init__(self):
        super(B,self).__init__()
        self.a=A()
        self.l2=nn.Linear(2,1)
    def forward(self,x):
        x=self.a(x)
        return self.l2(x)

class C(A):
    def __init__(self):
        super(C,self).__init__()
        self.l1=nn.Linear(2,2)
        self.l2=nn.Linear(2,1)
    def forward(self,x):
        x=self.l1(x)
        return self.l2(x)

from torch.optim import Adam
from torch.nn import MSELoss
x=torch.FloatTensor([1,2])

y=torch.FloatTensor([0.5])
b=B()
c=C()
p=[]
p.extend(b.parameters())
p.extend(b.a.parameters())
optimizer1 = Adam(p,lr=0.1)
optimizer2 = Adam(c.parameters(),lr=0.1)
loss_func = MSELoss()

print('======= b origin  ========')
print(b.a.l1.weight,'\n',b.l1.weight,'\n',b.l2.weight)
logit = b(x)
loss=loss_func(logit,y)
loss.backward()
optimizer1.step()

print('======= b after  ========')
print(b.a.l1.weight,'\n',b.l1.weight,'\n', b.l2.weight)

输出:

======= b origin  ========
Parameter containing:
tensor([[ 0.0655,  0.6051],
        [-0.4282, -0.2897]], requires_grad=True) 
 Parameter containing:
tensor([[ 0.6422, -0.1793],
        [-0.6748,  0.3853]], requires_grad=True) 
 Parameter containing:
tensor([[-0.0885,  0.1957]], requires_grad=True)
======= b after  ========
Parameter containing:
tensor([[-0.1345,  0.4051],
        [-0.2282, -0.0897]], requires_grad=True) 
 Parameter containing:
tensor([[ 0.6422, -0.1793],
        [-0.6748,  0.3853]], requires_grad=True) 
 Parameter containing:
tensor([[0.0115, 0.0957]], requires_grad=True)

你可能感兴趣的:(【pytorch】如何更新模型父类参数)