本文内容参考了Conditional VAE (Variational Auto Encoder) 条件付きVAE
是对官方网页Train Variational Autoencoder (VAE) to Generate Images的改进,该网页的翻译可见MATLAB实现自编码器(四)——变分自编码器实现图像生成Train Variational Autoencoder (VAE) to Generate Images。
下载数据并解压,然后加载,没有变化
与官方网页相比,改进了代码,更加细致,同时增加了解析
自动编码器包含两个部分:编码器和解码器。 编码器接受图像输入并输出压缩表示(编码),压缩表示是大小为latent_dim的矢量,在此示例中为20。 解码器获取压缩表示,对其进行解码,然后重新创建原始图像。
变分自编码器的结构如下图所示,后面带有具体部分的意义
编码器使用卷积层,relu层和完全连接层将图像和标签信息压缩到潜在空间上。
latentDim = 20; % 压缩表示的大小
imageSize = [28 28 1]; % 输入图像的大小
numClasses = size(countcats(YTrain),1); % 定义输入图像的种类数(在此演示中:10)
encoderLG = layerGraph([
imageInputLayer(imageSize,'Name','input_encoder','Normalization','none') % 图像输入层
concatenationLayer(3,2,'Name','cat') % a layer to gather the two-input % 收集两个输入的层
convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1') % 卷积层1,32个3*3的卷积核,全0填充,步长为2
reluLayer('Name','relu1') % 激活函数层1
convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2') % 卷积层2,32个3*3的卷积核,全0填充,步长为2
reluLayer('Name','relu2') % 激活函数层2
fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder') % 编码器输出潜在空间上值的均值和方差
]);
类别(class)信息被转换为独热向量,其中与目标图像相对应的索引为1(其余为0),具体见下图。
标签输入被转换为嵌入向量,并被整形为28×28。 下面的示例为4 x 4。
要嵌入和调整标签输入的形状,请使用附加到此示例作为支持文件的自定义层embedAndReshapeLayer。
embeddingDimension = 50; % 标签调整后的大小
encoderLabel = [ % 对标签进行编码
imageInputLayer([1 1],'Name','labels','Normalization','none')
embedAndReshapeLayer(imageSize,embeddingDimension,numClasses,'emb')];
lgraphDiscriminator = addLayers(encoderLG,encoderLabel); % 标签编码层添加到编码器中
lgraphDiscriminator = connectLayers(lgraphDiscriminator,'emb','cat/in2'); % 标签编码的输出是cat层的两个输入之一
analyzeNetwork(lgraphDiscriminator) % 分析整个编码器网络
在此演示中,我使用ReLu层进行激活,但Leaky ReLu倾向于在GAN(生成对抗网络)中使用。
解码器将21 x 1的输入比例放大为28 x 28。 上采样如下图所示。
其中转置卷积层的介绍见transposedConv2dLayer
decoderLG = layerGraph([
imageInputLayer([1 1 latentDim],'Name','i','Normalization','none') % 潜在维度在最后一部分中定义
concatenationLayer(3,2,'Name','cat') % 收集两个输入的层
transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1') %转置卷积层进行上采样,64个7*7的卷积核,步长为7
reluLayer('Name','relu1') % 激活层
transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2')
reluLayer('Name','relu2')
transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3')
reluLayer('Name','relu3')
transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4')
]);
如编码器中所述,标签输入被嵌入到向量中并被整形以用作解码器的输入。
embeddingDimension = 20; % 标签调整后的大小
projectionSize = [1 1];
layers = [ % 对标签进行编码和整形
imageInputLayer([1 1],'Name','labels','Normalization','none')
embedAndReshapeLayer(projectionSize(1:2),embeddingDimension,numClasses,'emb')];
lgraphGenerator = addLayers(decoderLG,layers);
lgraphGenerator = connectLayers(lgraphGenerator,'emb','cat/in2'); % 标签编码的输出是cat层的两个输入之一
analyzeNetwork(lgraphGenerator) % 分析整个解码器生成网络
要使用自定义训练循环训练两个网络并启用自动微分,将层图转换为dlnetwork对象。
encoderNet = dlnetwork(lgraphDiscriminator);
decoderNet = dlnetwork(lgraphGenerator);
梯度函数的介绍与原网页一致,但是将原来解码编码过程图替换为新的图片,可以结合二者进行理解。
这一部分没有变化,见原网页代码。
增加了不进行训练的选项,可以直接调用保存好的网络(有点迁移学习的味道),没有进行具体的翻译,暂且搁置。
cpu版的电脑不建议进行训练,耗时太长!!
doTraining=1; % 设置为1进行PC训练
if doTraining==1
for epoch = 1:numEpochs % the tranining data is learned in total of "numEpochs" times
tic;
for i = 1:numIterations
iteration = iteration + 1;
idx = (i-1)*miniBatchSize+1:i*miniBatchSize; % the mini-batch is not shuffled
XBatch = XTrain(:,:,:,idx); %Obtain the next mini-batch from the training set.
XBatch = dlarray(single(XBatch), 'SSCB'); % conver the mini-batch data to use VAE
YBatch = permute(YTrain(idx),[2 3 4 1]);%exchange the dimension in the TValidation
% For example, the 2nd dimension goes to 4th dimension
YBatch = dlarray(single(YBatch), 'SSCB'); % Convert the mini-batch to a dlarray object,
% making sure to specify the dimension labels 'SSCB' (spatial, spatial, channel, batch).
% if your GPU is available for the training,convert the dlarray to a gpuArray object.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
XBatch = gpuArray(XBatch);
end
% Evaluate the model gradients using the dlfeval and modelGradients functions.
[infGrad, genGrad] = dlfeval(...
@modelGradients, encoderNet, decoderNet, XBatch,YBatch);
% Update the network learnables and the average gradients for both networks, using the adamupdate function.
[decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ...
adamupdate(decoderNet.Learnables, ...
genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr);
[encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ...
adamupdate(encoderNet.Learnables, ...
infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr);
end
elapsedTime = toc;
% estimate the mean and variace with log scale on the latent space
% assuming it has a normal distribution
[z, zMean, zLogvar] = sampling(encoderNet, XTraindl,YTraindl);
% decode the compressed value to the image-shape value
xPred = sigmoid(forward(decoderNet,z,YTraindl));
% calculate the ELBO loss. the loss decreases when the
elbo = ELBOloss(XTraindl, xPred, zMean, zLogvar);
addpoints(lineLossTrain,iteration,double(gather(extractdata(elbo))))
title("Loss During Training: Epoch - " + epoch + "; Iteration - " + iteration)
drawnow
end
else
load encoderNet; load encoderNet
end
更加细致,三部分作了具体介绍
要可视化和解释结果,请使用帮助程序可视化功能。 这些帮助器功能在本示例的最后定义。
VisualizeReconstruction函数显示从每个类中随机选择的一位,并在通过自动编码器后对其进行重构。
visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)
VisualizeLatentSpace函数采用将测试图像通过编码器网络后生成的均值和方差编码(每个维度为20),并对包含每个图像编码的矩阵执行t-SNE。 然后,您可以可视化由均值定义的潜在空间以及以t-SNE为特征的二维方差。
有关t-SNE的详细信息,请参阅下面的补充文件或Federico Errica博士撰写的页面。Step-By-Step Derivation of SNE and t-SNE gradients
visualizeLatentSpace(XTest, YTest, encoderNet)
该功能是新增加的,首先检查训练数据中每个数字被压缩到的潜在空间。 然后使用每个数字的潜在空间来合成每个数字。
将合成数字保存到GIF文件中。
numRepeat=30;% 重复合成过程的次数。 每一次操作合成一百个数字。
TileForGif=generateDigits(decoderNet,encoderNet,XTrain,YTrain,numRepeat);% this fumction was made in the end of this script
fig=figure;set(gcf,'visible','on') % impose the figure outside the script
filename = 'conditional_VAE.gif'; % Specify the output file name
for i=1:numRepeat
imshow(TileForGif(:,:,:,i))
title(sprintf('pattern: %d',i))
pause(.1)
drawnow
frame = getframe(fig);
[A,map] = rgb2ind(frame.cdata,256);
if i == 1
imwrite(A,map,filename,'gif','LoopCount',Inf,'DelayTime',.1);
else
imwrite(A,map,filename,'gif','WriteMode','append','DelayTime',.1);
end
end
可变自动编码器只是用于执行生成任务的众多可用模型之一。 它们适用于图像较小且具有明确定义的特征的数据集(例如MNIST)。 对于具有较大图像的更复杂的数据集,生成对抗网络(GAN)往往会表现更好,并生成噪声较小的图像。 有关显示如何实施GAN生成64×64RGB图像的示例,请参阅训练生成对抗网络Train Generative Adversarial Network (GAN)。
相关函数请看文章MATLAB实现自编码器(五)——变分自编码器(VAE)实现图像生成的帮助函数
有些改进,更加具体