sepconv(Separable Convolution)代码复现

import torch.nn as nn


class SP_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=3, stride=1, dilation=1, bias=False):
        super(SP_conv, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, in_channels, kernel, stride, 0,
            dilation, groups=in_channels, bias=bias
        )
        self.pixelwise = nn.Conv2d(
            in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.pixelwise(x)
        return x

你可能感兴趣的:(深度学习,pytorch,人工智能,python)