目录
1. 原论文
2. libtorch复现
2.1 生成器
2.2 判别器
2.3 完整训练代码
2.4 训练效果
论文:https://arxiv.org/pdf/1406.2661.pdf
pytorch源码:https://github.com/devnag/pytorch-generative-adversarial-networks/blob/master/gan_pytorch.py
详细介绍可参考:pytorch实现GAN_Mr.Q的博客-CSDN博客_pytorch实现gan
生成器是全链接网络,输入是正态分布的随机数,size是(64,100),5次全链接层后得到(64,784),再view成(64,1,28,28). 由于中间有个tanh正切激活函数,所以输出值的范围在(-1,1)之间。
// ln+bn+relu
class LnBnReluImpl : public torch::nn::Module {
public:
LnBnReluImpl(int64_t in_c, int64_t out_c, bool normalize);
torch::Tensor forward(torch::Tensor x);
private:
torch::nn::Linear ln = nullptr;
torch::nn::BatchNorm1d bn = nullptr;
torch::nn::LeakyReLU LReLU = nullptr;
bool normalize = true;
};
TORCH_MODULE(LnBnRelu);
LnBnReluImpl::LnBnReluImpl(int64_t in_c, int64_t out_c, bool normalize)
{
this->normalize = normalize;
ln = torch::nn::Linear(in_c, out_c);
if (normalize)
bn = torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(out_c).eps(0.8)); // epsilon=0.8
LReLU = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true)); // inplace=true
register_module("block ln", ln);
if (normalize)
register_module("block bn", bn);
}
torch::Tensor LnBnReluImpl::forward(torch::Tensor x)
{
x = ln->forward(x);
if (normalize)
x = bn(x);
x = LReLU->forward(x);
return x;
}
// 全链接网络:4个ln+bn+relu, 再接ln, tanh
class GeneratorImpl : public torch::nn::Module {
public:
GeneratorImpl();
torch::Tensor forward(torch::Tensor x);
private:
LnBnRelu fc1 = nullptr;
LnBnRelu fc2 = nullptr;
LnBnRelu fc3 = nullptr;
LnBnRelu fc4 = nullptr;
torch::nn::Linear fc5{ nullptr };
};
TORCH_MODULE(Generator);
GeneratorImpl::GeneratorImpl() {
fc1 = LnBnRelu(NOISE_SIZE, 128, false);
fc2 = LnBnRelu(128, 256, true);
fc3 = LnBnRelu(256, 512, true);
fc4 = LnBnRelu(512, 1024, true);
fc5 = torch::nn::Linear(1024, int(IMAGE_SHAPE[0] * IMAGE_SHAPE[1] * IMAGE_SHAPE[2]));
register_module("generator fc1", fc1);
register_module("generator fc2", fc2);
register_module("generator fc3", fc3);
register_module("generator fc4", fc4);
register_module("generator fc5", fc5);
};
torch::Tensor GeneratorImpl::forward(torch::Tensor x) // (64,100)
{
x = fc1(x);
x = fc2(x);
x = fc3(x);
x = fc4(x);
x = fc5(x);
x = torch::tanh(x); // (-1,1)
x = x.view({ x.sizes()[0], IMAGE_SHAPE[0], IMAGE_SHAPE[1], IMAGE_SHAPE[2] }); // (64,1,28,28)
return x;
}
判别器也是全链接网络,输入的数据是(64,1,28,28)大小,先view成(64,784)向量,再经过3次全连接层,得到size(64,1)值 ,最后经过sigmoid输出分数值在(0,1)之间。值越接近1说明输入的越真。
// 全链接网络:两个ln+relu, 再接一个ln,最后一个sigmoid分类。
class DiscriminatorImpl : public torch::nn::Module {
public:
DiscriminatorImpl();
torch::Tensor forward(torch::Tensor x);
private:
torch::nn::Linear fc1{ nullptr };
torch::nn::Linear fc2{ nullptr };
torch::nn::Linear fc3{ nullptr };
torch::nn::LeakyReLU relu1{ nullptr };
torch::nn::LeakyReLU relu2{ nullptr };
};
TORCH_MODULE(Discriminator);
DiscriminatorImpl::DiscriminatorImpl() {
fc1 = torch::nn::Linear(IMAGE_SHAPE[0] * IMAGE_SHAPE[1] * IMAGE_SHAPE[2], 512);
relu1 = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));
fc2 = torch::nn::Linear(512, 256);
relu2 = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));
fc3 = torch::nn::Linear(256, 1);
register_module("disciminator fc1", fc1);
register_module("disciminator fc2", fc2);
register_module("disciminator fc3", fc3);
}
torch::Tensor DiscriminatorImpl::forward(torch::Tensor x) // (64,1,28,28)
{
x = x.view({ x.sizes()[0], -1 }); // (64,784)
x = fc1(x);
x = relu1(x);
x = fc2(x);
x = relu2(x);
torch::Tensor validity = fc3(x); // (64,1)
validity = torch::sigmoid(validity); // (0,1)
return validity;
}
这里是基于minist数据,生成手写字体。后面将探索实现生成人脸数据。
#include
#include
#include
#include
#include
#include
const std::string DATA_FOLDER = "./data/MNIST/raw";
const int64_t BATCH_SIZE = 64;
const int64_t N_EPOCHS = 200;
const int64_t NOISE_SIZE = 100;
torch::Device device(torch::kCPU);
std::vector IMAGE_SHAPE{ 1, 28, 28 };
// ln+bn+relu
class LnBnReluImpl : public torch::nn::Module {
public:
LnBnReluImpl(int64_t in_c, int64_t out_c, bool normalize);
torch::Tensor forward(torch::Tensor x);
private:
torch::nn::Linear ln = nullptr;
torch::nn::BatchNorm1d bn = nullptr;
torch::nn::LeakyReLU LReLU = nullptr;
bool normalize = true;
};
TORCH_MODULE(LnBnRelu);
LnBnReluImpl::LnBnReluImpl(int64_t in_c, int64_t out_c, bool normalize)
{
this->normalize = normalize;
ln = torch::nn::Linear(in_c, out_c);
if (normalize)
bn = torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(out_c).eps(0.8)); // epsilon=0.8
LReLU = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true)); // inplace=true
register_module("block ln", ln);
if (normalize)
register_module("block bn", bn);
}
torch::Tensor LnBnReluImpl::forward(torch::Tensor x)
{
x = ln->forward(x);
if (normalize)
x = bn(x);
x = LReLU->forward(x);
return x;
}
// 全链接网络:4个ln+bn+relu, 再接ln, tanh
class GeneratorImpl : public torch::nn::Module {
public:
GeneratorImpl();
torch::Tensor forward(torch::Tensor x);
private:
LnBnRelu fc1 = nullptr;
LnBnRelu fc2 = nullptr;
LnBnRelu fc3 = nullptr;
LnBnRelu fc4 = nullptr;
torch::nn::Linear fc5{ nullptr };
};
TORCH_MODULE(Generator);
GeneratorImpl::GeneratorImpl() {
fc1 = LnBnRelu(NOISE_SIZE, 128, false);
fc2 = LnBnRelu(128, 256, true);
fc3 = LnBnRelu(256, 512, true);
fc4 = LnBnRelu(512, 1024, true);
fc5 = torch::nn::Linear(1024, int(IMAGE_SHAPE[0] * IMAGE_SHAPE[1] * IMAGE_SHAPE[2]));
register_module("generator fc1", fc1);
register_module("generator fc2", fc2);
register_module("generator fc3", fc3);
register_module("generator fc4", fc4);
register_module("generator fc5", fc5);
};
torch::Tensor GeneratorImpl::forward(torch::Tensor x) // (64,100)
{
x = fc1(x);
x = fc2(x);
x = fc3(x);
x = fc4(x);
x = fc5(x);
x = torch::tanh(x); // (-1,1)
x = x.view({ x.sizes()[0], IMAGE_SHAPE[0], IMAGE_SHAPE[1], IMAGE_SHAPE[2] }); // (64,1,28,28)
return x;
}
// 全链接网络:两个ln+relu, 再接一个ln,最后一个sigmoid分类。
class DiscriminatorImpl : public torch::nn::Module {
public:
DiscriminatorImpl();
torch::Tensor forward(torch::Tensor x);
private:
torch::nn::Linear fc1{ nullptr };
torch::nn::Linear fc2{ nullptr };
torch::nn::Linear fc3{ nullptr };
torch::nn::LeakyReLU relu1{ nullptr };
torch::nn::LeakyReLU relu2{ nullptr };
};
TORCH_MODULE(Discriminator);
DiscriminatorImpl::DiscriminatorImpl() {
fc1 = torch::nn::Linear(IMAGE_SHAPE[0] * IMAGE_SHAPE[1] * IMAGE_SHAPE[2], 512);
relu1 = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));
fc2 = torch::nn::Linear(512, 256);
relu2 = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));
fc3 = torch::nn::Linear(256, 1);
register_module("disciminator fc1", fc1);
register_module("disciminator fc2", fc2);
register_module("disciminator fc3", fc3);
}
torch::Tensor DiscriminatorImpl::forward(torch::Tensor x) // (64,1,28,28)
{
x = x.view({ x.sizes()[0], -1 }); // (64,784)
x = fc1(x);
x = relu1(x);
x = fc2(x);
x = relu2(x);
torch::Tensor validity = fc3(x); // (64,1)
validity = torch::sigmoid(validity); // (0,1)
return validity;
}
void Visualize(const torch::Tensor& samples)
{
int n = 10;
cv::Mat scene(cv::Size(samples.sizes()[2] * n, samples.sizes()[3]), CV_32F); // witdh = w*n
for (int i = 0; i < n; i++)
{
auto image_tensor = samples[i].detach().cpu(); // (1,28,28)
cv::Mat image_mat(image_tensor.size(1), image_tensor.size(2), CV_32F, image_tensor.data_ptr()); // tensor to mat
image_mat.copyTo(scene(cv::Rect(image_mat.cols * i, 0, image_mat.cols, image_mat.rows))); // x = w = col*i
}
cv::namedWindow("visualize", cv::WINDOW_NORMAL);
cv::imshow("visualize", scene);
cv::waitKey(1);
}
int main()
{
if (torch::cuda::is_available())
device = torch::Device(torch::kCUDA);
// Assume the MNIST dataset is available under `kDataFolder`;
auto dataset = torch::data::datasets::MNIST(DATA_FOLDER) // http://yann.lecun.com/exdb/mnist.
.map(torch::data::transforms::Normalize<>(0.5, 0.5)) // 0.5均值,0.5方差
.map(torch::data::transforms::Stack<>());
const int64_t batches_per_epoch = std::ceil(dataset.size().value() / static_cast(BATCH_SIZE));
auto options = torch::data::DataLoaderOptions();
options.drop_last(true);
options.batch_size(BATCH_SIZE);
options.workers(2);
auto data_loader = torch::data::make_data_loader(std::move(dataset), options);
//auto data_loader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions().batch_size(BATCH_SIZE).workers(2));
// Initialize generator and discriminator
Generator generator = Generator();
Discriminator discriminator = Discriminator();
generator->to(device);
discriminator->to(device);
// Loss function
torch::nn::BCELoss adversarial_loss = torch::nn::BCELoss();
adversarial_loss->to(device);
// optimizers
;
torch::optim::Adam optimizer_G = torch::optim::Adam(generator->parameters(), torch::optim::AdamOptions(0.002).betas(std::make_tuple(0.5, 0.999)));
torch::optim::Adam optimizer_D = torch::optim::Adam(discriminator->parameters(), torch::optim::AdamOptions(0.002).betas(std::make_tuple(0.5, 0.999)));
for (int64_t epoch = 1; epoch <= N_EPOCHS; epoch++)
{
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader)
{
// Adversarial ground truths
torch::Tensor valid = torch::ones({ batch.data.size(0), 1 }, torch::kFloat).to(device); // (64,1)
torch::Tensor fake = torch::zeros({ batch.data.size(0), 1 }, torch::kFloat).to(device); // (64,1)
torch::Tensor real_imges = batch.data.to(device); // (64,1,28,28)
/*
-----------------
Train Generator
-----------------
*/
optimizer_G.zero_grad();
// Sample noise as generator input
torch::Tensor z = torch::randn({ batch.data.size(0), NOISE_SIZE}, device); // (64,100)
// Generate a batch of images
torch::Tensor gen_imges = generator(z); // (64,1,28,28)
// std::cout << gen_imges.sizes() << std::endl;
// Loss measures generator's ability to fool the discriminator
torch::Tensor g_loss = adversarial_loss(discriminator(gen_imges), valid); // 越真越好
g_loss.backward(); // 生成器loss,负责越真越好
optimizer_G.step();
/*
---------------------
Train Discriminator
---------------------
*/
optimizer_D.zero_grad();
//Measure discriminator's ability to classify real from generated samples
torch::Tensor real_loss = adversarial_loss(discriminator(real_imges), valid); // 能够判别真的
torch::Tensor fake_loss = adversarial_loss(discriminator(gen_imges.detach()), fake); // 也能够判别假的, 注意这里要有detach,否则会报错runtime error. datch之后这里只会更新判别器梯度,不会更新生成器梯度。
torch::Tensor d_loss = (real_loss + fake_loss) / 2;
d_loss.backward(); // 判别器loss,甄别能力越强越好
optimizer_D.step();
std::cout << "Epoch: " << epoch << "/" << N_EPOCHS <<
"- D loss: " << d_loss.item().toDouble() << ", G loss: " << g_loss.item().toDouble() << std::endl;
Visualize(gen_imges);
}
}
return 0;
}
epoch 10: