MATLAB实现自编码器(五)——变分自编码器(VAE)实现图像生成的帮助函数

本文是对Train Variational Autoencoder (VAE) to Generate Images网页的翻译,该网页实现了变分自编码的图像生成,以MNIST手写数字为训练数据,生成了相似的图像。本文主要翻译了网页中帮助函数外的部分。主要部分见MATLAB实现自编码器(四)——变分自编码器实现图像生成Train Variational Autoencoder (VAE) to Generate Images。

processImagesMNIST

首先是两个用于处理mnist数据集的函数,分别处理图片和标签,使其符合网络的输入要求。

function X = processImagesMNIST(filename)
% The MNIST processing functions extract the data from the downloaded IDX
% files into MATLAB arrays. The processImagesMNIST function performs these
% operations: Check if the file can be opened correctly. Obtain the magic
% number by reading the first four bytes. The magic number is 2051 for
% image data, and 2049 for label data. Read the next 3 sets of 4 bytes,
% which return the number of images, the number of rows, and the number of
% columns. Read the image data. Reshape the array and swaps the first two
% dimensions due to the fact that the data was being read in column major
% format. Ensure the pixel values are in the range  [0,1] by dividing them
% all by 255, and converts the 3-D array to a 4-D dlarray object. Close the
% file.

[fileID,errmsg] = fopen(filename,'r','b');
if fileID < 0
    error(errmsg);
end

magicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2051
    fprintf('\nRead MNIST image data...\n')
end

numImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of images in the dataset: %6d ...\n',numImages);
numRows = fread(fileID,1,'int32',0,'b');
numCols = fread(fileID,1,'int32',0,'b');

X = fread(fileID,inf,'unsigned char');

X = reshape(X,numCols,numRows,numImages);
X = permute(X,[2 1 3]);
X = X./255;
X = reshape(X, [28,28,1,size(X,3)]);
X = dlarray(X, 'SSCB');

fclose(fileID);
end

processImagesMNIST

处理标签,使其符合网络的输入要求

function Y = processLabelsMNIST(filename)
% The processLabelsMNIST function operates similarly to the
% processImagesMNIST function. After opening the file and reading the magic
% number, it reads the labels and returns a categorical array containing
% their values.

[fileID,errmsg] = fopen(filename,'r','b');

if fileID < 0
    error(errmsg);
end

magicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2049
    fprintf('\nRead MNIST label data...\n')
end

numItems = fread(fileID,1,'int32',0,'b');
fprintf('Number of labels in the dataset: %6d ...\n',numItems);

Y = fread(fileID,inf,'unsigned char');

Y = categorical(Y);

fclose(fileID);
end

Model Gradients Function

The modelGradients function takes the encoder and decoder dlnetwork objects and a mini-batch of input data X, and returns the gradients of the loss with respect to the learnable parameters in the networks. The function performs three operations:

  • Obtain the encodings by calling the sampling function on the mini-batch of images that passes through the encoder network.
  • Obtain the loss by passing the encodings through the decoder network and calling the ELBOloss function.
  • Compute the gradients of the loss with respect to the learnable parameters of both networks by calling the dlgradient function.

modelGradients函数获取编码器和解码器的dlnetwork对象以及输入数据X的小批量,并返回网络中可训练参数的损失梯度。 该函数执行三个操作:

  • 通过在通过编码器网络的微型图像批次上调用采样函数来获取编码。
  • 通过使编码通过解码器网络并调用ELBOloss函数来获得损耗。
  • 通过调用dlgradient函数,针对两个网络的可学习参数计算损耗的梯度。
function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x)
[z, zMean, zLogvar] = sampling(encoderNet, x);
xPred = sigmoid(forward(decoderNet, z));
loss = ELBOloss(x, xPred, zMean, zLogvar);
[genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ...
    encoderNet.Learnables);
end

Sampling and Loss Functions

The sampling function obtains encodings from input images. Initially, it passes a mini-batch of images through the encoder network and splits the output of size (2*latentDim)miniBatchSize into a matrix of means and a matrix of variances, each of size latentDimbatchSize. Then, it uses these matrices to implement the reparameterization trick and to compute the encoding. Finally, it converts this encoding to a dlarray object in SSCB format.

Sampling 函数从输入图像获取编码。 最初,它通过编码器网络传递一个图像的小批量,并将大小(2 × latentDim) × miniBatchSize的输出分成均值矩阵和方差矩阵,每个大小均为latentDim × batchSize。 然后,它使用这些矩阵来实现重新参数化技巧并计算编码。 最后,它将这种编码转换为SSCB格式的dlarray对象。

function [zSampled, zMean, zLogvar] = sampling(encoderNet, x)
compressed = forward(encoderNet, x);
d = size(compressed,1)/2;
zMean = compressed(1:d,:);
zLogvar = compressed(1+d:end,:);

