强化学习-DDQN

DDQN和DQN基本上很像,不同的地方可以参考书本的132页
 
代码实现

import torch.nn as nn
import torch.nn.functional as F 
import random 
import torch
from torch import nn
from torch import optim 
import gym
import numpy as np 
from collections import namedtuple

import warnings 
warnings.filterwarnings('ignore')


Transition=namedtuple(
                      'Transition',('state','action','next_state','reward'))



def repaly(self):
    "经验回放学习网络的连接参数"

    #1.检查内存大小
    if len(self.memory)<BATCH_SIZE:
        return 

    #2.创建小批量数据
    self.batch,self.state_batch,self.action_batch,self.reward_batch,self.non_final_next_states=self.make_minibatch()

    #3.获取Q(s_t,a_t)值作为监督信息
    self.exceted_state_action_values=self.get_expected_state_action_values()

    #4.更新连接参数
    self.update_main_q_network()

class Net(nn.Module):

    def __init__(self,n_in,n_mid,n_out):
        super(Net,self).__init__()
        self.fc1=nn.Linear(n_in,n_mid)
        self.fc2=nn.Linear(n_mid,n_mid)
        self.fc3=nn.Linear(n_mid,n_out)

    def forward(self,x):
        h1=F.relu(self.fc1(x))
        h2=F.relu(self.fc2(h1))
        output=self.fc3(h2)

        return output

#定义用于存储经验的内存类
class ReplayMemory:

    def __init__(self,CAPACITY):
        self.capacity=CAPACITY    #下面memory的最大长度
        self.memory=[] #存储过往经验
        self.index=0 #表示要保存的索引



    def push(self,state,action,state_next,reward):
        "将transition=(state,action,state_next,reward)保存在存储器中"

        if len(self.memory)<self.capacity:
            self.memory.append(None)   #内存未满时添加

        #使用namedtuple对象Transition(state,action,state_next,reward)
        self.memory[self.index]=Transition(state,action,state_next,reward)
        

        self.index=(self.index+1)%self.capacity #将保存的index移动一位 相当于self.index=self.index+1

    def sample(self,batch_size):
        "随机抽取Batch_size大小的样本并返回"
        return random.sample(self.memory,batch_size)

    def __len__(self):
        "返回当前memory的长度"
        return len(self.memory)


BATCH_SIZE=32
CAPACITY=10000


class Brain:
    def __init__(self,num_states,num_actions):
        self.num_actions=num_actions

        self.memory=ReplayMemory(CAPACITY)

        #构建神经网络
        n_in,n_mid,n_out=num_states,32,num_actions
        self.main_q_network=Net(n_in,n_mid,n_out)
        self.target_q_network=Net(n_in,n_mid,n_out)
        print(self.main_q_network)


        self.optimizer=optim.Adam(self.main_q_network.parameters(),lr=0.0001)

    def replay(self):

        if len(self.memory)<BATCH_SIZE:
            return 

        self.batch,self.state_batch,self.action_batch,self.reward_batch,self.non_final_next_states=self.make_minibatch()

        self.expected_state_action_values=self.get_expected_state_action_values()

        self.update_main_q_network()



    def decide_action(self,state,episode):

        epsilon=0.5*(1/(episode+1))

        if epsilon<=np.random.uniform(0,1):
            self.main_q_network.eval()
            with torch.no_grad():
                action=self.main_q_network(state).max(1)[1].view(1,1)

        else:
            action=torch.LongTensor([[random.randrange(self.num_actions)]])

        return action

    def make_minibatch(self):

        transitions=self.memory.sample(BATCH_SIZE)
        batch=Transition(*zip(*transitions))

        state_batch=torch.cat(batch.state)
        action_batch=torch.cat(batch.action)
        reward_batch=torch.cat(batch.reward)
        non_final_next_states=torch.cat([s for s in batch.next_state if s is not None])

        return batch,state_batch,action_batch,reward_batch,non_final_next_states 

    def get_expected_state_action_values(self):


        #将网络切换到推理模式
        self.main_q_network.eval()
        self.target_q_network.eval()

        #用gather获取相应的Q值
        self.state_action_values=self.main_q_network(self.state_batch).gather(1,self.action_batch)

        non_final_mask=torch.ByteTensor(tuple(map(lambda s:s is not None,self.batch.next_state)))
        
        #首先全部设置为0
        next_state_values=torch.zeros(BATCH_SIZE)
        
        a_m=torch.zeros(BATCH_SIZE).type(torch.LongTensor)

        #将下一个状态输入到main_q_network中 根据Q值最大值得到动作
        a_m[non_final_mask]=self.main_q_network(self.non_final_next_states).detach().max(1)[1]
        #仅过滤具有下一个状态
        a_m_non_final_next_states=a_m[non_final_mask].view(-1,1)

        #将下一个状态输入到target_q_network中 再根据动作得到Q值
        next_state_values[non_final_mask]=self.target_q_network(self.non_final_next_states).gather(1,a_m_non_final_next_states).detach().squeeze()

        expected_state_action_values=self.reward_batch+GAMMA*next_state_values

        return expected_state_action_values



    def update_main_q_network(self):

        self.main_q_network.train()

        loss=F.smooth_l1_loss(self.state_action_values,self.expected_state_action_values.unsqueeze(1))


        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()



    def update_target_q_network(self):
        self.target_q_network.load_state_dict(self.main_q_network.state_dict())



