一、随机数据生成
随机数据生成算法
随机数据生成的显示建模和隐式建模
二、生成对抗网络结构
生成对抗网络(GAN)中,生成模型(Generator)和判别模型(Discriminator)的任务和训练目标分别是:
生成模型的任务是从随机噪声中生成尽可能真实的数据,例如图像、文本、音频等。生成模型的训练目标是最小化生成数据被判别模型识别为假的概率,也就是最大化生成数据的真实性。
判别模型的任务是区分输入的数据是真实的还是生成模型生成的。判别模型的训练目标是最大化真实数据被识别为真的概率和生成数据被识别为假的概率,也就是最小化判别模型的误判率。
生成模型和判别模型的训练是一个对抗的过程,它们互相竞争,不断提高自己的能力,最终达到一个平衡点,使得生成模型生成的数据无法被判别模型区分。这样,生成模型就可以生成高质量的数据,判别模型就可以提高数据的鉴别能力。
生成对抗网络的原理
三、模型的训练
GAN的训练过程是怎样的?
# 初始化生成器G和判别器D
G = Generator()
D = Discriminator()
# 设置优化器和超参数
optimizer_G = Optimizer(G.parameters(), ...)
optimizer_D = Optimizer(D.parameters(), ...)
epochs = ...
batch_size = ...
latent_dim = ...
# 循环训练epochs次
for epoch in range(epochs):
# 循环训练每个批次的数据
for x in data_loader(batch_size):
# 训练判别器D
optimizer_D.zero_grad()
z = random_noise(latent_dim)
fake_x = G(z)
real_pred = D(x)
fake_pred = D(fake_x.detach())
loss_D = binary_cross_entropy(real_pred, 1) + binary_cross_entropy(fake_pred, 0)
loss_D.backward()
optimizer_D.step()
# 训练生成器G
optimizer_G.zero_grad()
fake_pred = D(fake_x)
loss_G = binary_cross_entropy(fake_pred, 1)
loss_G.backward()
optimizer_G.step()
# 打印训练信息
print(f"Epoch {epoch}, Loss_D: {loss_D}, Loss_G: {loss_G}")
四、应用和改进
GAN的变体
这些变体的原理和GAN有什么不同?
如何选择适合自己任务的GAN变体?
4.1 改进方案
CGAN
CGAN和GAN的训练过程有什么不同?
DCGAN
拉普拉斯金字塔GAN
GRAN
循环神经网络 (Recurrent Neural Networks, RNNs) 是一种深度学习的方法,它可以处理序列数据,如文本,语音,音乐等。RNNs 的特点是它们有一个内部状态,可以记住之前的信息,从而捕捉序列数据的长期依赖和结构。RNNs 可以用于生成对抗网络 (Generative Adversarial Networks, GANs) 的框架中,作为生成器或判别器,来生成或评估序列数据。这种结合了 RNNs 和 GANs 的方法称为生成循环对抗网络 (Generative Recurrent Adversarial Networks, GRANs)。GRANs 可以利用 RNNs 的能力来生成逼真和多样的序列数据,如文本,语音,音乐等。
InfoGAN
4.2 典型应用
Real-ESRGAN 超分辨率图像示例:
从 https://github.com/ai-forever/Real-ESRGAN 下载源码,从
https://huggingface.co/ai-forever/Real-ESRGAN/tree/main 手动下载模型,放在weights文件夹中。
主程序
# 导入os模块,用于操作系统相关的功能,如文件和目录的管理
import os
# 导入torch模块,用于深度学习的计算和模型的构建
import torch
# 导入PIL模块,用于图像的处理和显示
from PIL import Image
# 导入numpy模块,用于科学计算和数组的操作
import numpy as np
# 导入RealESRGAN模块,这是一个基于生成对抗网络的超分辨率模型,可以将低分辨率的图像转换为高分辨率的图像
from RealESRGAN import RealESRGAN
# 定义一个主函数,返回值类型为整数
def main() -> int:
# 判断当前设备是否支持CUDA,如果支持,就使用CUDA作为设备类型,否则使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建一个RealESRGAN的实例,指定设备类型和放大倍数为4
model = RealESRGAN(device, scale=4)
# 加载预训练的模型权重,从本地目录中读取文件,不需要下载
model.load_weights('Real-ESRGAN/weights/RealESRGAN_x4.pth', download=False)
# 遍历输入目录中的所有图像文件,使用enumerate函数给每个文件编号
for i, image in enumerate(os.listdir("Real-ESRGAN/inputs")):
# 打开图像文件,并转换为RGB模式
image = Image.open(f"Real-ESRGAN/inputs/{image}").convert('RGB')
# 使用模型对图像进行预测,得到超分辨率的图像
sr_image = model.predict(image)
# 将超分辨率的图像保存到结果目录中,文件名为编号.png
sr_image.save(f'Real-ESRGAN/results/{i}.png')
# 返回0表示程序正常结束
return 0
# 如果当前文件是作为主程序运行,而不是被其他文件导入,就执行主函数
if __name__ == '__main__':
main()
效果:
原始图 650x650
超分辨率图 2600x2600
参考网址:
https://en.wikipedia.org/wiki/Generative_adversarial_network
https://github.com/pytorch/examples/tree/main
https://arxiv.org/abs/1406.2661
https://zhuanlan.zhihu.com/p/53473337 [GAN学习系列2] GAN的起源
https://mp.weixin.qq.com/s?__biz=MzA3MzI4MjgzMw==&mid=2650730721&idx=2&sn=95b97b80188f507c409f4c72bd0a2767&chksm=871b349fb06cbd891771f72d77563f77986afc9b144f42c8232db44c7c56c1d2bc019458c4e4&scene=21#wechat_redirect
https://www.mindspore.cn/tutorials/application/zh-CN/r1.7/cv/dcgan.html 生成式对抗网络
https://pytorch.org/examples/
https://github.com/leftthomas/SRGAN
https://pytorch.org/hub/
https://github.com/xiong-jie-y/ml-examples/tree/master
https://huggingface.co/ai-forever/Real-ESRGAN/tree/main 模型下载Real-ESRGAN