【Pytorch】实现subpixle上采样及下采样

Pytorch 实现subpixle上采样及下采样, 类似与tensorflow的tf.depth_to_space , tf.space_to_depth

def shuffle_down(inputs, scale):
    N, C, iH, iW = inputs.size()
    oH = iH // scale
    oW = iW // scale

    output = inputs.view(N, C, oH, scale, oW, scale)
    output = output.permute(0,1,5,3,2,4).contiguous()
    return output.view(N, -1, oH, oW)
def shuffle_up(inputs, scale):
    N, C, iH, iW = inputs.size()
    oH = iH * scale
    oW = iW * scale
    oC = C // (scale ** 2)
    output = inputs.view(N, oC, scale, scale, iH, iW)
    output = output.permute(0,1,4,3,5,2).contiguous()
    output = output.view(N, oC, oH, oW)
    return output

==================
pytorch 论坛的例子

from torch import nn

class DepthToSpace(nn.Module):

    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W)  # (N, bs, bs, C//bs^2, H, W)
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # (N, C//bs^2, H, bs, W, bs)
        x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs)  # (N, C//bs^2, H * bs, W * bs)
        return x


class SpaceToDepth(nn.Module):

    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs)  # (N, C, H//bs, bs, W//bs, bs)
        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)
        x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs)  # (N, C*bs^2, H//bs, W//bs)
        return x


import tensorflow as tf
import torch

# pytorch
x1 = torch.rand(64, 256, 8, 8)
x2 = DepthToSpace(2)(x1)
x3 = SpaceToDepth(2)(x2)
print(x1.size())
print(x2.size())
print(x3.size())
print((x1 == x3).all())

# tensorflow
y1 = tf.transpose(x1.numpy(), [0, 2, 3, 1])  # NCHW -> NHWC
y2 = tf.depth_to_space(y1, 2)
y3 = tf.space_to_depth(y2, 2)

y1 = tf.transpose(y1, [0, 3, 1, 2])  # NHWC -> NCHW
y2 = tf.transpose(y2, [0, 3, 1, 2])
y3 = tf.transpose(y3, [0, 3, 1, 2])

y1, y2, y3 = tf.Session().run([y1, y2, y3])
print(y1.shape)
print(y2.shape)
print(y3.shape)
print((y1 == y3).all())

# check consistency
print((x1.numpy() == y1).all())
print((x2.numpy() == y2).all())
print((x3.numpy() == y3).all())

你可能感兴趣的:(pytorch)