【机器学习】对抗生成网络

【机器学习】对抗生成网络_第1张图片

一、随机数据生成

【机器学习】对抗生成网络_第2张图片

随机数据生成算法

【机器学习】对抗生成网络_第3张图片

【机器学习】对抗生成网络_第4张图片

随机数据生成的显示建模和隐式建模

【机器学习】对抗生成网络_第5张图片

二、生成对抗网络结构

【机器学习】对抗生成网络_第6张图片

生成对抗网络(GAN)中,生成模型(Generator)和判别模型(Discriminator)的任务和训练目标分别是:

  • 生成模型的任务是从随机噪声中生成尽可能真实的数据,例如图像、文本、音频等。生成模型的训练目标是最小化生成数据被判别模型识别为假的概率,也就是最大化生成数据的真实性。

  • 判别模型的任务是区分输入的数据是真实的还是生成模型生成的。判别模型的训练目标是最大化真实数据被识别为真的概率和生成数据被识别为假的概率,也就是最小化判别模型的误判率。

生成模型和判别模型的训练是一个对抗的过程,它们互相竞争,不断提高自己的能力,最终达到一个平衡点,使得生成模型生成的数据无法被判别模型区分。这样,生成模型就可以生成高质量的数据,判别模型就可以提高数据的鉴别能力。

生成对抗网络的原理

【机器学习】对抗生成网络_第7张图片

【机器学习】对抗生成网络_第8张图片

【机器学习】对抗生成网络_第9张图片

【机器学习】对抗生成网络_第10张图片

【机器学习】对抗生成网络_第11张图片

三、模型的训练

【机器学习】对抗生成网络_第12张图片

GAN的训练过程是怎样的?

【机器学习】对抗生成网络_第13张图片

# 初始化生成器G和判别器D
G = Generator()
D = Discriminator()


# 设置优化器和超参数
optimizer_G = Optimizer(G.parameters(), ...)
optimizer_D = Optimizer(D.parameters(), ...)
epochs = ...
batch_size = ...
latent_dim = ...


# 循环训练epochs次
for epoch in range(epochs):
  # 循环训练每个批次的数据
  for x in data_loader(batch_size):
    # 训练判别器D
    optimizer_D.zero_grad()
    z = random_noise(latent_dim)
    fake_x = G(z)
    real_pred = D(x)
    fake_pred = D(fake_x.detach())
    loss_D = binary_cross_entropy(real_pred, 1) + binary_cross_entropy(fake_pred, 0)
    loss_D.backward()
    optimizer_D.step()


    # 训练生成器G
    optimizer_G.zero_grad()
    fake_pred = D(fake_x)
    loss_G = binary_cross_entropy(fake_pred, 1)
    loss_G.backward()
    optimizer_G.step()


    # 打印训练信息
    print(f"Epoch {epoch}, Loss_D: {loss_D}, Loss_G: {loss_G}")

四、应用和改进

GAN的变体

这些变体的原理和GAN有什么不同?

如何选择适合自己任务的GAN变体?

【机器学习】对抗生成网络_第14张图片

4.1 改进方案

CGAN

【机器学习】对抗生成网络_第15张图片

CGAN和GAN的训练过程有什么不同?

【机器学习】对抗生成网络_第16张图片

DCGAN

【机器学习】对抗生成网络_第17张图片

拉普拉斯金字塔GAN

【机器学习】对抗生成网络_第18张图片

【机器学习】对抗生成网络_第19张图片

【机器学习】对抗生成网络_第20张图片

【机器学习】对抗生成网络_第21张图片

GRAN

循环神经网络 (Recurrent Neural Networks, RNNs) 是一种深度学习的方法,它可以处理序列数据,如文本,语音,音乐等。RNNs 的特点是它们有一个内部状态,可以记住之前的信息,从而捕捉序列数据的长期依赖和结构。RNNs 可以用于生成对抗网络 (Generative Adversarial Networks, GANs) 的框架中,作为生成器或判别器,来生成或评估序列数据。这种结合了 RNNs 和 GANs 的方法称为生成循环对抗网络 (Generative Recurrent Adversarial Networks, GRANs)。GRANs 可以利用 RNNs 的能力来生成逼真和多样的序列数据,如文本,语音,音乐等。

