一、infoGAN原理简介
普通的GAN存在无约束、不可控、噪声信号z很难解释等问题。InfoGAN 主要特点是对GAN进行了一些改动,成功地让网络学到了可解释的特征,网络训练完成之后,我们可以通过设定输入生成器的隐含编码来控制生成数据的特征。InfoGAN的基本结构为:
其中,真实数据Real_data只是用来跟生成的Fake_data混合在一起进行真假判断,并根据判断的结果更新生成器和判别器,从而使生成的数据与真实数据接近。生成数据既要参与真假判断,还需要和隐变量C_vector求互信息,并根据互信息更新生成器和判别器,从而使得生成图像中保留了更多隐变量C_vector的信息。
二、matlab代码实战
clear all; close all; clc;
%% Info Generative Adversarial Network
%% Load Data
load('mnistAll.mat')
trainX = preprocess(mnist.train_images);
trainY = mnist.train_labels;
testX = preprocess(mnist.test_images);
testY = mnist.test_labels;
%% Settings
args.maxepochs = 50; args.c_weight = 0.5; args.z_dim = 62;
args.batch_size = 16; args.image_size = [28,28,1];
args.lrD = 0.0002; args.lrG = 0.001; args.beta1 = 0.5;
args.beta2 = 0.999; args.cc_dim = 1; args.dc_dim = 10;
args.sample_size = 100;
%% Weights, Biases, Offsets and Scales
% Generator
paramsGen.FCW1 = dlarray(...
initializeGaussian([1024,args.z_dim+args.cc_dim+args.dc_dim]));
paramsGen.FCb1 = dlarray(zeros(1024,1,'single'));
paramsGen.BNo1 = dlarray(zeros(1024,1,'single'));
paramsGen.BNs1 = dlarray(ones(1024,1,'single'));
paramsGen.FCW2 = dlarray(initializeGaussian([128*7*7,1024]));
paramsGen.FCb2 = dlarray(zeros(128*7*7,1,'single'));
paramsGen.BNo2 = dlarray(zeros(128*7*7,1,'single'));
paramsGen.BNs2 = dlarray(ones(128*7*7,1,'single'));
paramsGen.TCW1 = dlarray(initializeGaussian([4,4,64,128]));
paramsGen.TCb1 = dlarray(zeros(64,1,'single'));
paramsGen.BNo3 = dlarray(zeros(64,1,'single'));
paramsGen.BNs3 = dlarray(ones(64,1,'single'));
paramsGen.TCW2 = dlarray(initializeGaussian([4,4,1,64]));
paramsGen.TCb2 = dlarray(zeros(1,1,'single'));
%% Progress Plot
function progressplot(args,paramsGen,stGen)
fixednoise = zeros(args.z_dim,args.sample_size);
tmp = zeros(args.cc_dim,args.sample_size);
for i = 1:10
tmp(1,(i-1)*10+1:i*10) = linspace(-2,2,10);
end
cc = tmp;
tmp = zeros(args.dc_dim,args.sample_size);
for i = 1:10
tmp(i,(i-1)*10+1:i*10) = 1;
end
dc = tmp;
fake_data = gpuArray(dlarray(cat(1,fixednoise,cc,dc),'CB'));
fake_images = extractdata(Generator(fake_data,paramsGen,stGen));
fig = gcf;
if ~isempty(fig.Children)
delete(fig.Children)
end
I = imtile(fake_images);
I = rescale(I);
imagesc(I)
title("Generated Images")
drawnow;
end
%% Report Progress
function [d_loss,g_loss] = reportprogress(x,z,paramsDis,...
paramsGen,args,stDis,stGen)
fake_images = Generator(z,paramsGen,stGen);
d_output_real = Discriminator(x,paramsDis,args,stDis);
d_output_fake = Discriminator(fake_images,paramsDis,args,stDis);
% Loss due to true or not
d_loss_a = -mean(log(d_output_real(1,:))+log(1-d_output_fake(1,:)));
g_loss_a = -mean(log(d_output_fake(1,:)));
% cc loss
output_cc = d_output_fake(2,:);
d_loss_cc = mean((output_cc/0.5).^2);
% softmax classification loss
output_dc = d_output_fake(3:end,:);
d_loss_dc = -(mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*output_dc,1))+...
mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*z(args.z_dim+args.cc_dim+1:end,:),1)));
% Discriminator Loss
d_loss = d_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
% Generator Loss
g_loss = g_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
end
%% Model Gradients
function [GradDis,GradGen,stDis,stGen] = modelGradients(x,z,paramsDis,...
paramsGen,args,stDis,stGen)
[fake_images,stGen] = Generator(z,paramsGen,stGen);
d_output_real = Discriminator(x,paramsDis,args,stDis);
[d_output_fake,stDis] = Discriminator(fake_images,paramsDis,args,stDis);
% Loss due to true or not
d_loss_a = -mean(log(d_output_real(1,:))+log(1-d_output_fake(1,:)));
g_loss_a = -mean(log(d_output_fake(1,:)));
% cc loss
output_cc = d_output_fake(2,:);
d_loss_cc = mean((output_cc/0.5).^2);
% softmax classification loss
output_dc = d_output_fake(3:end,:);
d_loss_dc = -(mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*output_dc,1))+...
mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*z(args.z_dim+args.cc_dim+1:end,:),1)));
% Discriminator Loss
d_loss = d_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
% Generator Loss
g_loss = g_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
% For each network, calculate the gradients with respect to the loss.
GradGen = dlgradient(g_loss,paramsGen,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end
结果展示