MATLAB实现自编码器(六)——变分自编码器(VAE)官网代码的改进

本文内容参考了Conditional VAE (Variational Auto Encoder) 条件付きVAE

是对官方网页Train Variational Autoencoder (VAE) to Generate Images的改进,该网页的翻译可见MATLAB实现自编码器(四)——变分自编码器实现图像生成Train Variational Autoencoder (VAE) to Generate Images。

1.Load Data

下载数据并解压,然后加载,没有变化

2.Construct Network

与官方网页相比,改进了代码,更加细致,同时增加了解析

2.1.Overview

自动编码器包含两个部分:编码器和解码器。 编码器接受图像输入并输出压缩表示(编码),压缩表示是大小为latent_dim的矢量,在此示例中为20。 解码器获取压缩表示,对其进行解码,然后重新创建原始图像。
变分自编码器的结构如下图所示,后面带有具体部分的意义
MATLAB实现自编码器(六)——变分自编码器(VAE)官网代码的改进_第1张图片

2.1.1编码器

  • 面板[a]显示尺寸为28 x 28的图像输入。
  • 面板[b]是独热向量形式的标签输入。下一部分将对此进行说明。
  • 在面板[c]中,标签输入被嵌入到向量中,并被整形为28 x 28。
  • 输入的图像和重塑的标签覆盖在面板[d]中。
  • 在面板[e]中,分层图像被卷积并向下采样。
  • 在面板[f]中,神经网络返回潜在空间的均值和方差。请注意,本演示假定图像可以被压缩为矢量。我们还假设向量(潜空间)具有正态分布。如果采用其他类型的分布,则可以自定义分布的类型,这可能会产生更好的结果。
  • 从面板的正态分布中抽样随机值[g]。

2.1.2解码器

  • 如面板[h]所示,标签信息输入到解码器。
  • 如面板[i]中所示,使用编码器估算的均值和方差对随机值进行采样。
  • 在面板[j]中,将随机值和标签信息连接在一起。
  • 在面板[k]中,连接的向量按比例放大为28×28。
  • 在面板[l]中,将输入图像重构为最终输出图像。

2.2.Define an encoder ([a], [d], [e], and [f])

编码器使用卷积层,relu层和完全连接层将图像和标签信息压缩到潜在空间上。

latentDim = 20; % 压缩表示的大小
imageSize = [28 28 1]; % 输入图像的大小
numClasses = size(countcats(YTrain),1); % 定义输入图像的种类数(在此演示中:10)
encoderLG = layerGraph([
    imageInputLayer(imageSize,'Name','input_encoder','Normalization','none')  % 图像输入层
    concatenationLayer(3,2,'Name','cat') % a layer to gather the two-input    % 收集两个输入的层
    convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1') % 卷积层1323*3的卷积核,全0填充,步长为2
    reluLayer('Name','relu1')  % 激活函数层1
    convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2') % 卷积层2323*3的卷积核,全0填充,步长为2
    reluLayer('Name','relu2')  % 激活函数层2
    fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder') % 编码器输出潜在空间上值的均值和方差
    ]);

2.3.Define an encoder ([b] and [c])

类别(class)信息被转换为独热向量,其中与目标图像相对应的索引为1(其余为0),具体见下图。
标签输入被转换为嵌入向量,并被整形为28×28。 下面的示例为4 x 4。
要嵌入和调整标签输入的形状,请使用附加到此示例作为支持文件的自定义层embedAndReshapeLayer。
MATLAB实现自编码器(六)——变分自编码器(VAE)官网代码的改进_第2张图片

embeddingDimension = 50;  % 标签调整后的大小
encoderLabel = [          % 对标签进行编码
    imageInputLayer([1 1],'Name','labels','Normalization','none')
    embedAndReshapeLayer(imageSize,embeddingDimension,numClasses,'emb')];

lgraphDiscriminator = addLayers(encoderLG,encoderLabel); % 标签编码层添加到编码器中
lgraphDiscriminator = connectLayers(lgraphDiscriminator,'emb','cat/in2'); % 标签编码的输出是cat层的两个输入之一
analyzeNetwork(lgraphDiscriminator)  % 分析整个编码器网络

2.4.Construct decoder ([i], [j], [k] and [l])

在此演示中,我使用ReLu层进行激活,但Leaky ReLu倾向于在GAN(生成对抗网络)中使用。
解码器将21 x 1的输入比例放大为28 x 28。 上采样如下图所示。
MATLAB实现自编码器(六)——变分自编码器(VAE)官网代码的改进_第3张图片
其中转置卷积层的介绍见transposedConv2dLayer

decoderLG = layerGraph([
    imageInputLayer([1 1 latentDim],'Name','i','Normalization','none') % 潜在维度在最后一部分中定义
    concatenationLayer(3,2,'Name','cat')   % 收集两个输入的层
    transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1') %转置卷积层进行上采样,647*7的卷积核,步长为7
    reluLayer('Name','relu1')  % 激活层
    transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4')
    ]);

2.5.Construct decoder ([g] and [h])

如编码器中所述,标签输入被嵌入到向量中并被整形以用作解码器的输入。

embeddingDimension = 20; % 标签调整后的大小
projectionSize = [1 1];   
layers = [               % 对标签进行编码和整形   
    imageInputLayer([1 1],'Name','labels','Normalization','none')
    embedAndReshapeLayer(projectionSize(1:2),embeddingDimension,numClasses,'emb')];
lgraphGenerator = addLayers(decoderLG,layers);
lgraphGenerator = connectLayers(lgraphGenerator,'emb','cat/in2'); % 标签编码的输出是cat层的两个输入之一
analyzeNetwork(lgraphGenerator)   % 分析整个解码器生成网络

