python train.py \
--data_set=cifar \
--model=h12_noup_smallkey \
--nr_logistic_mix=10 \
--nr_filters=256 \
--batch_size=8 \
--init_batch_size=8 \
--dropout_p=0.5 \
--polyak_decay=0.9995 \
--save_interval=10
python train.py \
--data_set=imagenet \
--model=h12_noup_smallkey \
--nr_logistic_mix=32 \
--nr_filters=256 \
--batch_size=8 \
--init_batch_size=8 \
--learning_rate=0.001 \
--dropout_p=0.0 \
--polyak_decay=0.9997 \
--save_interval=1
这个代码写的真的不敢苟同,所有的处理逻辑都放在一个train.p中,看起来很混乱。他的代码是tensorflow的,而且是1.0系列的代码,可读性并不是那么好,所以这里就不 投入太多关注了,仅仅阅读模型的生成部分。
调用并生成模型的代码
# 创建模型
model_opt = {'nr_resnet': args.nr_resnet, 'nr_filters': args.nr_filters,
'nr_logistic_mix': args.nr_logistic_mix, 'resnet_nonlinearity': args.resnet_nonlinearity}
# 生成一个模型模板,模型可以多次重复使用,不需要重复创建变量
model = tf.make_template('model', getattr(pxpp_models, args.model + "_spec"))
# 用于依赖于数据的参数初始化
with tf.device('/gpu:0'):
gen_par = model(x_init, h_init, init=True,
dropout_p=args.dropout_p, **model_opt)
h12_noup_smallkey_spec = functools.partial(_base_noup_smallkey_spec, attn_rep=12)
h12_pool2_smallkey_spec = functools.partial(_base_noup_smallkey_spec, attn_rep=12, att_downsample=2)
h8_noup_smallkey_spec = functools.partial(_base_noup_smallkey_spec, attn_rep=8)
参照论文,我们看一下整个模型基本的定义图,具体如下,主要是两个模块,分别是
具体执行逻辑如下图
下述为原程序代码
def h6_shift_spec(x, h=None, init=False, ema=None, dropout_p=0.5, nr_resnet=5, nr_filters=160, nr_logistic_mix=10, resnet_nonlinearity='concat_elu'):
"""
We receive a Tensor x of shape (N,H,W,D1) (e.g. (12,32,32,3)) and produce
a Tensor x_out of shape (N,H,W,D2) (e.g. (12,32,32,100)), where each fiber
of the x_out tensor describes the predictive distribution for the RGB at
that position.
'h' is an optional N x K matrix of values to condition our generative model on
"""
counters = {}
with arg_scope([nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.dense, nn.nin, nn.mem_saving_causal_shift_nin], counters=counters, init=init, ema=ema, dropout_p=dropout_p):
# parse resnet nonlinearity argument
if resnet_nonlinearity == 'concat_elu':
resnet_nonlinearity = nn.concat_elu
elif resnet_nonlinearity == 'elu':
resnet_nonlinearity = tf.nn.elu
elif resnet_nonlinearity == 'relu':
resnet_nonlinearity = tf.nn.relu
else:
raise('resnet nonlinearity ' +
resnet_nonlinearity + ' is not supported')
with arg_scope([nn.gated_resnet], nonlinearity=resnet_nonlinearity, h=h):
# // up pass through pixelCNN
xs = nn.int_shape(x)
background = tf.concat(
[
((tf.range(xs[1], dtype=tf.float32) - xs[1] / 2) / xs[1])[None, :, None, None] + 0. * x,
((tf.range(xs[2], dtype=tf.float32) - xs[2] / 2) / xs[2])[None, None, :, None] + 0. * x,
],
axis=3
)
# add channel of ones to distinguish image from padding later on
x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)
ul_list = [nn.causal_shift_nin(x_pad, nr_filters)] # stream for up and to the left
for attn_rep in range(6):
for rep in range(nr_resnet):
ul_list.append(nn.gated_resnet(
ul_list[-1], conv=nn.mem_saving_causal_shift_nin))
ul = ul_list[-1]
hiers = [1, ]
hier = hiers[attn_rep % len(hiers)]
raw_content = tf.concat([x, ul, background], axis=3)
key, mixin = tf.split(nn.nin(nn.gated_resnet(raw_content, conv=nn.nin), nr_filters * 2 // 2), 2, axis=3)
raw_q = tf.concat([ul, background], axis=3)
if hier != 1:
raw_q = raw_q[:, ::hier, ::hier, :]
query = nn.nin(nn.gated_resnet(raw_q, conv=nn.nin), nr_filters // 2)
if hier != 1:
key = tf.nn.pool(key, [hier, hier], "AVG", "SAME", strides=[hier, hier])
mixin = tf.nn.pool(mixin, [hier, hier], "AVG", "SAME", strides=[hier, hier])
mixed = nn.causal_attention(key, mixin, query, causal_unit=1 if hier == 1 else xs[2] // hier)
if hier != 1:
mixed = tf.depth_to_space(tf.tile(mixed, [1, 1, 1, hier * hier]), hier)
ul_list.append(nn.gated_resnet(ul, mixed, conv=nn.nin))
x_out = nn.nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix)
return x_out
函数参数说明
下述将根据代码和流程图,列出因果卷积、门控残差网络和因果注意力模块的具体实现
这里因果卷积的定义方式和PixelCNN不一样,他是定义掩码,这里是定义了四种不同的卷积方式来实现因果卷积的,分别是,这个过程复杂的很。
在二维因果卷积过程中,要确保每一个输出像素仅受其左侧和上方元素的影响,通常经过一下几种方式实现
在作者的代码通过了两个方法来确保每一个像素点只能获得上面和左边的信息。
通过上述两个方法的结合,确保元素只能获得左上部的未来信息。
一维输入序列 x = [ x 0 , x 1 , . . . , x n − 1 ] x = [x_0,x_1,...,x_{n-1}] x=[x0,x1,...,xn−1]
一维卷积核 h = [ h 0 , h 1 , . . . . , h m − 1 ] h = [h _0,h_1,....,h_{m-1}] h=[h0,h1,....,hm−1]
因果卷积的输出 y y y定义如下, y t = ∑ i = 0 m − 1 h i x t − i t > = i y_t = \sum_{i = 0}^{m-1} h_i x_{t-i} \ \ \ \ \ \ \ \ \ \ \ \ \ t >= i yt=i=0∑m−1hixt−i t>=i
具体样例如下
def down_shifted_conv2d(x, num_filters, filter_size=[2, 3], stride=[1, 1], **kwargs):
# 这里是对数据进行填充,总共有四个维度,分别是N,H,W,C
# 第一个维度不进行填充,他是batch_size
# 第二个维度H进行填充,开始的地方填充的大小是filter_size[0] - 1,结束的地方填充的大小是0,也就是仅仅扩充上部分
# 第三个宽度是W进行填充,开始的地方填充的大小是int((filter_size[1] - 1) / 2),结束的地方填充的大小是int((filter_size[1] - 1) / 2),也就是仅仅扩充左右两边
# 第四个维度不进行填充,他是channel
x = tf.pad(x, [[0, 0], [filter_size[0] - 1, 0], [int((filter_size[1] - 1) / 2), int((filter_size[1] - 1) / 2)], [0, 0]])
#
return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)
这里作者自己定义了一个带有权重归一化的二维卷积层,因为正常卷积并不包含权重归一化的效果。
带权重归一化的二维卷积层的优势:
正常二维卷积的优势:
缩放因子 g g g和偏置权重 b b b的作用:
使用广播机制,将放缩因子与权重矩阵中的每一个数字按位相乘 S c a l e d W e i g h t = g × N o r m a l i z e d W e i g h t Scaled \ Weight = g \times Normalized \ Weight Scaled Weight=g×Normalized Weight
主要是权重归一化之后,所有权重的范围 都是单位范围内,加上缩放因子能够适应更多范数。
注意,权重归一化,是在输出通道上进行操作的
这里我进行了两种实现方法,一种是通过pytorch自定义权重归一化卷积层,还有一种是通过pytorch内置的权重归一化装饰器实现类似的功能。
# 权重归一化卷积层
class WeightNormConv2d(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, stride=1, padding=0,
nonlinearity=None, init_scale=1.):
super(WeightNormConv2d, self).__init__()
# 指定非线性激活函数
self.nonlinearity = nonlinearity
self.init_scale = init_scale
# 定义卷积层
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
# 定义缩放因子g和偏置b
# 将g和b声明为需要优化的参数,卷积层默认的权重是(C-out,C-in,H,W)这四个维度
self.g = nn.Parameter(torch.ones(out_channels, 1, 1, 1))
self.b = nn.Parameter(torch.zeros(out_channels))
# 声明初始化参数
self.reset_parameters()
# 数据依赖的参数初始化
def reset_parameters(self):
# 初始化权重为正太分布
init.normal_(self.conv.weight, mean=0, std=0.05)
# 初始化偏置为0
init.zeros_(self.conv.bias)
# 使用一次随机输入,进行一次前向传播,以计算初始的g和b
with torch.no_grad():
x_init = F.conv2d(torch.randn(1, *self.conv.weight.shape[1:]),
self.conv.weight)
m_init, v_init = x_init.mean(), x_init.var()
# 计算缩放因子
scale_init = self.init_scale / torch.sqrt(v_init + 1e-8)
self.g.data.fill_(scale_init)
# 计算偏置
# 将张量所有的元素都设定为特定的元素
# data属性仅仅是修改对应的值,但是不会计入梯度的改变
self.b.data.fill_(-m_init * scale_init)
def forward(self, x):
# 应用权重归一化
W = self.conv.weight * (self.g / torch.sqrt((self.conv.weight ** 2).sum([1, 2, 3], keepdim=True)))
# 执行卷积操作
x = F.conv2d(x, W, self.b, self.conv.stride, self.conv.padding)
# 应用非线性激活
if self.nonlinearity is not None:
x = self.nonlinearity(x)
return x
# 测试函数
conv_layer = WeightNormConv2d(3, 16, [3, 3], stride=1, nonlinearity=F.relu)
x = torch.randn(8, 3, 64, 64) # NCHW格式
out = conv_layer(x)
print(out.shape)
# 使用pytorch自定义的权重归一化层
import torch
from torch import nn
from torch.nn.utils import weight_norm
# 创建一个标准的 Conv2d 层
conv_layer = nn.Conv2d(3, 16, 3, 1)
# 应用权重归一化
conv_layer = weight_norm(conv_layer)
# 测试该层
x = torch.randn(8, 3, 64, 64) # 输入张量,形状为 [batch_size, channels, height, width]
out = conv_layer(x)
print(out.shape) # 输出张量的形状应为 [8, 16, 62, 62]
@add_arg_scope
def down_right_shifted_conv2d(x, num_filters, filter_size=[2, 2], stride=[1, 1], **kwargs):
x = tf.pad(x, [[0, 0], [filter_size[0] - 1, 0],
[filter_size[1] - 1, 0], [0, 0]])
return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)
def right_shift(x, step=1):
xs = int_shape(x)
return tf.cobncat([tf.zeros([xs[0], xs[1], step, xs[3]]), x[:, :, :xs[2] - step, :]], 2)
# 实现最终的模型
class PixelSNAIL(nn.Module):
'''
pixelSNAIL模型
'''
def __init__(self,nr_resnet=5, nr_filters=32, attn_rep=12, nr_logistic_mix=10, att_downsample=1):
super(PixelSNAIL,self).__init__()
# 声明类成员
self.nr_resnet = nr_resnet
self.nr_filters = nr_filters
self.attn_rep = attn_rep
self.nr_logistic_mix = nr_logistic_mix
self.att_downsample = att_downsample
# 声明定义模型对象
# 声明因果卷积的网络
self.down_shifted_conv2d = weight_norm(nn.Conv2d(3, self.nr_filters, kernel_size=(1, 3)))
self.down_right_shifted_conv2d = weight_norm(nn.Conv2d(3, self.nr_filters, kernel_size=(2, 1)))
# 声明包含若干门控残差网络的modulelist
self.gated_resnets = nn.ModuleList([GatedResNet(self.nr_filters) for _ in range(self.nr_resnet)])
# 声明线性模型
self.nin1 = nn.Linear(self.nr_filters, self.nr_filters // 2 + 16) # 假设q_size = 16
self.nin2 = nn.Linear(self.nr_filters, 16) # 假设q_size = 16
# 声明因果注意力模块
self.causal_attentions = nn.ModuleList([CausalAttention() for _ in range(self.attn_rep)])
# 最终的卷积网络
self.final_conv = nn.Conv2d(self.nr_filters, 10 * self.nr_logistic_mix, kernel_size=1)
def forward(self, x):
ul_list = []
# 加上一个是四个
# 按照左右上下的方式进行填充
down_shifted = F.pad(x, (1, 1, 0, 0)) # 自定义下移和右移操作
right_shifted = F.pad(x, (0, 0, 1, 0))
# 因果卷积
ds_conv = self.down_shifted_conv2d(down_shifted)
drs_conv = self.down_right_shifted_conv2d(right_shifted)
ul = ds_conv + drs_conv
ul_list.append(ul)
# 下采样,从右下角开始
for causal_attention in self.causal_attentions:
for gated_resnet in self.gated_resnets:
ul = gated_resnet(ul)
ul_list.append(ul)
print('attention module')
# 注意力机制
last_ul = ul_list[-1]
# 准备原始内容
raw_content = torch.cat([x, last_ul], dim=1) # 假设背景信息已经添加到x中
# 生成key和value
print(raw_content.shape)
raw = self.nin1(raw_content)
print('raw data shape',raw.shape)
key, mixin = raw.split(18, dim=1) # 假设q_size = 16
# 生成query查询键
raw_q = last_ul
query = self.nin2(raw_q)
# 计算注意力
print(mixin.shape)
print(query.shape)
mixed = causal_attention(key, mixin, query)
ul_list.append(mixed)
x_out = F.elu(ul_list[-1])
x_out = self.final_conv(x_out)
return x_out
x = torch.randn(64,3,32,32)
model = PixelSNAIL()
x_out = model(x)
print(x_out.shape)