class highorder_block(nn.Module):
def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):
super(highorder_block, self).__init__()
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0
padding1 = get_valid_padding(1, dilation) if pad_type == 'zero' else 0
##################kernel representation
self.con1=nn.Conv2d(in_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
self.con2=nn.Conv2d(in_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
#self.con3=nn.Conv2d(in_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
#self.con4=nn.Conv2d(in_nc, 8*out_nc, 1, 1, padding1, dilation, groups, bias)
self.con_out=nn.Conv2d(in_nc+8*out_nc, out_nc, 1, 1, padding1, dilation, groups, bias)
self.a = act(act_type) if act_type else None
self.n_h = norm(norm_type, out_nc) if norm_type else None
def forward(self, x):
X_h1=self.con1(x)
X_h2=self.con2(x)
#X_h3=self.con3(x)
#X_h4=self.con4(x)
#x.mul(y)
X_r1=x###first order
X_r2=X_h1.mul(X_h2)###second order
#X_r3=X_h1.mul(X_h2).mul(X_h3)###three order
#X_r4=X_h1.mul(X_h2).mul(X_h3).mul(X_h4)###four order
##############
x=torch.cat((X_r1,X_r2), dim=1)
x=self.con_out(x)
if self.n_h:
x = self.n_h(x)
if self.a:
x = self.a(x)
return x