high order

class highorder_block(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
                    bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
        super(highorder_block, self).__init__()
        assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
        padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
        padding1 = get_valid_padding(1, dilation) if pad_type == 'zero' else 0

        ##################kernel representation
        self.con1=nn.Conv2d(in_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
        self.con2=nn.Conv2d(in_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
        #self.con3=nn.Conv2d(in_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
        #self.con4=nn.Conv2d(in_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)

        self.con_out=nn.Conv2d(in_nc+8*out_nc, out_nc, 1, 1, padding1, dilation, groups, bias)

        self.a = act(act_type) if act_type else None
        self.n_h = norm(norm_type, out_nc) if norm_type else None

    def forward(self, x):


        X_h1=self.con1(x)
        X_h2=self.con2(x)
        #X_h3=self.con3(x)
        #X_h4=self.con4(x)
        #x.mul(y)
        X_r1=x###first order
        X_r2=X_h1.mul(X_h2)###second order
        #X_r3=X_h1.mul(X_h2).mul(X_h3)###three order
        #X_r4=X_h1.mul(X_h2).mul(X_h3).mul(X_h4)###four order
        ##############
        x=torch.cat((X_r1,X_r2), dim=1)
        x=self.con_out(x)

        if self.n_h:
            x = self.n_h(x)

        if self.a:
            x = self.a(x)

        return x

 

你可能感兴趣的:(high order)