【LayerNorm 2d】 LayerNorm图解, torch代码实现, 用法

1. LayerNorm, BN 对比

【LayerNorm 2d】 LayerNorm图解, torch代码实现, 用法_第1张图片

2. LayerNorm2d

class LayerNorm2d(nn.Module):
    def __init__(self,
                 embed_dim,
                 eps=1e-6,
                 data_format="channels_last") -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.weight = nn.parameter.Parameter(torch.ones(embed_dim))
        self.bias = nn.parameter.Parameter(torch.zeros(embed_dim))

        self.eps = eps
        self.data_format = data_format

        assert self.data_format in ["channels_last", "channels_first"]
        self.normalized_shape = (embed_dim, )

    def forward(self, x):
        if self.data_format == "channels_last":  # N,H,W,C
            return F.layer_norm(x, self.embed_dim, self.weight, self.bias,
                                self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)  # N,C,H,W

            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

3. 用法

            if self.use_layer_norm:
                N,C,H,W=x.shape
                x = x.flatten(2).transpose(1, 2)  # N,C,H,W -> N,C,H*W -> N,H*W, C
                hw_shape=(H,W)
                x = norm(x) 
                x = x.reshape(-1, *hw_shape, C).permute(0, 3, 1, 2).contiguous() # N,H,W,C -> N, C,H,W

你可能感兴趣的:(基础网络,深度学习,python,pytorch)