Pytorch 实现subpixle上采样及下采样, 类似与tensorflow的tf.depth_to_space , tf.space_to_depth
def shuffle_down(inputs, scale):
N, C, iH, iW = inputs.size()
oH = iH // scale
oW = iW // scale
output = inputs.view(N, C, oH, scale, oW, scale)
output = output.permute(0,1,5,3,2,4).contiguous()
return output.view(N, -1, oH, oW)
def shuffle_up(inputs, scale):
N, C, iH, iW = inputs.size()
oH = iH * scale
oW = iW * scale
oC = C // (scale ** 2)
output = inputs.view(N, oC, scale, scale, iH, iW)
output = output.permute(0,1,4,3,5,2).contiguous()
output = output.view(N, oC, oH, oW)
return output
==================
pytorch 论坛的例子
from torch import nn
class DepthToSpace(nn.Module):
def __init__(self, block_size):
super().__init__()
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
return x
class SpaceToDepth(nn.Module):
def __init__(self, block_size):
super().__init__()
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
return x
import tensorflow as tf
import torch
# pytorch
x1 = torch.rand(64, 256, 8, 8)
x2 = DepthToSpace(2)(x1)
x3 = SpaceToDepth(2)(x2)
print(x1.size())
print(x2.size())
print(x3.size())
print((x1 == x3).all())
# tensorflow
y1 = tf.transpose(x1.numpy(), [0, 2, 3, 1]) # NCHW -> NHWC
y2 = tf.depth_to_space(y1, 2)
y3 = tf.space_to_depth(y2, 2)
y1 = tf.transpose(y1, [0, 3, 1, 2]) # NHWC -> NCHW
y2 = tf.transpose(y2, [0, 3, 1, 2])
y3 = tf.transpose(y3, [0, 3, 1, 2])
y1, y2, y3 = tf.Session().run([y1, y2, y3])
print(y1.shape)
print(y2.shape)
print(y3.shape)
print((y1 == y3).all())
# check consistency
print((x1.numpy() == y1).all())
print((x2.numpy() == y2).all())
print((x3.numpy() == y3).all())