【深度强化学习】第一个神经网络Demo :GAN生成Atari游戏图片

文章目录

  • 前言
  • 第三章 实例: 关于Atari游戏的生成对抗网络实现

前言

重读《Deep Reinforcemnet Learning Hands-on》, 常读常新, 极其深入浅出的一本深度强化学习教程。 本文的唯一贡献是对其进行了翻译和提炼, 加一点自己的理解组织成一篇中文笔记。

原英文书下载地址: 传送门
原代码地址: 传送门

第三章 实例: 关于Atari游戏的生成对抗网络实现

在开始本篇的介绍前, 先说明原代码在windows中遇到的一个极易出错的小问题:

我在envs = [InputWrapper(gym.make(name)) for name in ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')]这一步报错, 一度一头雾水。 原来这是由于我们使用的atari库是gym自带的, 而这在windows系统上是有问题的。 简单的解决办法是参考了解决方法

经实践发现只需前两步,即

pip uninstall atari-py
pip install --no-index -f https://github.com/Kojoley/atari-py/releases atari_py

重新安装atari-py即可。

本篇是 介绍,使用了pytorch 实现了一个生成对抗网络——训练一个网络,可以产生Atari游戏风格的画面截图。 作者说: MNIST手写体识别,已经是老生常谈的神经网络的hello world Demo了, 他不想拘泥于此, 因此决定用GAN,更有趣地写一个Demo。

关于GAN的简单介绍,可以知乎搜索一下GAN,之前的笔记里也有提到,这里不再赘述:简单而言就是共有两个网络, 一个生成网络和一个对抗网络。 以本例为例, 生成网络负责生成逼近Atari游戏的图片, 对抗网络负责分辨图片是不是Atari游戏的图片。 一开始,生成网络生成的图片很假,被对抗网络轻松识别。 那么生成网络就被训练,生成更逼真一些的图片, 欺骗了对抗网络。 而对抗网络马上也加以训练, 再次识别出了生成网络伪造的图片。 循环往复地迭代训练后, 最后生成网络的生成图片已经基本可以以假乱真了。

这个Demo的完整代码, 在上面给出的github库 第三章的 最后一个.py文件, 可以直接运行。 下面是对这个Demo代码的每一块进行深入解析:

class InputWrapper(gym.ObservationWrapper):
    """
    Preprocessing of input numpy array:
    1. resize image into predefined size
    2. move color channel axis to a first place
    """
    def __init__(self, *args):
        super(InputWrapper, self).__init__(*args)
        assert isinstance(self.observation_space, gym.spaces.Box)
        old_space = self.observation_space
        self.observation_space = gym.spaces.Box(self.observation(old_space.low), self.observation(old_space.high), dtype=np.float32)

    def observation(self, observation):
        # resize image
        new_obs = cv2.resize(observation, (IMAGE_SIZE, IMAGE_SIZE))
        # transform (210, 160, 3) -> (3, 210, 160)
        new_obs = np.moveaxis(new_obs, 2, 0)
        return new_obs.astype(np.float32)


首先,我们利用之前介绍过的Gym的装饰器类, 创建了一个Observation装饰器, 来对观测进行一些简单的处理,更利于网络的训练。 如注释部分所写: 主要是将图片尺寸统一修改, 然后将颜色坐标轴提到第一维。

首先用super,调用基类的初始化方法。
assert isinstance(self.observation_space, gym.spaces.Box)这是python的assert 用法,
assert expression等价于

if not expression:
    raise AssertionError(arguments)

因此,这里通过 isinstance方法判断观测是否是gym的Box类,属于程序的自己检查。
接下来

 old_space = self.observation_space
        self.observation_space = gym.spaces.Box(self.observation(old_space.low), self.observation(old_space.high), dtype=np.float32)

两句则是先获取环境本来的观测空间, 再创建一个相同维度的空间(gym.Box类)。
然后 Observation的装饰类要重写observation()方法, 来返还我们自定义的观测。 这里也就是我们提到的两处修改: 用opencv的resize方法修改图片尺寸,用numpy的moveasix方法, 把第三维移动到第一维。

接下来,是对两个网络的分别搭建: 首先是 对抗网络,其任务是鉴别输入数据是否为真/伪造的。 显然,这是一个简单的二分类网络, 实现如下:

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        # this pipe converges image into the single number
        self.conv_pipe = nn.Sequential(
            nn.Conv2d(in_channels=input_shape[0], out_channels=DISCR_FILTERS,
                      kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS, out_channels=DISCR_FILTERS*2,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS*2),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 2, out_channels=DISCR_FILTERS * 4,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS * 4),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 4, out_channels=DISCR_FILTERS * 8,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS * 8),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=1,
                      kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        conv_out = self.conv_pipe(x)
        return conv_out.view(-1, 1).squeeze(dim=1)

首先, 使用Sequential类, 组建了一个卷积神经网络,接受一个三维张量(维度由input_shape这个参数输入),最后一层维度为1, 且用Sigmoid函数激活, 来得到一个0~1的输出, 这也是分类网络的常见处理。 接下来, 使用forward()函数, 用定义的网络对输入数据进行处理,返回输出。

