import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import math
import random
import os
import gym
# Hyper Parameters
STATE_DIM = 4
ACTION_DIM = 2
STEP = 2000
SAMPLE_NUMS = 30
class ActorNetwork(nn.Module):
def __init__(self,input_size,hidden_size,action_size):
super(ActorNetwork, self).__init__()
self.fc1 = nn.Linear(input_size,hidden_size)
self.fc2 = nn.Linear(hidden_size,hidden_size)
self.fc3 = nn.Linear(hidden_size,action_size)
def forward(self,x):
out = F.relu(self.fc1(x))
out = F.relu(self.fc2(out))
out = F.log_softmax(self.fc3(out))
return out
class ValueNetwork(nn.Module):
def __init__(self,input_size,hidden_size,output_size):
super(ValueNetwork, self).__init__()
self.fc1 = nn.Linear(input_size,hidden_size)
self.fc2 = nn.Linear(hidden_size,hidden_size)
self.fc3 = nn.Linear(hidden_size,output_size)
def forward(self,x):
out = F.relu(self.fc1(x))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def roll_out(actor_network,task,sample_nums,value_network,init_state):
#task.reset()
states = []
actions = []
rewards = []
is_done = False
final_r = 0
state = init_state
for j in range(sample_nums):
states.append(state)
log_softmax_action = actor_network(Variable(torch.Tensor([state])))
softmax_action = torch.exp(log_softmax_action)
action = np.random.choice(ACTION_DIM,p=softmax_action.cpu().data.numpy()[0])
one_hot_action = [int(k == action) for k in range(ACTION_DIM)]
next_state,reward,done,_ = task.step(action)
#fix_reward = -10 if done else 1
actions.append(one_hot_action)
rewards.append(reward)
final_state = next_state
state = next_state
if done:
is_done = True
state = task.reset()
break
if not is_done:
final_r = value_network(Variable(torch.Tensor([final_state]))).cpu().data.numpy()
return states,actions,rewards,final_r,state
def discount_reward(r, gamma,final_r):
discounted_r = np.zeros_like(r)
running_add = final_r
for t in reversed(range(0, len(r))):
running_add = running_add * gamma + r[t]
discounted_r[t] = running_add
return discounted_r
def main():
# init a task generator for data fetching
task = gym.make("CartPole-v0")
init_state = task.reset()
# init value network
value_network = ValueNetwork(input_size = STATE_DIM,hidden_size = 40,output_size = 1)
value_network_optim = torch.optim.Adam(value_network.parameters(),lr=0.01)
# init actor network
actor_network = ActorNetwork(STATE_DIM,40,ACTION_DIM)
actor_network_optim = torch.optim.Adam(actor_network.parameters(),lr = 0.01)
steps =[]
task_episodes =[]
test_results =[]
for step in range(STEP):
states,actions,rewards,final_r,current_state = roll_out(actor_network,task,SAMPLE_NUMS,value_network,init_state)
init_state = current_state
actions_var = Variable(torch.Tensor(actions).view(-1,ACTION_DIM))
states_var = Variable(torch.Tensor(states).view(-1,STATE_DIM))
# train actor network
actor_network_optim.zero_grad()
log_softmax_actions = actor_network(states_var)
vs = value_network(states_var).detach()
# calculate qs
qs = Variable(torch.Tensor(discount_reward(rewards,0.99,final_r)))
advantages = qs - vs
actor_network_loss = - torch.mean(torch.sum(log_softmax_actions*actions_var,1)* advantages)
actor_network_loss.backward()
torch.nn.utils.clip_grad_norm(actor_network.parameters(),0.5)
actor_network_optim.step()
# train value network
value_network_optim.zero_grad()
target_values = qs
values = value_network(states_var)
criterion = nn.MSELoss()
value_network_loss = criterion(values,target_values)
value_network_loss.backward()
torch.nn.utils.clip_grad_norm(value_network.parameters(),0.5)
value_network_optim.step()
# Testing
if (step + 1) % 50== 0:
result = 0
test_task = gym.make("CartPole-v0")
for test_epi in range(10):
state = test_task.reset()
for test_step in range(200):
softmax_action = torch.exp(actor_network(Variable(torch.Tensor([state]))))
#print(softmax_action.data)
action = np.argmax(softmax_action.data.numpy()[0])
next_state,reward,done,_ = test_task.step(action)
result += reward
state = next_state
if done:
break
print("step:",step+1,"test result:",result/10.0)
steps.append(step+1)
test_results.append(result/10)
if __name__ == '__main__':
main()