对抗网络GAN生成图片

可以说,深度学习模型中,GAN是上头条次数最多的模型。从不存在的人脸到用英伟达GAN生成老婆,怎样的妹子都可以。

什么是GAN

GAN全称生成式对抗网络(GAN, Generative Adversarial Networks),是一种近年来大热的深度学习模型。GAN是从一个简单分布中生成数据,而他的训练则是处于一种对抗博弈过程,其核心是由一个生成器(Generator)和一个判别器D(Discriminator)互相冲突的神经网络组成,这两个网络会以越来越复杂的方法来“蒙骗”对方。

我们以一个假币制造的例子形象地解释 GAN 的工作原理。

在这个过程中,我们想象有两种人:造假大师和警察。

我们看看他们的之间互相冲突的目标:

造假大师:他的主要目标就是想出各种制作假币的复杂方法,从而让警察无法区分真伪。
警察:他的主要目标就是想出各种辨别假币的复杂方法,这样就能够区分是否要假币。

随着过种的不断推进,造假大师制造假币的方式越来越高明,警察的鉴别技术也越来越高超。最终是通过博弈达到一个动态的平衡。

GAN能达到的一个最理想的状态是平衡,生成器G生成的模拟数据,判别器D输出的概率应该 0.5, 即生成的数据和真实数据一致。也就是说,它不确定来自生成器的新数据是真实还是虚假,二者的概率相等。

接下来,我们看一下Java代码是怎么实现这一过程:

//加载并输入类型进行生成图片
public static Image[] generate() throws IOException, ModelException, TranslateException {
    Criteria<int[], Image[]> criteria =
        Criteria.builder()
        .optApplication(Application.CV.IMAGE_GENERATION)
        .setTypes(int[].class, Image[].class) //设置输入和输出类型
        .optFilter("size", "256")
        .optArgument("truncation", 0.5f)
        .optProgress(new ProgressBar())
        .optEngine(PtEngine.ENGINE_NAME) //指定pytorch引擎
        .build();

    int[] input = {100, 207, 971, 970, 933};

    try (ZooModel<int[], Image[]> model = criteria.loadModel();
         Predictor<int[], Image[]> generator = model.newPredictor()) {
        return generator.predict(input);
    }
}

//将生成图片保存到指定目录下
private static void saveImages(Image[] generatedImages) throws IOException {
    Path outputPath = Paths.get("build/output/gan/");
    Files.createDirectories(outputPath);

    for (int i = 0; i < generatedImages.length; ++i) {
        Path imagePath = outputPath.resolve("image" + i + ".png");
        generatedImages[i].save(Files.newOutputStream(imagePath), "png");
    }
    logger.info("Generated images have been saved in: {}", outputPath);
}

运行上面的代码,就会在build/output/gan/目录下生成,

关注微信公众号,继续了解后续内容!

你可能感兴趣的:(Java开发者动手学习深度学习,java,深度学习,开发语言,机器学习,神经网络)