import torch
import torch.utils.data
from torch import nn, optim
from padding_same_conv import Conv2d
def toTensor(img):
img = torch.from_numpy(img.transpose((0, 3, 1, 2)))
return img
def var_to_np(img_var):
return img_var.data.cpu().numpy()
class _ConvLayer(nn.Sequential):
def __init__(self, input_features, output_features):
super(_ConvLayer, self).__init__()
self.add_module('conv2', Conv2d(input_features, output_features,
kernel_size=5, stride=2))
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
class _UpScale(nn.Sequential):
def __init__(self, input_features, output_features):
super(_UpScale, self).__init__()
self.add_module('conv2_', Conv2d(input_features, output_features * 4,
kernel_size=3))
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
self.add_module('pixelshuffler', _PixelShuffler())
class Flatten(nn.Module):
def forward(self, input):
output = input.view(input.size(0), -1)
return output
class Reshape(nn.Module):
def forward(self, input):
output = input.view(-1, 1024, 4, 4) # channel * 4 * 4
return output
class _PixelShuffler(nn.Module):
def forward(self, input):
batch_size, c, h, w = input.size()
rh, rw = (2, 2)
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = input.view(batch_size, rh, rw, oc, h, w)
out = out.permute(0, 3, 4, 1, 5, 2).contiguous()
out = out.view(batch_size, oc, oh, ow) # channel first
return out
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
_ConvLayer(3, 128),
_ConvLayer(128, 256),
_ConvLayer(256, 512),
_ConvLayer(512, 1024),
Flatten(),
nn.Linear(1024 * 4 * 4, 1024),
nn.Linear(1024, 1024 * 4 * 4),
Reshape(),
_UpScale(1024, 512),
)
self.decoder_A = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
Conv2d(64, 3, kernel_size=5, padding=1),
nn.Sigmoid(),
)
self.decoder_B = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
Conv2d(64, 3, kernel_size=5, padding=1),
nn.Sigmoid(),
)
def forward(self, x, select='A'):
if select == 'A':
out = self.encoder(x)
out = self.decoder_A(out)
else:
out = self.encoder(x)
out = self.decoder_B(out)
return out
class PixelShuffler(Layer):
def __init__(self, size=(2, 2), data_format=None, **kwargs):
super(PixelShuffler, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 2, 'size') # (2,2)
def call(self, inputs):
input_shape = K.int_shape(inputs)
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
batch_size, c, h, w = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
out = K.reshape(out, (batch_size, oc, oh, ow))
return out
elif self.data_format == 'channels_last':
batch_size, h, w, c = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = K.reshape(out, (batch_size, oh, ow, oc))
return out
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
channels = input_shape[1] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[1]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
channels,
height,
width)
elif self.data_format == 'channels_last':
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[3]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
height,
width,
channels)
1.https://oldpan.me/archives/upsample-convolve-efficient-sub-pixel-convolutional-layers
2.https://github.com/Oldpan/Faceswap-Deepfake-Pytorch/blob/master/models.py
3.https://arxiv.org/ftp/arxiv/papers/1609/1609.07009.pdf
4.https://www.jianshu.com/p/dd7df75c5679