八度卷积 Octave Convolution(OctConv)

讲解的非常详细,可以看看

核心原理就是利用空间尺度化理论将图像高频低频部分分开,下采样低频部分,可以大大降低参数量,并且可以完美的嵌入到神经网络中。降低了低频信息的冗余。

很大的创新。

好多人不理解这个操作,我看了看facebook的源码,项目也正在开发中,是mxnet,大家可以自己看看。

https://github.com/facebookresearch/OctConv

首先传参,就是传入的data是分高频和低频数据的,跟返回结果是对应的,传入tuple,则返回tuple,负责会默认只有高频。

然后,之前的conv操作是一个卷积核,这个OctCon是训练四个卷积核,分别为h-h,h-l,l-h,l-l

至于具体卷积核的大小,你们可以自己计算一下。

def _gather(operator, srcs, name, **kwargs):
    if type(srcs[0]) is not tuple:
        return operator(*srcs, **kwargs)
    output = []
    appxs = {1: ['-h'], 2: ['-h', '-l']}[len(srcs[0])]
    for i in range(len(srcs[0])):
        data = [src[i] for src in srcs if src[i] is not None]
        output.append(None if len(data) == 0
                      else data[0] if len(data) == 1
                      else operator(*data, name=(name+appxs[i]),**kwargs))
    return tuple(output)

def Convolution(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), num_group=1, no_bias=True, name=None):
    data_h, data_l = data if type(data) is tuple else (data, None)
    num_high, num_low = num_filter if type(num_filter) is tuple else (num_filter, 0)

    assert num_high >= 0 and num_low >= 0
    assert stride == (1, 1) or stride == (2, 2), "stride = {} is not supported yet".format(stride)

    data_h2l, data_h2h, data_l2l, data_l2h = None, None, None, None
    depthwise = True if num_filter == num_group else False

    '''processing high frequency group'''
    if data_h is not None:
        # High -> High
        data_h = mx.sym.Pooling(data=data_h, pool_type="avg", kernel=(2, 2), pad=(0, 0), stride=(2, 2)) if stride == (2, 2) else data_h
        data_h2h = mx.sym.Convolution(data=data_h, num_filter=num_high, kernel=kernel, stride=(1, 1), pad=pad, num_group=min(num_high, num_group), no_bias=no_bias, name=('%s-h2h' % name)) if num_high > 0 else None
        # High -> Low
        if not depthwise:
            data_h2l = mx.sym.Pooling(data=data_h, pool_type="avg", kernel=(2, 2), pad=(0, 0), stride=(2, 2)) if (num_low > 0) else data_h
            data_h2l = mx.sym.Convolution(data=data_h2l, num_filter=num_low, kernel=kernel, stride=(1, 1), pad=pad, num_group=min(num_low, num_group), no_bias=no_bias, name=('%s-h2l' % name)) if num_low > 0 else None

    '''processing low frequency group'''
    if data_l is not None:
        # Low -> Low
        data_l2l = mx.sym.Pooling(data=data_l, pool_type="avg", kernel=(2, 2), pad=(0, 0), stride=(2, 2)) if (num_low > 0 and stride == (2, 2)) else data_l
        data_l2l = mx.sym.Convolution(data=data_l2l, num_filter=num_low, kernel=kernel, stride=(1, 1), pad=pad, num_group=min(num_low, num_group), no_bias=True, name=('%s-l2l' % name)) if num_low > 0 else None
        # Low -> High
        if not depthwise:
            data_l2h = mx.sym.Convolution(data=data_l, num_filter=num_high, kernel=kernel, stride=(1, 1), pad=pad, num_group=min(num_high, num_group), no_bias=True, name=('%s-l2h' % name)) if num_high > 0 else None
            data_l2h = mx.sym.UpSampling(data_l2h, scale=2, sample_type="nearest", num_args=1) if (num_high > 0 and stride == (1, 1)) else data_l2h

    '''you can force to disable the interaction paths'''
    # data_h2l = None if (data_h2h is not None) and (data_l2l is not None) else data_h2l
    # data_l2h = None if (data_h2h is not None) and (data_l2l is not None) else data_l2h

    output = ElementWiseSum(*[(data_h2h, data_h2l), (data_l2h, data_l2l)], name=name)

    # squeeze output (to be backward compatible)
    return output[0] if output[1] is None else output

参考文献:https://www.cnblogs.com/RyanXing/p/10720182.html

你可能感兴趣的:(深度学习)