网上找了很多代码,sac的,pytorch的,很多都不能用,要么实在太复杂,把环境和网络杂糅在一起,要么就代码运行就有错。
找了一个还算能用的,代码网址:
https://github.com/higgsfield/RL-Adventure-2
用pytorch写的,sac第一篇,从这里看到的https://zhuanlan.zhihu.com/p/75937178。
对它进行了一些魔改,现在代码如下:
'''
第一篇SAC
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
device = 'cuda' if torch.cuda.is_available() else 'cpu'
action_dim = 2
state_dim = 12
hidden_dim = 256
batch_size = 128
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.state_pool = torch.zeros(self.capacity, state_dim).float().to(device)
self.action_pool = torch.zeros(self.capacity, action_dim).float().to(device)
self.reward_pool = torch.zeros(self.capacity, 1).float().to(device)
self.next_state_pool = torch.zeros(self.capacity, state_dim).float().to(device)
self.done_pool = torch.zeros(self.capacity, 1).float().to(device)
self.num_transition = 0
def push(self, state, action, reward, next_state, done):
index = self.num_transition % self.capacity
s = torch.tensor(state).float().to(device)
a = torch.tensor(action).float().to(device)
r = torch.tensor(reward).float().to(device)
s_ = torch.tensor(next_state).float().to(device)
d = torch.tensor(done).float().to(device)
for pool, ele in zip(
[self.state_pool, self.action_pool, self.reward_pool, self.next_state_pool, self.done_pool],
[s, a, r, s_, d]):
pool[index] = ele
self.num_transition += 1
def sample(self, batch_size):
index = np.random.choice(range(self.capacity), batch_size, replace=False)
bn_s, bn_a, bn_r, bn_s_, bn_d = self.state_pool[index], self.action_pool[index], self.reward_pool[index], \
self.next_state_pool[index], self.done_pool[index]
return bn_s, bn_a, bn_r, bn_s_, bn_d
class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_dim, init_w=3e-3):
super(ValueNetwork, self).__init__()
self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)
self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
class SoftQNetwork(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
super(SoftQNetwork, self).__init__()
self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.linear3 = nn.Linear(hidden_size, 1)
self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)
def forward(self, state, action):
x = torch.cat([state, action], 1)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
class PolicyNetwork(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
super(PolicyNetwork, self).__init__()
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.linear1 = nn.Linear(num_inputs, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.mean_linear = nn.Linear(hidden_size, num_actions)
self.mean_linear.weight.data.uniform_(-init_w, init_w)
self.mean_linear.bias.data.uniform_(-init_w, init_w)
self.log_std_linear = nn.Linear(hidden_size, num_actions)
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
x = self.linear1(state)
x = F.relu(x)
x = F.relu(self.linear2(x))
mean = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
return mean, log_std
def evaluate(self, state, epsilon=1e-6):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.sample()
action = torch.tanh(z)
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
log_prob = log_prob.sum(-1, keepdim=True)
return action, log_prob, z, mean, log_std
def get_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0).to(device)
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.sample()
action = torch.tanh(z)
action = action.detach().cpu().numpy()
return action[0]
class SAC:
def __init__(self):
super(SAC, self).__init__()
self.gamma = 0.99
self.mean_lambda = 1e-3
self.std_lambda = 1e-3
self.z_lambda = 0.0
self.soft_tau = 1e-2
self.value_net = ValueNetwork(state_dim, hidden_dim).to(device)
self.target_value_net = ValueNetwork(state_dim, hidden_dim).to(device)
self.soft_q_net = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)
for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()):
target_param.data.copy_(param.data)
self.value_criterion = nn.MSELoss()
self.soft_q_criterion = nn.MSELoss()
self.value_lr = 3e-4
self.soft_q_lr = 3e-4
self.policy_lr = 3e-4
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=self.value_lr)
self.soft_q_optimizer = optim.Adam(self.soft_q_net.parameters(), lr=self.soft_q_lr)
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=self.policy_lr)
self.replay_buffer_size = 1000
self.replay_buffer = ReplayBuffer(self.replay_buffer_size)
def select_action(self, state):
return self.policy_net.get_action(state)
def update(self):
# print('=' * 100)
state, action, reward, next_state, done = self.replay_buffer.sample(batch_size)
expected_q_value = self.soft_q_net(state, action)
expected_value = self.value_net(state)
new_action, log_prob, z, mean, log_std = self.policy_net.evaluate(state)
target_value = self.target_value_net(next_state)
next_q_value = reward + (1 - done) * self.gamma * target_value
q_value_loss = self.soft_q_criterion(expected_q_value, next_q_value.detach())
expected_new_q_value = self.soft_q_net(state, new_action)
next_value = expected_new_q_value - log_prob
value_loss = self.value_criterion(expected_value, next_value.detach())
log_prob_target = expected_new_q_value - expected_value
policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean()
mean_loss = self.mean_lambda * mean.pow(2).mean()
std_loss = self.std_lambda * log_std.pow(2).mean()
z_loss = self.z_lambda * z.pow(2).sum(1).mean()
policy_loss += mean_loss + std_loss + z_loss
self.soft_q_optimizer.zero_grad()
q_value_loss.backward()
self.soft_q_optimizer.step()
self.value_optimizer.zero_grad()
value_loss.backward()
self.value_optimizer.step()
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
print('update successed')
def save(self):
torch.save(self.policy_net.state_dict(), './SAC_model/policy_net.pth')
torch.save(self.value_net.state_dict(), './SAC_model/value_net.pth')
torch.save(self.soft_q_net.state_dict(), './SAC_model/soft_q_net.pth')
print("====================================")
print("Model has been saved...")
print("====================================")
def load(self):
self.policy_net.load_state_dict(torch.load('./SAC_model/policy_net.pth'))
self.value_net.load_state_dict(torch.load('./SAC_model/value_net.pth'))
self.soft_q_net.load_state_dict(torch.load('./SAC_model/soft_q_net.pth'))
print("====================================")
print("model has been loaded...")
print("====================================")