2.6.图层转换

要使用自定义训练循环训练两个网络并启用自动微分,将层图转换为dlnetwork对象。

encoderNet = dlnetwork(lgraphDiscriminator);
decoderNet = dlnetwork(lgraphGenerator);

3.Define Model Gradients Function定义模型梯度函数

梯度函数的介绍与原网页一致,但是将原来解码编码过程图替换为新的图片,可以结合二者进行理解。
MATLAB实现自编码器(六)——变分自编码器(VAE)官网代码的改进_第4张图片

4.Specify Training Options指定训练选项

这一部分没有变化,见原网页代码。

5.Train Model模型训练

增加了不进行训练的选项,可以直接调用保存好的网络(有点迁移学习的味道),没有进行具体的翻译,暂且搁置。
cpu版的电脑不建议进行训练,耗时太长!!

doTraining=1; % 设置为1进行PC训练
if doTraining==1
for epoch = 1:numEpochs % the tranining data is learned in total of "numEpochs" times
    tic;
    for i = 1:numIterations 
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize; % the mini-batch is not shuffled
        XBatch = XTrain(:,:,:,idx); %Obtain the next mini-batch from the training set.
        XBatch = dlarray(single(XBatch), 'SSCB'); % conver the mini-batch data to use VAE
        YBatch = permute(YTrain(idx),[2 3 4 1]);%exchange the dimension in the TValidation
        % For example, the 2nd dimension goes to 4th dimension
        YBatch = dlarray(single(YBatch), 'SSCB'); % Convert the mini-batch to a dlarray object, 
        % making sure to specify the dimension labels 'SSCB' (spatial, spatial, channel, batch).
        
        % if your GPU is available for the training,convert the dlarray to a gpuArray object.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            XBatch = gpuArray(XBatch);           
        end 
        
        % Evaluate the model gradients using the dlfeval and modelGradients functions.    
        [infGrad, genGrad] = dlfeval(...
            @modelGradients, encoderNet, decoderNet, XBatch,YBatch);
        
        % Update the network learnables and the average gradients for both networks, using the adamupdate function.
        [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ...
            adamupdate(decoderNet.Learnables, ...
                genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr);
        [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ...
            adamupdate(encoderNet.Learnables, ...
                infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr);
        
    end
    elapsedTime = toc;
    
    % estimate the mean and variace with log scale on the latent space
    % assuming it has a normal distribution
    [z, zMean, zLogvar] = sampling(encoderNet, XTraindl,YTraindl);
    
    % decode the compressed value to the image-shape value 
    xPred = sigmoid(forward(decoderNet,z,YTraindl));
    
    % calculate the ELBO loss. the loss decreases when the  
    elbo = ELBOloss(XTraindl, xPred, zMean, zLogvar);
    addpoints(lineLossTrain,iteration,double(gather(extractdata(elbo))))
    title("Loss During Training: Epoch - " + epoch + "; Iteration - " + iteration)
    drawnow   
end
else
    load encoderNet; load encoderNet
end

6.Visualize Results可视化结果

更加细致,三部分作了具体介绍

6.1.Encode-Decode

要可视化和解释结果,请使用帮助程序可视化功能。 这些帮助器功能在本示例的最后定义。
VisualizeReconstruction函数显示从每个类中随机选择的一位,并在通过自动编码器后对其进行重构。

visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)

6.2.Display the distribution of the latent space after dimension reduction with t-SNE

VisualizeLatentSpace函数采用将测试图像通过编码器网络后生成的均值和方差编码(每个维度为20),并对包含每个图像编码的矩阵执行t-SNE。 然后,您可以可视化由均值定义的潜在空间以及以t-SNE为特征的二维方差。
有关t-SNE的详细信息,请参阅下面的补充文件或Federico Errica博士撰写的页面。Step-By-Step Derivation of SNE and t-SNE gradients

visualizeLatentSpace(XTest, YTest, encoderNet)

6.3.Synthesizing digits using the conditional VAE constructed in this script

该功能是新增加的,首先检查训练数据中每个数字被压缩到的潜在空间。 然后使用每个数字的潜在空间来合成每个数字。
将合成数字保存到GIF文件中。

numRepeat=30;% 重复合成过程的次数。 每一次操作合成一百个数字。
TileForGif=generateDigits(decoderNet,encoderNet,XTrain,YTrain,numRepeat);% this fumction was made in the end of this script
fig=figure;set(gcf,'visible','on') % impose the figure outside the script
filename = 'conditional_VAE.gif'; % Specify the output file name
for i=1:numRepeat
    imshow(TileForGif(:,:,:,i))
    title(sprintf('pattern: %d',i))
    pause(.1)
    drawnow
    frame = getframe(fig);
    [A,map] = rgb2ind(frame.cdata,256);
    if i == 1
        imwrite(A,map,filename,'gif','LoopCount',Inf,'DelayTime',.1);
    else
        imwrite(A,map,filename,'gif','WriteMode','append','DelayTime',.1);
    end
end

7.下一步

可变自动编码器只是用于执行生成任务的众多可用模型之一。 它们适用于图像较小且具有明确定义的特征的数据集(例如MNIST)。 对于具有较大图像的更复杂的数据集,生成对抗网络(GAN)往往会表现更好,并生成噪声较小的图像。 有关显示如何实施GAN生成64×64RGB图像的示例,请参阅训练生成对抗网络Train Generative Adversarial Network (GAN)。

8.Helper Functions帮助函数

相关函数请看文章MATLAB实现自编码器(五)——变分自编码器(VAE)实现图像生成的帮助函数
有些改进,更加具体

你可能感兴趣的:(MATLAB深度学习,深度学习,matlab,变分自编码器)