net_discr = Discriminator(input_shape=input_shape).to(device),创建对抗网络实例(鉴别器), 用```print(new_discr)可以打印网络:

print(net_discr)
Discriminator(
  (conv_pipe): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
    (12): Sigmoid()
  )
)

然后是生成网络: 生成网络以一个随机张量作为输入, 用 反卷积 (转置卷积)层,将随机张量转变成代表一幅图的三维张量。 我们希望其尽可能逼近 Atari原生的图片, 即,生成的图片尽可能让 对抗网络识别为真。

类似的,其网络构造代码如下:

class Generator(nn.Module):
    def __init__(self, output_shape):
        super(Generator, self).__init__()
        # pipe deconvolves input vector into (3, 64, 64) image
        self.pipe = nn.Sequential(
            nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 8,
                               kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(GENER_FILTERS * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 8, out_channels=GENER_FILTERS * 4,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 2,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 2, out_channels=GENER_FILTERS,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS, out_channels=output_shape[0],
                               kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.pipe(x)

核心就是使用了nn.ConvTranspose2d层, 来实现转置卷积。 关于转置卷积的详细叙述可以自行百度。

已经有了网络, 现在我们还需要 数据集—— 由于我们的目标是产生和Atari游戏画面类似的图片, 因此, 我们需要 Atari图片作为数据集, 来训练我们的对抗网络。 (对于对抗网络,这些图片的标签为1, 即为真。)

实现代码如下:

def iterate_batches(envs, batch_size=BATCH_SIZE):
    batch = [e.reset() for e in envs]
    env_gen = iter(lambda: random.choice(envs), None)

    while True:
        e = next(env_gen)
        obs, reward, is_done, _ = e.step(e.action_space.sample())
        if np.mean(obs) > 0.01:
            batch.append(obs)
        if len(batch) == batch_size:
            # Normalising input between -1 to 1
            batch_np = np.array(batch, dtype=np.float32) * 2.0 / 255.0 - 1.0
            yield torch.tensor(batch_np)
            batch.clear()
        if is_done:
            e.reset()

详细解释下代码:
batch = [e.reset() for e in envs], batch这个列表,是用来保存我们得到的obs数据(我们想要的图片数据其实就是 obs),而一开始, 作者就用e.reset()方法,相当于batch先加入了各个环境的初始obs图片数据——rese方法的返回值就是初始obs。

env_gen = iter(lambda: random.choice(envs), None),这个用法非常神奇,似乎不是主流的用法, iter(func, None),通过这样的代码,可以使得func函数变成一个迭代器, 如同后面使用的时候所展示的那样:e = next(env_gen)等价于 e = random.choice(envs)

接下来, 作者使用经典的

while True:
	yield ....

来实现每次调用本函数iterate_batches时,都产生一组数据返回值。 yield关键词就是返还, 并且函数停留在这一点, 下次再调用该函数时,从yield处继续。百度上也有许多对yield的详细介绍。
obs, reward, is_done, _ = e.step(e.action_space.sample())
这一步则是使用env类的step方法, 得到obs数据。 我们采用了随机的动作。
为了避免一个小Bug, 作者对obs做了一个过滤——只有平均值大于0.01的obs张量才会被记录到batch中,这是因为作者发现atari游戏偶尔有闪烁,这时得到的obs可能全0, 没有训练意义。

if len(batch) == batch_size:
           # Normalising input between -1 to 1
           batch_np = np.array(batch, dtype=np.float32) * 2.0 / 255.0 - 1.0
           yield torch.tensor(batch_np)
           batch.clear()

当batch中存有的obs数量已经达到需求时, 将数据转化为numpy形式(本来是列表), 再转变为torch的张量类型,返回。 由yield关键词, 下次再调用本函数时,会从上一次的返回处开始, 即会先运行 batch.clear(), 归零列表, 由于外面的是while True循环,因此再次调用时,同样会和之前一样, 最终返回一个样本数量足够的张量结果。

最后,检查下is_done变量, 如果回合已经结束,则调用reset()方法重启。

类和函数的定义做完了, 接下来就是主函数部分:

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", default=False, action='store_true', help="Enable cuda computation")
    args = parser.parse_args()

    device = torch.device("cuda" if args.cuda else "cpu")
    envs = [InputWrapper(gym.make(name)) for name in ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')]
    input_shape = envs[0].observation_space.shape

这一段代码没有什么。 作者使用了argpars这个库, 用来控制使用cpu 或者 gpu进行训练。 我是直接用cpu的, 忽略即可。然后根据定义的环境类(装饰过)创建了实例列表, 再将obs的维度定义为输入的维度(input_shape)。

net_discr = Discriminator(input_shape=input_shape).to(device)
net_gener = Generator(output_shape=input_shape).to(device)

objective = nn.BCELoss()
gen_optimizer = optim.Adam(params=net_gener.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
dis_optimizer = optim.Adam(params=net_discr.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
writer = SummaryWriter()

这一段则创建了两个网络的实例, 以及损失函数及优化器。 损失函数这里使用的是nn.BCELoss(), 即二元交叉熵 (GAN的损失函数就是如此, 即判断图片的真假,相当于二分类问题)。两个网络使用的都是Adam优化器。 最后, 再创建了一个Writer用于训练时tensorboard进行监控。

gen_losses = []
dis_losses = []
iter_no = 0

true_labels_v = torch.ones(BATCH_SIZE, dtype=torch.float32, device=device)
fake_labels_v = torch.zeros(BATCH_SIZE, dtype=torch.float32, device=device)

创建了两个空列表,用于记录训练中损失值的变化——保存到writer中。生成了两个标签张量: 维度就是Batch_size, 即和Batch_size个样本一一对应。 如果是真的图片,其标签值为1, 因此使用torch.ones(), 如果是假的,标签值0, 也就对应torch.zeros()

for batch_v in iterate_batches(envs):
	    # generate extra fake samples, input is 4D: batch, filters, x, y
	    gen_input_v = torch.FloatTensor(BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1).normal_(0, 1).to(device)
	    batch_v = batch_v.to(device)
	    gen_output_v = net_gener(gen_input_v)

这里,
for batch_v in iterate_batches(envs):就是将函数iterate_batches看做一个迭代器, 每次循环的时候,到yield关键字处,返回数据。 下一次循环中,再从yield关键字开始, 无限循环。 总之, batch_v 就是这一次循环中,我们用于训练网络的样本数据——真图片。 接下来,我们还要生成假图片。 首先, 随机产生一个向量gen_input_v = torch.FloatTensor(BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1).normal_(0, 1).to(device), 再通过生成网络, gen_output_v = net_gener(gen_input_v),按我们的设想,这输出的gen_output_v生成网络产生的图片, 即假图片。 它们和刚刚获取的真图片一起,组成训练数据。

dis_optimizer.zero_grad()
dis_output_true_v = net_discr(batch_v)
dis_output_fake_v = net_discr(gen_output_v.detach())
dis_loss = objective(dis_output_true_v, true_labels_v) + objective(dis_output_fake_v, fake_labels_v)
dis_loss.backward()
dis_optimizer.step()
dis_losses.append(dis_loss.item())

这里是对对抗网络的损失函数的定义, 首先用zero_grad()方法清零梯度。 然后
dis_output_true_v = net_discr(batch_v)dis_output_fake_v = net_discr(gen_output_v.detach())分别将 真图片 和假图片, 输入到对抗网络中,进行判别。显然, 损失函数值也就是两者之和了。 objective(dis_output_true_v, true_labels_v)就是对输出和标签求交叉熵。 接下来, 使用backward()方法, 进行计算图中(即神经网络中)所有优化变量的梯度求取, 然后使用优化器的step()方法,对变量进行优化。 最后,将损失值存储。 值得一提的是这里用到了.detach()方法, 该方法返回一个新的变量,但仍指向原来的空间, 且 requires_grad属性默认为False。

gen_optimizer.zero_grad()
dis_output_v = net_discr(gen_output_v)
gen_loss_v = objective(dis_output_v, true_labels_v)
gen_loss_v.backward()
gen_optimizer.step()
gen_losses.append(gen_loss_v.item())

这一段代码,则是对生成网络进行训练。gen_output_v是训练前的生成网络的输出, 这里先让他通过对抗网络, 得到对抗网络的结果。 由于我们希望,生成的图片会被判别为真, 因此, 损失值就定义为 对抗网络的输出结果 与 期望的输出值1 之间 的 交叉熵。同样的, backward() 和 step()二连, 完成本次的迭代训练。

iter_no += 1
if iter_no % REPORT_EVERY_ITER == 0:
      log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e", iter_no, np.mean(gen_losses), np.mean(dis_losses))
      writer.add_scalar("gen_loss", np.mean(gen_losses), iter_no)
      writer.add_scalar("dis_loss", np.mean(dis_losses), iter_no)
      gen_losses = []
      dis_losses = []
if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
      writer.add_image("fake", vutils.make_grid(gen_output_v.data[:64], normalize=True), iter_no)
      writer.add_image("real", vutils.make_grid(batch_v.data[:64], normalize=True), iter_no)

最后, 当训练次数达到固定值时, 打印结果, 并存储到writer中,便于tensorboard的展示。

运行代码,效果:(CPU:i9-9900k,不需要GPU也能跑)
【深度强化学习】第一个神经网络Demo :GAN生成Atari游戏图片_第1张图片

用tensorflow 可视化一下效果图:

这是训练200次时的效果, 可以看到, 生成网络得到的基本还是白噪声水平。
【深度强化学习】第一个神经网络Demo :GAN生成Atari游戏图片_第2张图片
2000次左右开始就颇具雏形了。
【深度强化学习】第一个神经网络Demo :GAN生成Atari游戏图片_第3张图片
这个大家自己可以测试一下~
tensorboard的用法在上一节中讲述了。

你可能感兴趣的:(深度强化学习,深度学习)