python \
--data_set=cifar \
--model=h12_noup_smallkey \
--nr_logistic_mix=10 \
--nr_filters=256 \
--batch_size=8 \
--init_batch_size=8 \
--dropout_p=0.5 \
--polyak_decay=0.9995 \
python \
--data_set=imagenet \
--model=h12_noup_smallkey \
--nr_logistic_mix=32 \
--nr_filters=256 \
--batch_size=8 \
--init_batch_size=8 \
--learning_rate=0.001 \
--dropout_p=0.0 \
--polyak_decay=0.9997 \
这个代码写的真的不敢苟同,所有的处理逻辑都放在一个train.p中,看起来很混乱。他的代码是tensorflow的,而且是1.0系列的代码,可读性并不是那么好,所以这里就不 投入太多关注了,仅仅阅读模型的生成部分。
# 创建模型
model_opt = {'nr_resnet': args.nr_resnet, 'nr_filters': args.nr_filters,
'nr_logistic_mix': args.nr_logistic_mix, 'resnet_nonlinearity': args.resnet_nonlinearity}
# 生成一个模型模板,模型可以多次重复使用,不需要重复创建变量
model = tf.make_template('model', getattr(pxpp_models, args.model + "_spec"))
# 用于依赖于数据的参数初始化
with tf.device('/gpu:0'):
gen_par = model(x_init, h_init, init=True,
dropout_p=args.dropout_p, **model_opt)
h12_noup_smallkey_spec = functools.partial(_base_noup_smallkey_spec, attn_rep=12)
h12_pool2_smallkey_spec = functools.partial(_base_noup_smallkey_spec, attn_rep=12, att_downsample=2)
h8_noup_smallkey_spec = functools.partial(_base_noup_smallkey_spec, attn_rep=8)
def h6_shift_spec(x, h=None, init=False, ema=None, dropout_p=0.5, nr_resnet=5, nr_filters=160, nr_logistic_mix=10, resnet_nonlinearity='concat_elu'):
We receive a Tensor x of shape (N,H,W,D1) (e.g. (12,32,32,3)) and produce
a Tensor x_out of shape (N,H,W,D2) (e.g. (12,32,32,100)), where each fiber
of the x_out tensor describes the predictive distribution for the RGB at
that position.
'h' is an optional N x K matrix of values to condition our generative model on
counters = {}
with arg_scope([nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.dense, nn.nin, nn.mem_saving_causal_shift_nin], counters=counters, init=init, ema=ema, dropout_p=dropout_p):
# parse resnet nonlinearity argument
if resnet_nonlinearity == 'concat_elu':
resnet_nonlinearity = nn.concat_elu
elif resnet_nonlinearity == 'elu':
resnet_nonlinearity = tf.nn.elu
elif resnet_nonlinearity == 'relu':
resnet_nonlinearity = tf.nn.relu
raise('resnet nonlinearity ' +
resnet_nonlinearity + ' is not supported')
with arg_scope([nn.gated_resnet], nonlinearity=resnet_nonlinearity, h=h):
# // up pass through pixelCNN
xs = nn.int_shape(x)
background = tf.concat(
((tf.range(xs[1], dtype=tf.float32) - xs[1] / 2) / xs[1])[None, :, None, None] + 0. * x,
((tf.range(xs[2], dtype=tf.float32) - xs[2] / 2) / xs[2])[None, None, :, None] + 0. * x,
# add channel of ones to distinguish image from padding later on
x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], 3)
ul_list = [nn.causal_shift_nin(x_pad, nr_filters)] # stream for up and to the left
for attn_rep in range(6):
for rep in range(nr_resnet):
ul_list[-1], conv=nn.mem_saving_causal_shift_nin))
ul = ul_list[-1]
hiers = [1, ]
hier = hiers[attn_rep % len(hiers)]
raw_content = tf.concat([x, ul, background], axis=3)
key, mixin = tf.split(nn.nin(nn.gated_resnet(raw_content, conv=nn.nin), nr_filters * 2 // 2), 2, axis=3)
raw_q = tf.concat([ul, background], axis=3)
if hier != 1:
raw_q = raw_q[:, ::hier, ::hier, :]
query = nn.nin(nn.gated_resnet(raw_q, conv=nn.nin), nr_filters // 2)
if hier != 1:
key = tf.nn.pool(key, [hier, hier], "AVG", "SAME", strides=[hier, hier])
mixin = tf.nn.pool(mixin, [hier, hier], "AVG", "SAME", strides=[hier, hier])
mixed = nn.causal_attention(key, mixin, query, causal_unit=1 if hier == 1 else xs[2] // hier)
if hier != 1:
mixed = tf.depth_to_space(tf.tile(mixed, [1, 1, 1, hier * hier]), hier)
ul_list.append(nn.gated_resnet(ul, mixed, conv=nn.nin))
x_out = nn.nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix)
return x_out
一维输入序列 x = [ x 0 , x 1 , . . . , x n − 1 ] x = [x_0,x_1,...,x_{n-1}] x=[x0,x1,...,xn−1]
一维卷积核 h = [ h 0 , h 1 , . . . . , h m − 1 ] h = [h _0,h_1,....,h_{m-1}] h=[h0,h1,....,hm−1]
因果卷积的输出 y y y定义如下, y t = ∑ i = 0 m − 1 h i x t − i t > = i y_t = \sum_{i = 0}^{m-1} h_i x_{t-i} \ \ \ \ \ \ \ \ \ \ \ \ \ t >= i yt=i=0∑m−1hixt−i t>=i
def down_shifted_conv2d(x, num_filters, filter_size=[2, 3], stride=[1, 1], **kwargs):
# 这里是对数据进行填充,总共有四个维度,分别是N,H,W,C
# 第一个维度不进行填充,他是batch_size
# 第二个维度H进行填充,开始的地方填充的大小是filter_size[0] - 1,结束的地方填充的大小是0,也就是仅仅扩充上部分
# 第三个宽度是W进行填充,开始的地方填充的大小是int((filter_size[1] - 1) / 2),结束的地方填充的大小是int((filter_size[1] - 1) / 2),也就是仅仅扩充左右两边
# 第四个维度不进行填充,他是channel
x = tf.pad(x, [[0, 0], [filter_size[0] - 1, 0], [int((filter_size[1] - 1) / 2), int((filter_size[1] - 1) / 2)], [0, 0]])
return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)
缩放因子 g g g和偏置权重 b b b的作用:
使用广播机制,将放缩因子与权重矩阵中的每一个数字按位相乘 S c a l e d W e i g h t = g × N o r m a l i z e d W e i g h t Scaled \ Weight = g \times Normalized \ Weight Scaled Weight=g×Normalized Weight
主要是权重归一化之后,所有权重的范围 都是单位范围内,加上缩放因子能够适应更多范数。
# 权重归一化卷积层
class WeightNormConv2d(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, stride=1, padding=0,
nonlinearity=None, init_scale=1.):
super(WeightNormConv2d, self).__init__()
# 指定非线性激活函数
self.nonlinearity = nonlinearity
self.init_scale = init_scale
# 定义卷积层
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
# 定义缩放因子g和偏置b
# 将g和b声明为需要优化的参数,卷积层默认的权重是(C-out,C-in,H,W)这四个维度
self.g = nn.Parameter(torch.ones(out_channels, 1, 1, 1))
self.b = nn.Parameter(torch.zeros(out_channels))
# 声明初始化参数
# 数据依赖的参数初始化
def reset_parameters(self):
# 初始化权重为正太分布
init.normal_(self.conv.weight, mean=0, std=0.05)
# 初始化偏置为0
# 使用一次随机输入,进行一次前向传播,以计算初始的g和b
with torch.no_grad():
x_init = F.conv2d(torch.randn(1, *self.conv.weight.shape[1:]),
m_init, v_init = x_init.mean(), x_init.var()
# 计算缩放因子
scale_init = self.init_scale / torch.sqrt(v_init + 1e-8)
# 计算偏置
# 将张量所有的元素都设定为特定的元素
# data属性仅仅是修改对应的值,但是不会计入梯度的改变 * scale_init)
def forward(self, x):
# 应用权重归一化
W = self.conv.weight * (self.g / torch.sqrt((self.conv.weight ** 2).sum([1, 2, 3], keepdim=True)))
# 执行卷积操作
x = F.conv2d(x, W, self.b, self.conv.stride, self.conv.padding)
# 应用非线性激活
if self.nonlinearity is not None:
x = self.nonlinearity(x)
return x
# 测试函数
conv_layer = WeightNormConv2d(3, 16, [3, 3], stride=1, nonlinearity=F.relu)
x = torch.randn(8, 3, 64, 64) # NCHW格式
out = conv_layer(x)
# 使用pytorch自定义的权重归一化层
import torch
from torch import nn
from torch.nn.utils import weight_norm
# 创建一个标准的 Conv2d 层
conv_layer = nn.Conv2d(3, 16, 3, 1)
# 应用权重归一化
conv_layer = weight_norm(conv_layer)
# 测试该层
x = torch.randn(8, 3, 64, 64) # 输入张量,形状为 [batch_size, channels, height, width]
out = conv_layer(x)
print(out.shape) # 输出张量的形状应为 [8, 16, 62, 62]
def down_right_shifted_conv2d(x, num_filters, filter_size=[2, 2], stride=[1, 1], **kwargs):
x = tf.pad(x, [[0, 0], [filter_size[0] - 1, 0],
[filter_size[1] - 1, 0], [0, 0]])
return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)
def right_shift(x, step=1):
xs = int_shape(x)
return tf.cobncat([tf.zeros([xs[0], xs[1], step, xs[3]]), x[:, :, :xs[2] - step, :]], 2)
# 实现最终的模型
class PixelSNAIL(nn.Module):
def __init__(self,nr_resnet=5, nr_filters=32, attn_rep=12, nr_logistic_mix=10, att_downsample=1):
# 声明类成员
self.nr_resnet = nr_resnet
self.nr_filters = nr_filters
self.attn_rep = attn_rep
self.nr_logistic_mix = nr_logistic_mix
self.att_downsample = att_downsample
# 声明定义模型对象
# 声明因果卷积的网络
self.down_shifted_conv2d = weight_norm(nn.Conv2d(3, self.nr_filters, kernel_size=(1, 3)))
self.down_right_shifted_conv2d = weight_norm(nn.Conv2d(3, self.nr_filters, kernel_size=(2, 1)))
# 声明包含若干门控残差网络的modulelist
self.gated_resnets = nn.ModuleList([GatedResNet(self.nr_filters) for _ in range(self.nr_resnet)])
# 声明线性模型
self.nin1 = nn.Linear(self.nr_filters, self.nr_filters // 2 + 16) # 假设q_size = 16
self.nin2 = nn.Linear(self.nr_filters, 16) # 假设q_size = 16
# 声明因果注意力模块
self.causal_attentions = nn.ModuleList([CausalAttention() for _ in range(self.attn_rep)])
# 最终的卷积网络
self.final_conv = nn.Conv2d(self.nr_filters, 10 * self.nr_logistic_mix, kernel_size=1)
def forward(self, x):
ul_list = []
# 加上一个是四个
# 按照左右上下的方式进行填充
down_shifted = F.pad(x, (1, 1, 0, 0)) # 自定义下移和右移操作
right_shifted = F.pad(x, (0, 0, 1, 0))
# 因果卷积
ds_conv = self.down_shifted_conv2d(down_shifted)
drs_conv = self.down_right_shifted_conv2d(right_shifted)
ul = ds_conv + drs_conv
# 下采样,从右下角开始
for causal_attention in self.causal_attentions:
for gated_resnet in self.gated_resnets:
ul = gated_resnet(ul)
print('attention module')
# 注意力机制
last_ul = ul_list[-1]
# 准备原始内容
raw_content =[x, last_ul], dim=1) # 假设背景信息已经添加到x中
# 生成key和value
raw = self.nin1(raw_content)
print('raw data shape',raw.shape)
key, mixin = raw.split(18, dim=1) # 假设q_size = 16
# 生成query查询键
raw_q = last_ul
query = self.nin2(raw_q)
# 计算注意力
mixed = causal_attention(key, mixin, query)
x_out = F.elu(ul_list[-1])
x_out = self.final_conv(x_out)
return x_out
x = torch.randn(64,3,32,32)
model = PixelSNAIL()
x_out = model(x)