pytorch模型

假设有一个模型为conv + bn + relu :

class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True)
        )

来看一下下面的代码输出:

    model = ConvBNReLU(3, 32)  # 卷积 + bn + relu
    i = 0
    for key, weight in model.named_parameters():  #
        print(key, weight.shape) 
        # weight.requires_grad = False  #可以用来设置是否更新参数
        i += 1
    print('数量',i)
    '''
    输出:
    0.weight torch.Size([32, 3, 3, 3])
    1.weight torch.Size([32])
    1.bias torch.Size([32])
    数量 3
    '''


    state = model.state_dict()  
    j = 0
    for key in state:
        print(key,state[key].shape)
        j += 1
    print('数量',j)
    '''
    输出:
    0.weight torch.Size([32, 3, 3, 3])
    1.weight torch.Size([32])
    1.bias torch.Size([32])
    1.running_mean torch.Size([32])
    1.running_var torch.Size([32])
    1.num_batches_tracked torch.Size([])
    数量 6
    '''
    
    k = 0
    for content in list(model.parameters()):  #
        print(content.shape)
        k += 1
    print('数量', k)
    '''
    torch.Size([32, 3, 3, 3])
    torch.Size([32])
    torch.Size([32])
    数量 3
    '''

可以看到:

  • model.named_parameters():是对应的字典形式, key是参数名称,储存网络需要反向传播的参数
  • model.state_dict():储存网络整体参数,包括:需要反向传播训练的参数、仅仅需要向前向前传播的参数。这里仅仅需要向前传播的参数的参数主要是bn层的均值方差等:1.running_mean torch.Size([32])、1.running_var torch.Size([32])、1.num_batches_tracked torch.Size([])
  • model.parameters():储存网络需要训练(反向传播)的参数,一般会在定义optimizer的时候用到。可以看到它与model.named_parameters()所表示的是一样的。没有bn层的均值、方差等参数,因为这些参数只需要向前传播,无需反向传播。

补充一下,上面的参数中:

1.running_mean torch.Size([32])、1.running_var torch.Size([32])、1.num_batches_tracked torch.Size([])表示的是bn层均值方差,只需要向前传播

1.weight torch.Size([32]),1.bias torch.Size([32])表示的是bn层的scale和bias参数,这个是需要训练的

你可能感兴趣的:(pytorch,pytorch,python,深度学习)