sz = size(zMean);
epsilon = randn(sz);
sigma = exp(.5 * zLogvar);
z = epsilon .* sigma + zMean;
z = reshape(z, [1,1,sz]);
zSampled = dlarray(z, 'SSCB');
end

ELBOloss

The ELBOloss function takes the encodings of the means and the variances returned by the sampling function, and uses them to compute the ELBO loss.

ELBOloss函数采用均值和采样函数返回的方差的编码,并使用它们来计算ELBO损耗。

function elbo = ELBOloss(x, xPred, zMean, zLogvar)
squares = 0.5*(xPred-x).^2;
reconstructionLoss  = sum(squares, [1,2,3]);

KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1);

elbo = mean(reconstructionLoss + KL);
end

Visualization Functions

The VisualizeReconstruction function randomly chooses two images for each digit of the MNIST data set, passes them through the VAE, and plots the reconstruction side by side with the original input. Note that to plot the information contained inside a dlarray object, you need to extract it first using the extractdata and gather functions.

VisualizeReconstruction函数为MNIST数据集的每个数字随机选择两个图像,将它们通过VAE,然后与原始输入并排绘制。 请注意,要绘制dlarray对象中包含的信息,需要先使用extractdata and gather函数将其提取出来。

function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet)
f = figure;
figure(f)
title("Example ground truth image vs. reconstructed image")
for i = 1:2
    for c=0:9
        idx = iRandomIdxOfClass(YTest,c);
        X = XTest(:,:,:,idx);

        [z, ~, ~] = sampling(encoderNet, X);
        XPred = sigmoid(forward(decoderNet, z));
        
        X = gather(extractdata(X));
        XPred = gather(extractdata(XPred));

        comparison = [X, ones(size(X,1),1), XPred];
        subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]),
    end
end
end

function idx = iRandomIdxOfClass(T,c)
idx = T == categorical(c);
idx = find(idx);
idx = idx(randi(numel(idx),1));
end

VisualizeLatentSpace

The VisualizeLatentSpace function visualizes the latent space defined by the mean and the variance matrices that form the output of the encoder network, and locates the clusters formed by the latent space representations of each digit.

VisualizeLatentSpace函数可视化由形成编码器网络输出的均值和方差矩阵定义的潜在空间,并找到由每个数字的潜在空间表示形式形成的聚类。

The function starts by extracting the mean and the variance matrices from the dlarray objects. Because transposing a matrix with channel/batch dimensions (C and B) is not possible, the function calls stripdims before transposing the matrices. Then, it carries out a principal component analysis (PCA) on both matrices. To visualize the latent space in two dimensions, the function keeps the first two principal components and plots them against each other. Finally, the function colors the digit classes so that you can observe clusters.

该函数首先从dlarray对象中提取均值和方差矩阵。 由于无法转置具有通道/批处理尺寸(C和B)的矩阵,因此该函数在转置矩阵之前调用stripdims。 然后,它对两个矩阵执行主成分分析(PCA)。 为了在两个维度上可视化潜在空间,该函数保留前两个主要成分并将其相互绘制。 最后,该函数为数字类着色,以便观察群集。

function visualizeLatentSpace(XTest, YTest, encoderNet)
[~, zMean, zLogvar] = sampling(encoderNet, XTest);

zMean = stripdims(zMean)';
zMean = gather(extractdata(zMean));

zLogvar = stripdims(zLogvar)';
zLogvar = gather(extractdata(zLogvar));

[~,scoreMean] = pca(zMean);
[~,scoreLogvar] = pca(zLogvar);

c = parula(10);
f1 = figure;
figure(f1)
title("Latent space")

ah = subplot(1,2,1);
scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
axis equal
xlabel("Z_m_u(2)")
ylabel("Z_m_u(1)")
cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);

ah = subplot(1,2,2);
scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
xlabel("Z_v_a_r(2)")
ylabel("Z_v_a_r(1)")
cb = colorbar;  cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);
axis equal
end

generate

The generate function tests the generative capabilities of the VAE. It initializes a dlarray object containing 25 randomly generated encodings, passes them through the decoder network, and plots the outputs.

生成函数测试VAE的生成能力。 它初始化包含25个随机生成的编码的dlarray对象,将它们传递通过解码器网络,并绘制输出。

function generate(decoderNet, latentDim)
randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB');
generatedImage = sigmoid(predict(decoderNet, randomNoise));
generatedImage = extractdata(generatedImage);

f3 = figure;
figure(f3)
imshow(imtile(generatedImage, "ThumbnailSize", [100,100]))
title("Generated samples of digits")
drawnow
end

你可能感兴趣的:(MATLAB深度学习,matlab,深度学习,变分自编码器)