在看深度强化学习DQN代码时,遇到这段代码,搞了好久都没看明白。
完整代码参考这个博客。
for t in count():
#count()用法: itertools.count(start=0, step=1)
#start:序列的开始(默认为0)
#step:连续数字之间的差(默认为1)
reward = 0 #设置初始化奖励为0
m_reward = 0#求和奖励
# 每m帧完成一次action
action = select_action(state)#选择动作
#每四步更新一次奖励
for i in range(m):
#与环境交互,选择一个动作之后,获得奖励,并判断是否时最终状态
_, reward, done, _ = env.step(action.item())
if not done:
#如果不是终止状态,那么屏幕截屏到next_state_queue
next_state_queue.append(get_screen())
else:
#否则的话,就终止程序
break
m_reward += reward#然后累加奖励
if not done:
#如果不是终止状态,那么就进入下一个状态,把下一个状态连接到一起,使用tuple,不会被修改
next_state = torch.cat(tuple(next_state_queue), dim=1)
else:
#如果是终止状态,则下一个状态就没有了,获取最终奖励
next_state = None
m_reward = 150
m_reward = torch.tensor([m_reward], device=device)#把奖励转换成张量
memory.push(state, action, next_state, m_reward)#把计算出来的四个元素集存储到replay buffer中
state = next_state#把下一个状态转为当前状态
optimize_model()#开始优化模型
这个for循环的使用方式说实话我是真的不明白。
for t in count():
能找到关于count()的信息是上面的import部分
from itertools import count
然后我找了好多博客,最后这个博客给我讲明白了。
itertools.count(start,step)函数的意思是创建一个从start开始每次的步长是step的无穷序列
当count()括号里为空时,表示从0开始,每次步长为1.
我们再回到实际的代码环境中。
这段代码出现在迭代训练阶段
第一个for循环时迭代次数
在这个训练开始时,我们会使用random_start()函数计算出done, state_queue, next_state_queue,即状态的状态(终止状态和非终止状态),当前状态序列和下一个状态序列。
然后首先就要判断当前状态时是否是终止状态,不是终止状态就继续我们说的这个for循环。
那么第二个这个for循环为什么时无限制循环的呢?
for t in count():
这个循环开始,首先就是初始化奖励和初始化累计奖励
reward = 0 #设置初始化奖励为0
m_reward = 0#求和奖励
然后使用动作选择函数选择算法需要执行的动作
action = select_action(state)#选择动作
下面就开始第三个循环了
for i in range(m):
m=4,因为每个状态有四张图像
这个循环的第一行代码是
_, reward, done, _ = env.step(action.item())
作用就是将上面选择的动作输入到环境中,然后环境会给出奖励和判断该奖励是否是终止状态。
if not done:
#如果不是终止状态,那么屏幕截屏到next_state_queue
next_state_queue.append(get_screen())
else:
#否则的话,就终止程序
break
m_reward += reward#然后累加奖励
然后就开始判断该状态是否是终止状态,如果是终止状态就跳出该循环,不是的话就把当前屏幕截屏添加到next_state_queue序列中。
m=4,所以要执行四次。然后把这四次采集到的图像存储到序列中,需要提到的是,在这个for循环中,agent所使用的动作是一样的。
采集到四张图像之后,这个循环结束。
然后开始金鱼不判断状态是否结束了
if not done:
#如果不是终止状态,那么就进入下一个状态,把下一个状态连接到一起,使用tuple,不会被修改
next_state = torch.cat(tuple(next_state_queue), dim=1)
else:
#如果是终止状态,则下一个状态就没有了,获取最终奖励
next_state = None
m_reward = 150
如果没有结束,就把这个next_state_queue中的图像拼接cat起来,
如果是终止状态,那么提示没有下一个状态,给出奖励。
然后进行下一步
m_reward = torch.tensor([m_reward], device=device)#把奖励转换成张量
memory.push(state, action, next_state, m_reward)#把计算出来的四个元素集存储到replay buffer中
state = next_state#把下一个状态转为当前状态
optimize_model()#开始优化模型
这个动作执行结束后,把奖励转成张量,然后把transition四元数存储到replay buffer中。
然后更新当前状态。
开始优化模型。
在开始判断状态是否终止
并保存训练过程数据和更新网络模型参数
保存模型
if done:
episode_durations.append(t + 1)
plot_durations()
break
# 更新目标网络,复制DQN中的所有权重和偏置
if i_episode % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net.state_dict())
if i_episode % 1000 ==0:
torch.save(policy_net.state_dict(), 'weights/policy_net_weights_{0}.pth'.format(i_episode))
当我把所有的循环看完之后,终于明白。这个无限循环的for循环是为了收集replay buffer中的transition。我们设置replay buffer的容量为100000,但是由于agent’与环境交互的不可知性导致我们知道到底要多少步才能完成。所以使用了这个循环。