介绍
基本步骤
具体流程图如下
具体和上面描述的差不多,这里增加了两个额外的参数,分别是辅助输入a和条件矩阵b
注意,这里的二维卷积就是加上了简单的权重归一化的普通二维卷积。
辅助输入a
条件矩阵h
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):
xs = int_shape(x)
num_filters = xs[-1]
# 执行第一次卷积
c1 = conv(nonlinearity(x), num_filters)
# 查看是否有辅助输入a
if a is not None: # add short-cut connection if auxiliary input 'a' is given
c1 += nin(nonlinearity(a), num_filters)
# 执行非线性单元
c1 = nonlinearity(c1)
if dropout_p > 0:
c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
# 执行第二次卷积
c2 = conv(c1, num_filters * 2, init_scale=0.1)
# add projection of h vector if included: conditional generation
# 如果有辅助输入h,那么就将h投影到c2的维度上
if h is not None:
with tf.variable_scope(get_name('conditional_weights', counters)):
hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
if init:
hw = hw.initialized_value()
c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])
# Is this 3,2 or 2,3 ?
a, b = tf.split(c2, 2, 3)
c3 = a * tf.nn.sigmoid(b)
return x + c3
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
class GatedResNet(nn.Module):
def __init__(self, num_filters, nonlinearity=F.elu, dropout_p=0.0):
super(GatedResNet, self).__init__()
self.num_filters = num_filters
self.nonlinearity = nonlinearity
self.dropout_p = dropout_p
# 第一卷积层
self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
# self.conv1 = weight_norm(self.conv1)
# 第二卷积层,输出通道是 2 * num_filters,用于门控机制
self.conv2 = nn.Conv2d(num_filters, 2 * num_filters, kernel_size=3, padding=1)
# self.conv2 = weight_norm(self.conv2)
# 条件权重用于 h,初始化在前向传播过程中
self.hw = None
def forward(self, x, a=None, h=None):
c1 = self.conv1(self.nonlinearity(x))
# 检查是否有辅助输入 'a'
if a is not None:
c1 += a # 或使用 NIN 使维度兼容
c1 = self.nonlinearity(c1)
if self.dropout_p > 0:
c1 = F.dropout(c1, p=self.dropout_p, training=self.training)
c2 = self.conv2(c1)
print('the shape of c2',c2.shape)
# 如果有辅助输入 h,则加入 h 的投影
if h is not None:
if self.hw is None:
self.hw = nn.Parameter(torch.randn(h.size(1), self.num_filters) * 0.05)
print(self.hw.shape)
c2 += (h @ self.hw).view(h.size(0), 1, 1, self.num_filters)
# 将通道分为两组:'a' 和 'b'
a, b = c2.chunk(2, dim=1)
c3 = a * torch.sigmoid(b)
return x + c3
# 测试
x = torch.randn(16, 32, 32, 32) # [批次大小,通道数,高度,宽度]
a = torch.randn(16, 32, 32, 32) # 和 x 维度相同的辅助输入
h = torch.randn(16, 64) # 可选的条件变量
model = GatedResNet(32)
out = model(x, a , h)