class Agent:
    def __init__(self,num_states,num_actions):
        self.brain=Brain(num_states,num_actions)


    def update_q_function(self):
        self.brain.replay()

    def get_action(self,state,episode):
        action=self.brain.decide_action(state,episode)
        return action 

    def memorize(self,state,action,state_next,reward):
        self.brain.memory.push(state,action,state_next,reward) 

    def update_target_q_function(self):
        self.brain.update_target_q_network()

#常量的设定
ENV='CartPole-v0'   #要使用的任务名称
GAMMA=0.99          #时间折扣率
MAX_STEPS=200       #一次试验中的step数
NUM_EPISODES=500    #最大的尝试次数


class Enviroment:

    def __init__(self):
        self.env=gym.make(ENV)
        self.num_states=self.env.observation_space.shape[0]
        #设定任务状态和动作的数量
        self.num_actions=self.env.action_space.n

        #创建Agent在环境中执行的动作
        self.agent=Agent(self.num_states,self.num_actions)

    def run(self):

        episode_10_list=np.zeros(10)

        complete_episodes=0
        episode_final=False 

        frames=[]  #用于存储图像的变量,以使最后一轮成为画面

        for episode in range(NUM_EPISODES):

            print("==================>",episode)

            observation=self.env.reset()
            state=observation

            state=torch.from_numpy(state).type(torch.FloatTensor)   #将Numpy变量转换为pytorch tensor
            #https://www.cnblogs.com/datasnail/p/13086803.html
            #在第0维上进行扩张
            #torch.Size([4])
            state=torch.unsqueeze(state,0)
            #torch.Size([1, 4])

            for step in range(MAX_STEPS):

                action=self.agent.get_action(state,episode)   #求取动作
               
                #假如action=[[1]] action.item()=1
                observation_next,_,done,_=self.env.step(action.item())

                #done=True 大概分为两种情况 第一种是走路超过了195步 第二种是摔倒了
                if done:

                    state_next=None 


                    #https://blog.csdn.net/m0_37393514/article/details/79538748
                    episode_10_list=np.hstack((episode_10_list[1:],step+1))
                    
                    #判断是哪种情况
                    if step<195:
               
                        reward=torch.FloatTensor([-1.0])

                        complete_episodes=0
                    else:
                        reward=torch.FloatTensor([1.0])
                        complete_episodes=complete_episodes+1

                else:

                    reward=torch.FloatTensor([0.0])
                    state_next=observation_next
                    state_next=torch.from_numpy(state_next).type(torch.FloatTensor)
                    state_next=torch.unsqueeze(state_next,0)

                self.agent.memorize(state,action,state_next,reward)

                #DDQN中一共有两个网络
                
                #更新第一个网络
                self.agent.update_q_function()

                state=state_next

                if done:

                    print('%d Episode:Finished after %d steps: 10次试验的平均数step数=%.1lf'%(episode,step+1,episode_10_list.mean()))
                    if(episode % 2 == 0):
                        #更新第2个网络
                        #每隔两个episode更新一下网络
                        self.agent.update_target_q_function()
                    break 

            #结束游戏
            if episode_final is True:
                break
            
            #如果10轮都成功 则结束游戏
            if complete_episodes>=10:
                
                print('10轮连续成功')
                episode_final=True

cartpole_env=Enviroment()
cartpole_env.run()

你可能感兴趣的:(强化学习)