FC中的BN(伪代码)

'''
全连接层中的batch normalization
'''
import torch
import torch.nn as nn
import copy
class Net(nn.Module):
    def __init__(self,dim,pretrained):
        super(Net,self).__init__()
        self.bn=nn.BatchNorm1d(dim,1)
        if pretrained:
            self.pretrained()
    def forward(self, input):
        return self.bn(input)
    def pretrained(self):
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)

def bn_train_fc(input,model):
    state_dict=model.state_dict()
    print(state_dict)
    # weights=state_dict.items()[0].view(1,-1).expand(input.shape)
    weights=torch.tensor([])
    bias=torch.tensor([])
    for k,v in state_dict.items():
        if k=='bn.weight':
            weights=v
        if k=='bn.bias':
            bias=v
    weights=weights.view(1,-1).expand(input.shape)
    bias=bias.view(1,-1).expand(input.shape)
    run_mean=torch.mean(input,dim=0).view(1,-1).expand(input.shape)
    run_var=torch.var(input,dim=0).view(1,-1).expand(input.shape)
    output=(input-run_mean)/(run_var+1e-5).sqrt()
    output=output*weights+bias

    print(run_mean,run_var)

    return output

def bn_test_fc(input,model):
    state_dict=model.state_dict()

    weights = torch.tensor([])
    bias = torch.tensor([])
    mean=torch.tensor([])
    val=torch.tensor([])
    for k, v in state_dict.items():
        if k == 'bn.weight':
            weights = v
        if k == 'bn.bias':
            bias = v
        if k=='bn.running_mean':
            mean=v
        if k=='bn.running_var':
            var=v

    weights=weights.view(1,-1).expand(input.shape)
    bias=bias.view(1,-1).expand(input.shape)

    run_mean=mean.view(1,-1).expand(input.shape)
    run_var=var.view(1,-1).expand(input.shape)

    output=(input-run_mean)/(run_var+1e-5).sqrt()
    output=output*weights+bias
    return output

if __name__=='__main__':
    model=Net(dim=5,pretrained=False)
    input=torch.randn((3,5))
    output=model(input)
    print(model.state_dict()['bn.running_mean'])
    print(model.state_dict()['bn.running_var'])
    '''
    全连接层的输入为  [batch_Size,num_dims]
    进行batch normalization时是对每个节点进行BN操作的,也就是batch size个数值求平均值和方差
    可以想象成:全连接层的每个节点就是CNN的一个卷积核(体现在CNN的一个通道)
    故而全连接层的BN是在每个节点上   over  batch size dimension 
    CNN的BN是在每个卷积核上          over batch H W dimension
    '''
    output2=bn_train_fc(input,model)
    print(output,output2)



你可能感兴趣的:(pytorch)