注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
今天和大家分享一下如何通过强化学习DQN打游戏
#博学谷IT学习技术支持#
例如:随着人工智能的不断发展,机器学习这门技术也越来越重要,很多人都开启了学习机器学习,本文就介绍了机器学习中强化学习的经典模型DQN网络。
强化学习算法可以分为三大类:value based, policy based 和 actor critic。常见的是以DQN为代表的value based算法,这种算法中只有一个值函数网络,没有policy网络,以及以DDPG,TRPO为代表的actor-critic算法,这种算法中既有值函数网络,又有policy网络。
ps:理论部分大家可以去看一下王树森教授的理论课,讲的非常精彩
https://www.bilibili.com/video/BV12o4y197US?p=2&vd_source=c282742e9b92317ec46a907e78c4fa64
这里主要是如何TD算法去更新DQN网络的参数。
主要搜集的是这个时刻的state,采取行动action,然后得到reward,最后是下一时刻的state
这样不断的学习,让DQN去逼近真正的最优Q网络即可。最后得到训练之后的网络输入观察到的state得到动作的打分。
代码如下(示例):
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import matplotlib.pyplot as plt
import gym
#hyper parameters
EPSILON = 0.9
GAMMA = 0.9
LR = 0.01
MEMORY_CAPACITY = 2000
Q_NETWORK_ITERATION = 100
BATCH_SIZE = 32
EPISODES = 400
env = gym.make('MountainCar-v0')# 3个动作,(向左,向右,不动)
env = env.unwrapped
NUM_STATES = env.observation_space.shape[0] # 2
NUM_ACTIONS = env.action_space.n
代码如下(示例):
class Dqn():
def __init__(self):
self.eval_net, self.target_net = Net(), Net()
self.memory = np.zeros((MEMORY_CAPACITY, NUM_STATES *2 +2))
# state, action ,reward and next state
self.memory_counter = 0
self.learn_counter = 0
self.optimizer = optim.Adam(self.eval_net.parameters(), LR)
self.loss = nn.MSELoss()
self.fig, self.ax = plt.subplots()
def store_trans(self, state, action, reward, next_state):
if self.memory_counter % 500 ==0:
print("The experience pool collects {} time experience".format(self.memory_counter))
index = self.memory_counter % MEMORY_CAPACITY
trans = np.hstack((state, [action], [reward], next_state))#记录一条数据
self.memory[index,] = trans
self.memory_counter += 1
def choose_action(self, state):
# notation that the function return the action's index nor the real action
# EPSILON
state = torch.unsqueeze(torch.FloatTensor(state) ,0)
if np.random.randn() <= EPSILON:#探索
action_value = self.eval_net.forward(state)# 得到各个action的得分
action = torch.max(action_value, 1)[1].data.numpy() # 找最大的那个action
action = action[0] #get the action index
else:
action = np.random.randint(0,NUM_ACTIONS)
return action
def plot(self, ax, x):
ax.cla()
ax.set_xlabel("episode")
ax.set_ylabel("total reward")
ax.plot(x, 'b-')
plt.pause(0.000000000000001)
def learn(self):
# learn 100 times then the target network update
if self.learn_counter % Q_NETWORK_ITERATION ==0:
self.target_net.load_state_dict(self.eval_net.state_dict())#学了100次之后target才更新(直接加载eval的权重)
self.learn_counter+=1
sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)#获取一个batch数据
batch_memory = self.memory[sample_index, :]
batch_state = torch.FloatTensor(batch_memory[:, :NUM_STATES])
#note that the action must be a int
batch_action = torch.LongTensor(batch_memory[:, NUM_STATES:NUM_STATES+1].astype(int))
batch_reward = torch.FloatTensor(batch_memory[:, NUM_STATES+1: NUM_STATES+2])
batch_next_state = torch.FloatTensor(batch_memory[:, -NUM_STATES:])
q_eval = self.eval_net(batch_state).gather(1, batch_action)#得到当前Q(s,a)
q_next = self.target_net(batch_next_state).detach()#得到Q(s',a'),下面选max
q_target = batch_reward + GAMMA*q_next.max(1)[0].view(BATCH_SIZE, 1)#公式
loss = self.loss(q_eval, q_target)#差异越小越好
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(NUM_STATES, 30)
self.fc1.weight.data.normal_(0, 0.1)
self.fc2 = nn.Linear(30, NUM_ACTIONS)
self.fc2.weight.data.normal_(0, 0.1)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
代码如下(示例):
def main():
net = Dqn()
print("The DQN is collecting experience...")
step_counter_list = []
for episode in range(EPISODES):
state = env.reset()
step_counter = 0
while True:
step_counter +=1
env.render()
action = net.choose_action(state)
next_state, reward, done, info = env.step(action)
reward = reward * 100 if reward >0 else reward * 5
net.store_trans(state, action, reward, next_state)#记录当前这组数据
if net.memory_counter >= MEMORY_CAPACITY: # 攒够数据一起学
net.learn()
if done:
print("episode {}, the reward is {}".format(episode, round(reward, 3)))
if done:
step_counter_list.append(step_counter)
net.plot(net.ax, step_counter_list)
break
state = next_state
if __name__ == '__main__':
main()
DQN网络非常简单,是入门级别的强化学习模型,主要运用TEMPORAL DIFFERENCE算法,效果也没有其他模型好,不过也可以用,之后我会继续更新其他强化学习网络。效果会比DQN网络好。