mxnet - reshape操作完全解析(理解0,-1,-2,-3,-4)

一般来说,同一个操作,mxnet的ndarry和symbol都会有,分别对应动态图和静态图,比如reshape,可以调用 mx.nd.reshape,或者调用 mx.sym.reshape。下面对reshape这个操作进行解析,以mx.nd.reshape作为参考。

reshape的注释

reshape(data=None, shape=_Null, reverse=_Null, target_shape=_Null, keep_highest=_Null, out=None, name=None, **kwargs)
    Reshapes the input array.

    .. note:: ``Reshape`` is deprecated, use ``reshape``

    Given an array and a shape, this function returns a copy of the array in the new shape.
    The shape is a tuple of integers such as (2,3,4). The size of the new shape should be same as the size of the input array.

    Example::

      reshape([1,2,3,4], shape=(2,2)) = [[1,2], [3,4]]

    Some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}. The significance of each is explained below:

    - ``0``  copy this dimension from the input to the output shape.

      Example::

      - input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2)
      - input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4)

    - ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions
      keeping the size of the new array same as that of the input array.
      At most one dimension of shape can be -1.

      Example::

      - input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4)
      - input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8)
      - input shape = (2,3,4), shape=(-1,), output shape = (24,)

    - ``-2`` copy all/remainder of the input dimensions to the output shape.

      Example::

      - input shape = (2,3,4), shape = (-2,), output shape = (2,3,4)
      - input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4)
      - input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1)

    - ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension.

      Example::

      - input shape = (2,3,4), shape = (-3,4), output shape = (6,4)
      - input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20)
      - input shape = (2,3,4), shape = (0,-3), output shape = (2,12)
      - input shape = (2,3,4), shape = (-3,-2), output shape = (6,4)

    - ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1).

      Example::

      - input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4)
      - input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4)

    If the argument `reverse` is set to 1, then the special values are inferred from right to left.

      Example::

      - without reverse=1, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5)
      - with reverse=1, output shape will be (50,4).

reshape传入的一个参数shape元组,元组中的数字可以非0正数,或者是0,-1,-2,-3,-4 这些奇怪的输入,下面讲讲这些参数的意义。

0

0起一个占位符的作用,默认从左到右进行占位(除非传入reverse=1,则从右到左),维持原数组在该位置的维度。

  • input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2) # 中间维度维持不变
  • input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4) # 后两个维度维持不变

-1

-1是最后进行推导的,先保证其他数字被照顾好之后,在reshape前后数组的size不变的约束下,推导出该位置的维度。通常来说,最多只有一个-1,但是在有 -4 的情况下,可以有两个 -1。

  • input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4)
  • input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8)
  • input shape = (2,3,4), shape=(-1,), output shape = (24,)

-2

-2和-1不同,-2可以包括多个维度。当其他位置都有对应的维度之后,-2就来容纳剩下的多个维度。

  • input shape = (2,3,4), shape = (-2,), output shape = (2,3,4) # -2来容纳所有的维度
  • input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4) # 2占据了一个维度,-2容纳剩下的(3,4)
  • input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1) # (1,1)是新增的两个维度,-2将(2,3,4)给容纳

-3

-3是将对应的两个维度合成一个维度,合成之后的维度值为之前两个维度的乘积。

  • input shape = (2,3,4), shape = (-3,4), output shape = (6,4)
  • input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20)
  • input shape = (2,3,4), shape = (0,-3), output shape = (2,12)
  • input shape = (2,3,4), shape = (-3,-2), output shape = (6,4)

-4

-4和-3不同,-4是将一个维度拆分为两个,-4后面跟两个数字,代表拆分后的维度,其中可以有-1。

  • input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4) # 将2拆分为1X2,剩下的3,4传递给-2
  • input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4) # 将3拆分为1X3,剩下的4传递给-2

reverse

If the argument `reverse` is set to 1, then the special values are inferred from right to left.

  Example::

  - without reverse=1, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5)
  - with reverse=1, output shape will be (50,4).

一个例子:GN的实现

class GroupNorm(mx.gluon.HybridBlock):
    r"""Group Normalization

    refer to paper 

    """
    def __init__(self,
                 in_channels,
                 groups=32,
                 gamma_initializer='ones',
                 beta_initializer='zeros',
                 **kwargs):
        super(GroupNorm, self).__init__(**kwargs)
        self.groups = min(in_channels, groups)
        assert in_channels % self.groups == 0, "Channel number should be divisible by groups."
        attrs = SpecialAttrScope.current.attrs
        self.mirroring_level = attrs.get('mirroring_level', 0)
        self.eps = attrs.get('gn_eps', 2e-5)
        self.use_fp16 = False
        with self.name_scope():
            self.gamma = self.params.get('gamma',
                                         grad_req='write',
                                         shape=(1, in_channels, 1, 1),
                                         init=gamma_initializer,
                                         allow_deferred_init=True,
                                         differentiable=True)
            self.beta = self.params.get('beta',
                                        grad_req='write',
                                        shape=(1, in_channels, 1, 1),
                                        init=beta_initializer,
                                        allow_deferred_init=True,
                                        differentiable=True)

    def cast(self, dtype):
        self.use_fp16 = False
        if np.dtype(dtype).name == 'float16':
            self.use_fp16 = True
            dtype = 'float32'
        super(GroupNorm, self).cast(dtype)

    def hybrid_forward(self, F, x, gamma, beta):
        _kwargs = {}
        if F is mx.symbol and self.mirroring_level >= 3:
            _kwargs['force_mirroring'] = 'True'

        if self.use_fp16:
            x = F.cast(data=x, dtype='float32')

        # (N, C, H, W) --> (N, G, C//G, H, W
        x = F.reshape(x, shape=(-1, -4, self.groups, -1, -2))

        # y = (x - mean) / sqrt(var + eps)
        mean = F.mean(x, axis=(2, 3, 4), keepdims=True, **_kwargs)
        y = F.broadcast_sub(x, mean, **_kwargs)
        var = F.mean(y**2, axis=(2, 3, 4), keepdims=True, **_kwargs)
        y = F.broadcast_div(y, F.sqrt(var + self.eps))

        # (N, G, C//G, H, W --> (N, C, H, W)
        y = F.reshape(y, shape=(-1, -3, -2))

        y = F.broadcast_mul(y, gamma, **_kwargs)
        y = F.broadcast_add(y, beta, **_kwargs)

        if self.use_fp16:
            y = F.cast(data=y, dtype='float16')

        return y

你可能感兴趣的:(mxnet-gluon,mxnet-symbol)