【机器学习】对抗生成网络_第22张图片

InfoGAN

【机器学习】对抗生成网络_第23张图片

【机器学习】对抗生成网络_第24张图片

4.2 典型应用

【机器学习】对抗生成网络_第25张图片

【机器学习】对抗生成网络_第26张图片

【机器学习】对抗生成网络_第27张图片

Real-ESRGAN 超分辨率图像示例:

【机器学习】对抗生成网络_第28张图片

从 https://github.com/ai-forever/Real-ESRGAN 下载源码,从

https://huggingface.co/ai-forever/Real-ESRGAN/tree/main 手动下载模型,放在weights文件夹中。 

主程序

# 导入os模块,用于操作系统相关的功能,如文件和目录的管理
import os 
# 导入torch模块,用于深度学习的计算和模型的构建
import torch
# 导入PIL模块,用于图像的处理和显示
from PIL import Image
# 导入numpy模块,用于科学计算和数组的操作
import numpy as np
# 导入RealESRGAN模块,这是一个基于生成对抗网络的超分辨率模型,可以将低分辨率的图像转换为高分辨率的图像
from RealESRGAN import RealESRGAN




# 定义一个主函数,返回值类型为整数
def main() -> int:
    # 判断当前设备是否支持CUDA,如果支持,就使用CUDA作为设备类型,否则使用CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 创建一个RealESRGAN的实例,指定设备类型和放大倍数为4
    model = RealESRGAN(device, scale=4)
    # 加载预训练的模型权重,从本地目录中读取文件,不需要下载
    model.load_weights('Real-ESRGAN/weights/RealESRGAN_x4.pth', download=False)
    # 遍历输入目录中的所有图像文件,使用enumerate函数给每个文件编号
    for i, image in enumerate(os.listdir("Real-ESRGAN/inputs")):
        # 打开图像文件,并转换为RGB模式
        image = Image.open(f"Real-ESRGAN/inputs/{image}").convert('RGB')
        # 使用模型对图像进行预测,得到超分辨率的图像
        sr_image = model.predict(image)
        # 将超分辨率的图像保存到结果目录中,文件名为编号.png
        sr_image.save(f'Real-ESRGAN/results/{i}.png')
    # 返回0表示程序正常结束
    return 0




# 如果当前文件是作为主程序运行,而不是被其他文件导入,就执行主函数
if __name__ == '__main__':
    main()

效果:

【机器学习】对抗生成网络_第29张图片

原始图 650x650

【机器学习】对抗生成网络_第30张图片

超分辨率图 2600x2600

参考网址:

https://en.wikipedia.org/wiki/Generative_adversarial_network

https://github.com/pytorch/examples/tree/main

https://arxiv.org/abs/1406.2661

https://zhuanlan.zhihu.com/p/53473337 [GAN学习系列2] GAN的起源

https://mp.weixin.qq.com/s?__biz=MzA3MzI4MjgzMw==&mid=2650730721&idx=2&sn=95b97b80188f507c409f4c72bd0a2767&chksm=871b349fb06cbd891771f72d77563f77986afc9b144f42c8232db44c7c56c1d2bc019458c4e4&scene=21#wechat_redirect

https://www.mindspore.cn/tutorials/application/zh-CN/r1.7/cv/dcgan.html 生成式对抗网络

https://pytorch.org/examples/

https://github.com/leftthomas/SRGAN

https://pytorch.org/hub/

https://github.com/xiong-jie-y/ml-examples/tree/master

https://huggingface.co/ai-forever/Real-ESRGAN/tree/main    模型下载Real-ESRGAN

你可能感兴趣的:(机器学习,人工智能)