Python解包运算操作*和打包运算zip

背景还是我在高DQN算法的时候遇到的,下面代码的第七行。完整代码参考这个博客。

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)#从replay buffer中选择transitions
    #转置batch(有关详细说明,请参阅https://stackoverflow.com/a/19343/3343043)
    #这会将transitions的batch数组转换为batch数组的过渡。
    batch = Transition(*zip(*transitions))#解包运算
    # 计算非最终状态的掩码并连接batch元素(最终状态将是模拟结束后的状态)
    #这个tuple(map(lambda s...)函数的作用是判断状态s是否是最终状态,如果不是最终状态就把batch.next_state赋值给s
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)),device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)  #将状态连接到一起,按列组合
    action_batch = torch.cat(batch.action)#将动作连接到一起,按列组合
    reward_batch = torch.cat(batch.reward)#将奖励连接到一起,按列组合

    state_action_values = policy_net(state_batch).gather(1, action_batch)#根据动作值选择相应的状态
    next_state_values = torch.zeros(BATCH_SIZE, device=device)#生成一个尺寸为BATCH_SIZE的张量
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch#计算期望的状态动作值

    # 设置我们的损失函数
    criterion = nn.MSELoss()#
    #损失函数的输入是状态值和期望的状态值
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()#损失函数反向传播
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

根据对DQN的理解,这个transiton是一个剧名数组,构造方式如下:

#创建一个Transition容器,具名数组
Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))

再训练DQN的时候,我们需要从replay buffer中提取Transition ,然后将transition的四个变量提取出来使用,这个时候就需要使用解包操作

1、使用说明

(1)解包的意义就是将传递给函数的一个列表,元组,字典,拆分成独立的多个元素然后赋值给函数中的形参变量。

(2)解包字典有两种解法,一种用 ∗ * 解的只有key,一种用 ∗ ∗ ** 解的有key、value。但是这个方法**只能在函数定义中使用。

2. 解包方法

解包的方法分类两种, ∗ * ∗ ∗ **

其中 ∗ ∗ ** 是针对字典的。
我们先举 ∗ * 的例子,也用数组表示吧,其实列表list也一样。
来一个数组

a = (1,2,3)

常规的解包操作是这样的

a = (1,2,3)
a1,a2,a3 = (1,2,3)
print(a1)
print(a2)
print(a3)
输出是:
1
2
3

如果使用*方法解包,那么就省事很多了。

但是,一般来讲,我们会把zip(*)在一起用。
下面举个例子

Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))
#生成一个列表,列表中是元组
a = [(1,2,3,4),(11,12,13,14),(21,22,23,24),(31,32,33,34)]
b = zip(*a)
b = list(b)
c = Transition(*zip(*a))
c = list(c)
print(c)
#[(1,11,21,31),(2,12,22,32),(3,13,23,33),(4,14,24,34)]

你可能感兴趣的:(Pytorch,python,开发语言,pytorch)