在上一篇就是介绍了矢量量化变分模型的具体实现,就是一个编码器和解码器,只能生成和原来图片一样的图片,没啥意义。这里需要生成一个新的码字序列,解码器能够接受这部分数据,然后解码成对应的新的图片。
作者使用PixelCNN去训练这些码本
PixelCNN是一个自回归模型,根据已有的序列生成下一个位置的值。在这个任务里,就是生辰新的码字序列,然后使用训练好的解码器生成对应的新的图片。
注意:
自回归创建一个显式密度模型,该模型学习训练数据的最大似然。但是处理多个维度\特征的数据时,需要完整如下的步骤
图像中某一个像素具有特定强度值的概率由先前像素的值确定
图像的概率(所有像素的联合分布)是其所有像素的概率的组合
因此,自回归模型使用链式法则,将数据样本x的可能性分解成一维分布的乘积,将联合建模问题变成了序列问题,学习了在给定所有先前像素的情况下,预测下一个像素的过程。
经历过残差序列模块的处理之后,经过两个带有relu的1*1的卷积层,然后经过softmax进行输出,预测像素所有可能的预测值。模型的输出是具有与输入图像大小相同的格式乘以可能值的数量。
# 掩码卷积
class PixelConvLayer(Layer):
"""
掩码卷积层,分别是两种类型,A类和B类
"""
def __init__(self,mask_type,**kwargs):
super().__init__()
self.mask_type = mask_type
self.conv = Conv2D(**kwargs)
self.mask = None
def build(self,input_shape):
# 创建二维卷积层,并初始化对应的卷积核权重参数
self.conv.build(input_shape)
kernel_shape = self.conv.kernel.get_shape()
# 生成同等大小的掩码层,用来抑制没有预测的卷积层
self.mask = np.zeros(shape = kernel_shape)
self.mask[:kernel_shape[0] // 2,...] = 1.0
self.mask[kernel_shape[0] // 2,:kernel_shape[1] // 2,...] = 1.0
if self.mask_type == "B":
self.mask[kernel_shape[0] // 2,kernel_shape[1]//2,...] = 1.0
def call(self,inputs):
# 根据掩码层,来修改更正卷积层的结果
self.conv.kernel.assign(self.conv.kernel * self.mask)
return self.conv(inputs)
# 掩码卷积
class PixelConvLayer(Layer):
"""
掩码卷积层,分别是两种类型,A类和B类
"""
def __init__(self,mask_type,**kwargs):
super().__init__()
self.mask_type = mask_type
self.conv = Conv2D(**kwargs)
self.mask = None
def build(self,input_shape):
# 创建二维卷积层
self.conv.build(input_shape)
kernel_shape = self.conv.kernel.get_shape()
# 将卷积层改成,掩码卷积,这里是生成掩码层
self.mask = np.zeros(shape = kernel_shape)
self.mask[:kernel_shape[0] // 2,...] = 1.0
self.mask[kernel_shape[0] // 2,:kernel_shape[1] // 2,...] = 1.0
if self.mask_type == "B":
self.mask[kernel_shape[0] // 2,kernel_shape[1]//2,...] = 1.0
def call(self,inputs):
self.conv.kernel.assign(self.conv.kernel * self.mask)
return self.conv(inputs)
class ResidualBlock(Layer):
"""
基于掩码卷积层的残差模块
"""
def __init__(self,filters,**kwargs):
super().__init__()
# 第一个卷积模块
self.conv1 = Conv2D(
filters = filters,kernel_size = 1,activation="relu"
)
# 中间的B类掩码卷积,进行概率估计
self.pixel_conv = PixelConvLayer(
mask_type="B",
kernel_size = 3,
activation = "relu",
padding = "same",
filters = filters // 2
)
# 最后一个卷积模块
self.conv2 = Conv2D(
filters = filters,kernel_size = 1,activation="relu"
)
def call(self,inputs):
# 前向传播中,数据传播的方式
x = self.conv1(inputs)
x = self.pixel_conv(x)
x = self.conv2(x)
# 残差模块中的直连
res = add([inputs,x])
return res
class PixelCNN:
"""
PixelCNN:像素卷积模型
功能:用于生成新的码本序列
"""
def __init__(self,
input_shape,
num_residual_blocks,
num_pixelcnn_layers,
num_embeddings
):
self.input_shape = input_shape
self.num_residual_blocks = num_residual_blocks
self.num_pixelcnn_layers = num_pixelcnn_layers
self.num_embeddings = num_embeddings
# 定义不同的层
self.model = None
# 构建模型
self._build()
def _build(self):
"""
网络结构为:Aconv3*3,residual815,Rconv,Rconv,Softmax
:return:
"""
PixelCNN_Input = Input(shape = self.input_shape,name = "PixelCNN input")
x = PixelConvLayer(mask_type="A",
filters = 128,
kernel_size = 7,
strides = 1,
activation = "relu",
padding = "same"
)(PixelCNN_Input)
# 添加卷积模块
for i in range(self.num_residual_blocks):
x = ResidualBlock(filters = 128)(x)
# 添加后续的像素卷积层
for i in range(self.num_pixelcnn_layers):
x = PixelConvLayer(
mask_type="B",
filters = 128,
kernel_size = 1,
strides = 1,
activation = "relu",
padding = "valid"
)(x)
out = Conv2D(
filters=1,
kernel_size=1,
strides=1,
activation="sigmoid",
padding="valid"
)(x)
self.model = Model(PixelCNN_Input,out)
def summary(self):
self.model.summary()
def compile(self, learning_rate=0.0001):
""" 指定损失函数和优化器,并对模型进行优化 """
optimizer = Adam(learning_rate=learning_rate)
self.model.compile(
optimizer=optimizer,
loss="binary_crossentropy",
)
# 3.3 增加训练函数
def train(self, x_train, batch_size, num_epochs):
self.model.fit(
x_train,
x_train,
batch_size=batch_size,
epochs=num_epochs,
shuffle=True
)
问题:我们知道,PixelCNN的原理是将联合分布拆解成多个条件概率的分布,也就是说预测每一个像素需要考虑之前每一个值,但是掩码卷积仅仅是使用了卷积核之内的数据,并没有考虑到每一个数据。
PixelCNN使用了掩码卷积(masked convolution)来限制卷积核只能看到先前生成的像素值,从而确保每个像素的生成仅依赖于其前面的像素。确实并不能考虑到之前所有的像素点,只能考虑到当前像素点的感受野中的元素,还是条件概率,但是并不是之前所有的像素点。
具体可以看这个图片,五角星的那个像素点,确实只能考虑圆圈圈出来的点,计算他们的条件概率,并不能考虑到没有圈处来的其他的点,但这是只有一层卷积,如果有多层卷积,感受野会更大,考虑的也会更加全面。