【深度强化学习】tensorflow2.x 训练 muzero 玩五子棋 (Gomoku)
github代码地址:https://github.com/NickNameHaveBeenSwallowed/muzero-tensorflow2.x
参考资料:
[1]ColinFred. 蒙特卡洛树搜索(MCTS)代码详解【python】. 2019-03-23 23:37:09.
[2]饼干Japson 深度强化学习实验室.【论文深度研读报告】MuZero算法过程详解.2021-01-19.
[3]Tangarf. Muzero算法研读报告. 2020-08-31 11:40:20 .
[4]带带弟弟好吗. AlphaGo版本三——MuZero. 2020-08-30.
[5]Google原论文:Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model.
[6]参考GitHub代码1.
[7]参考GitHub代码2.
import tensorflow as tf
import numpy as np
num_blocks = 6
assert num_blocks >= 1 , "残差块的数量必须大于等于1"
l2 = 1e-4
def hidden_state_norm(x):
min = tf.reduce_min(x, axis=(1, 2), keepdims=True) - 1e-6
max = tf.reduce_max(x, axis=(1, 2), keepdims=True) + 1e-6
hs_norm = (x - min) / (max - min)
return hs_norm
class ResidualBlock(tf.keras.Model):
expansion = 1
def __init__(self, in_channels, out_channels, strides=1):
super(ResidualBlock, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(out_channels, kernel_size=3, strides=strides,
padding="same", use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(l2))
self.bn1 = tf.keras.layers.BatchNormalization()
self.conv2 = tf.keras.layers.Conv2D(out_channels, kernel_size=3, strides=1,
padding="same", use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(l2))
self.bn2 = tf.keras.layers.BatchNormalization()
"""
Adds a shortcut between input and residual block and merges them with "sum"
"""
if strides != 1 or in_channels != self.expansion * out_channels:
self.shortcut = tf.keras.Sequential([
tf.keras.layers.Conv2D(self.expansion*out_channels, kernel_size=1,
strides=strides, use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(l2)),
tf.keras.layers.BatchNormalization()]
)
else:
self.shortcut = lambda x,_: x
def call(self, x, training=False):
# if training: print("=> training network ... ")
out = tf.nn.relu(self.bn1(self.conv1(x), training=training))
out = self.bn2(self.conv2(out), training=training)
out += self.shortcut(x, training)
return tf.nn.relu(out)
class representation:
def __init__(self, observation_shape, hidden_state_channel):
observation = tf.keras.Input(shape=observation_shape)
x = ResidualBlock(
in_channels=observation_shape[-1],
out_channels=hidden_state_channel,
)(observation)
for _ in range(num_blocks - 1):
x = ResidualBlock(
in_channels=hidden_state_channel,
out_channels=hidden_state_channel,
)(x)
hidden_state = hidden_state_norm(x)
self.model = tf.keras.Model(inputs=observation, outputs=hidden_state)
self.trainable_variables = self.model.trainable_variables
def predict(self, observation):
observation = np.array([observation])
hidden_state = np.array(self.model(observation)[0])
return hidden_state
class dynamics:
def __init__(self, hidden_state_shape, hidden_state_channel, num_chess):
self.num_chess = num_chess
hidden_state = tf.keras.Input(shape=hidden_state_shape)
action = tf.keras.Input(shape=(num_chess, num_chess, 1))
x = tf.keras.layers.concatenate([hidden_state, action])
# print(x.shape)
x = ResidualBlock(
in_channels=hidden_state_channel + 1,
out_channels=hidden_state_channel
)(x)
for _ in range(num_blocks - 1):
x = ResidualBlock(
in_channels=hidden_state_channel,
out_channels=hidden_state_channel
)(x)
next_hidden_state = hidden_state_norm(x)
self.model = tf.keras.Model(inputs=[hidden_state, action], outputs=next_hidden_state)
self.trainable_variables = self.model.trainable_variables
def predict(self, hidden_state, action):
hidden_state = np.array([hidden_state])
action = np.array([1 if i == action else 0 for i in range(self.num_chess ** 2)])
action = np.reshape(action, newshape=(1, self.num_chess, self.num_chess, 1))
next_hidden_state = self.model([hidden_state, action])
next_hidden_state = np.array(next_hidden_state[0])
return next_hidden_state
class prediction:
def __init__(self, hidden_state_shape, hidden_state_channel, num_chess):
hidden_state = tf.keras.Input(shape=hidden_state_shape)
x = hidden_state
for _ in range(num_blocks):
x = ResidualBlock(
in_channels=hidden_state_channel,
out_channels=hidden_state_channel,
)(x)
policy = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=1,
padding="SAME", use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(l2))(x)
policy = tf.keras.layers.BatchNormalization()(policy)
policy = tf.keras.layers.Activation('relu')(policy)
policy = tf.keras.layers.Flatten()(policy)
policy = tf.keras.layers.Dense(units=1024, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(l2))(policy)
policy = tf.keras.layers.Dense(units=num_chess ** 2, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2(l2))(policy)
value = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=1,
padding="SAME", use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(l2))(x)
value = tf.keras.layers.BatchNormalization()(value)
value = tf.keras.layers.Activation('relu')(value)
value = tf.keras.layers.Flatten()(value)
value = tf.keras.layers.Dense(units=1024, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(l2))(value)
value = tf.keras.layers.Dense(units=1, activation='tanh', kernel_regularizer=tf.keras.regularizers.l2(l2))(value)
self.model = tf.keras.Model(inputs=hidden_state, outputs=[policy, value])
self.trainable_variables = self.model.trainable_variables
def predict(self, hidden_state):
hidden_state = np.array([hidden_state])
policy, value = self.model(hidden_state)
policy = np.array(policy[0])
value = np.array(value[0][0])
return policy, value
class model:
def __init__(self, observation_shape, hidden_state_channel, num_chess):
self.representation = representation(observation_shape, hidden_state_channel)
hidden_state_shape = (observation_shape[0], observation_shape[1], hidden_state_channel)
self.dynamics = dynamics(hidden_state_shape, hidden_state_channel, num_chess)
self.prediction = prediction(hidden_state_shape, hidden_state_channel, num_chess)
self.trainable_variables = self.representation.trainable_variables + \
self.dynamics.trainable_variables + \
self.prediction.trainable_variables
def save_weights(self, path):
self.representation.model.save_weights(path + "-representation.h5")
self.dynamics.model.save_weights(path + '-dynamics.h5')
self.prediction.model.save_weights(path + '-prediction.h5')
def load_weights(self, path):
self.representation.model.load_weights(path + "-representation.h5")
self.dynamics.model.load_weights(path + '-dynamics.h5')
self.prediction.model.load_weights(path + '-prediction.h5')
def copy_weights(self, target_model):
self.representation.model.set_weights(target_model.representation.model.get_weights())
self.dynamics.model.set_weights(target_model.dynamics.model.get_weights())
self.prediction.model.set_weights(target_model.prediction.model.get_weights())
import numpy as np
PB_C_INIT = 1.25
PB_C_BASE = 19652
class MinMax:
def __init__(self):
self.maximum = -float("inf")
self.minimum = float("inf")
def update(self, value):
self.maximum = max(self.maximum, value)
self.minimum = min(self.minimum, value)
def normalize(self, value):
if self.maximum > self.minimum:
return (value - self.minimum) / (self.maximum - self.minimum)
return value
class TreeNode:
def __init__(self):
self.parent = None
self.prior = 1.0
self.hidden_state = None
self.children = {}
self.visit_count = 0
self.reward = 0
self.Q = 0
def is_leaf_Node(self):
return self.children == {}
def is_root_Node(self):
return self.parent is None
def add_exploration_noise(node, dirichlet_alpha=0.3, exploration_fraction=0.25):
actions = list(node.children.keys())
noise = np.random.dirichlet([dirichlet_alpha] * len(actions))
frac = exploration_fraction
for a, n in zip(actions, noise):
node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac
def ucb_score_Atari(node, minmax, pb_c_init=PB_C_INIT, pb_c_base=PB_C_BASE):
pb_c = np.log(
(node.parent.visit_count + pb_c_base + 1) / pb_c_base
) + pb_c_init
pb_c *= np.sqrt(node.parent.visit_count) / (node.visit_count + 1)
prior_score = pb_c * node.prior
return minmax.normalize(node.Q) + prior_score
def select_argmax_pUCB_child_Atari(node, minmax):
return max(
node.children.items(),
key=lambda key_node_tuple: ucb_score_Atari(key_node_tuple[1], minmax)
)
def expand_Atari(node, model):
if node.parent is not None:
node.hidden_state, node.reward = model.dynamics.predict(node.parent.hidden_state, node.action)
policy, value = model.prediction.predict(node.hidden_state)
node.Q = value
keys = list(range(len(policy)))
for k in keys:
child = TreeNode()
child.action = k
child.prior = policy[k]
child.parent = node
node.children[k] = child
def backpropagate_Atari(node, minmax, discount):
value = node.Q
while True:
node.visit_count += 1
minmax.update(node.Q)
if node.is_root_Node():
break
else:
value = node.reward + discount * value
node = node.parent
node.Q = (node.Q * node.visit_count + value) / (node.visit_count + 1)
class MCTS_Atari:
def __init__(self, model, observation):
self.root_Node = TreeNode()
self.model = model
self.root_Node.hidden_state = self.model.representation.predict(observation)
self.minmax = MinMax()
def simulations(self, num_simulation, discount, add_noise=True):
for _ in range(num_simulation + 1):
node = self.root_Node
while True:
if node.is_leaf_Node():break
else:
_key, node = select_argmax_pUCB_child_Atari(node, self.minmax)
expand_Atari(node, self.model)
if node == self.root_Node and add_noise:
add_exploration_noise(node)
backpropagate_Atari(node, self.minmax, discount)
action_visits = {}
for k, n in self.root_Node.children.items():
action_visits[k] = n.visit_count
return action_visits, self.root_Node.Q
def __str__(self):
return "Muzero_MCTS_Atari"
def ucb_score_Chess(node, minmax, pb_c_init=PB_C_INIT, pb_c_base=PB_C_BASE):
pb_c = np.log(
(node.parent.visit_count + pb_c_base + 1) / pb_c_base
) + pb_c_init
pb_c *= np.sqrt(node.parent.visit_count) / (node.visit_count + 1)
prior_score = pb_c * node.prior
return minmax.normalize(node.Q) + prior_score
def select_argmax_pUCB_child_Chess(node, minmax):
return max(
node.children.items(),
key=lambda key_node_tuple: ucb_score_Chess(key_node_tuple[1], minmax)
)
def expand_Chess(node, model):
if node.parent is not None:
node.hidden_state = model.dynamics.predict(node.parent.hidden_state, node.action)
policy, value = model.prediction.predict(node.hidden_state)
node.Q = value
keys = list(range(len(policy)))
for k in keys:
child = TreeNode()
child.action = k
child.prior = policy[k]
child.parent = node
node.children[k] = child
def backpropagate_Chess(node, minmax):
value = node.Q
while True:
node.visit_count += 1
minmax.update(node.Q)
if node.is_root_Node():
break
else:
value = - value
node = node.parent
node.Q = (node.Q * node.visit_count + value) / (node.visit_count + 1)
class MCTS_Chess:
def __init__(self, model, observation):
self.root_Node = TreeNode()
self.model = model
self.root_Node.hidden_state = self.model.representation.predict(observation)
self.minmax = MinMax()
def simulations(self, num_simulation, add_noise=True):
for _ in range(num_simulation + 1):
node = self.root_Node
while True:
if node.is_leaf_Node(): break
else:
key, node = select_argmax_pUCB_child_Chess(node, self.minmax)
expand_Chess(node, self.model)
if node == self.root_Node and add_noise:
add_exploration_noise(node)
backpropagate_Chess(node, self.minmax)
action_visits = {}
for k, n in self.root_Node.children.items():
action_visits[k] = n.visit_count
return action_visits
def __str__(self):
return "Muzero_MCTS_Chess"
from gym.envs.classic_control import rendering
import numpy as np
import gym
def check(filter, state, size, filter_w, filter_h):
done = False
result = []
for i in range(size - filter_h + 1):
for j in range(size - filter_w + 1):
input_block = state[i:i + filter_h, j:j + filter_w]
result.append(np.sum(filter * input_block))
for i in result:
if i == 5:
done =True
return done
class Gomoku(gym.Env):
def __init__(self, num_chess, block_size):
if num_chess < 5:
raise ValueError("The minimum checkerboard is 5.")
self.board = None
self.num_chess = num_chess
self.winner = None
self.block_size = block_size
self.viewer = None
self.player = None
def reset(self):
self.board = np.zeros([3, self.num_chess, self.num_chess])
self.player = 0
return self.board, self.winner
def render(self, mode="human"):
if self.viewer is None:
self.viewer = rendering.Viewer(
self.num_chess * self.block_size,
self.num_chess * self.block_size
)
self.viewer.geoms.clear()
self.viewer.onetime_geoms.clear()
for i in range(self.num_chess - 1):
line = rendering.Line((0, (i+1) * self.block_size), (self.num_chess * self.block_size, (i+1) * self.block_size))
line.set_color(0, 0, 0)
self.viewer.add_geom(line)
line = rendering.Line(((i+1) * self.block_size, 0), ((i+1) * self.block_size, self.num_chess * self.block_size))
line.set_color(0, 0, 0)
self.viewer.add_geom(line)
for i in range(self.num_chess):
for j in range(self.num_chess):
if self.board[0][j][i] == 1:
circle = rendering.make_circle(0.35 * self.block_size)
circle.set_color(0 / 255, 139 / 255, 0 / 255)
move = rendering.Transform(
translation=(
(i + 0.5) * self.block_size,
(self.num_chess - j - 0.5) * self.block_size
)
)
circle.add_attr(move)
self.viewer.add_geom(circle)
for i in range(self.num_chess):
for j in range(self.num_chess):
if self.board[1][j][i] == 1:
circle = rendering.make_circle(0.35 * self.block_size)
circle.set_color(238 / 255, 118 / 255, 33 / 255)
move = rendering.Transform(
translation=(
(i + 0.5) * self.block_size,
(self.num_chess - j - 0.5) * self.block_size
)
)
circle.add_attr(move)
self.viewer.add_geom(circle)
return self.viewer.render(return_rgb_array=mode == 'rgb_array')
def done(self):
done = False
filter0 = np.array([1, 1, 1, 1, 1])
filter1 = np.array([[1], [1], [1], [1], [1]])
filter2 = np.eye(5)
filter3 = np.eye(5)[::-1]
done = check(filter0, self.board[0], self.num_chess, 5, 1) or done
done = check(filter0, self.board[1], self.num_chess, 5, 1) or done
done = check(filter1, self.board[0], self.num_chess, 1, 5) or done
done = check(filter1, self.board[1], self.num_chess, 1, 5) or done
done = check(filter2, self.board[0], self.num_chess, 5, 5) or done
done = check(filter2, self.board[1], self.num_chess, 5, 5) or done
done = check(filter3, self.board[0], self.num_chess, 5, 5) or done
done = check(filter3, self.board[1], self.num_chess, 5, 5) or done
return done
def step(self, action: int):
i = int(action / self.num_chess)
j = action % self.num_chess
if self.board[0][i][j] == 1 or self.board[1][i][j] == 1:
raise ValueError("Action error, there are pieces here")
else:
self.board[self.player][i][j] = 1
if self.done():
self.winner = self.player
if self.player == 0:
self.board[2] = np.ones([self.num_chess, self.num_chess])
self.player = 1
else:
self.board[2] = np.zeros([self.num_chess, self.num_chess])
self.player = 0
return self.board, self.winner
else:
if self.player == 0:
self.board[2] = np.ones([self.num_chess, self.num_chess])
self.player = 1
else:
self.board[2] = np.zeros([self.num_chess, self.num_chess])
self.player = 0
return self.board, self.winner
from game import Gomoku
from MCTS import MCTS_Chess
import numpy as np
import time
class play_game:
def __init__(self, num_chess, block_size, model, num_simulations, render):
self.num_chess = num_chess
self.env = Gomoku(num_chess, block_size)
self.render = render
self.max_step = num_chess ** 2
self.valid_action = list(range(num_chess ** 2))
self.model = model
self.mcts = MCTS_Chess
self.num_simulations = num_simulations
def choice_action(self, observation, T=1.0):
# t = time.time()
mcts = self.mcts(self.model, observation)
visit_count = mcts.simulations(self.num_simulations)
# print(visit_count.values())
for k, v in visit_count.items():
if k not in self.valid_action:
visit_count[k] = 0
action_visits = np.array(list(visit_count.values()))
if np.any(action_visits):
policy = action_visits ** (1 / T) / np.sum(action_visits ** (1 / T))
else:
policy = np.array([1 / len(self.valid_action) if i in self.valid_action else 0 for i in range(self.num_chess ** 2)])
action = np.random.choice(len(policy), p=policy)
self.valid_action.remove(action)
# print(time.time() - t)
return action, policy
def run(self):
trajectory = []
state, winner = self.env.reset()
# state = np.reshape(state, newshape=(self.num_chess, self.num_chess, 3))
state = np.transpose(state, (1, 2, 0))
if self.render:
self.env.render()
for step in range(self.max_step):
action, policy = self.choice_action(state)
action_onehot = np.reshape([1 if i == action else 0 for i in range(self.num_chess ** 2)], newshape=(self.num_chess, self.num_chess, 1))
trajectory.append([state, action_onehot, policy])
state, winner = self.env.step(action)
if self.render:
self.env.render()
# state = np.reshape(state, newshape=(self.num_chess, self.num_chess, 3))
state = np.transpose(state, (1, 2, 0))
if winner is not None:
break
return trajectory, winner
# class human_play:
# def __init__(self, num_chess, block_size, render):
# self.num_chess = num_chess
# self.env = Gomoku(num_chess, block_size)
# self.render = render
# self.max_step = num_chess ** 2
#
# def run(self):
# trajectory = []
# state, winner = self.env.reset()
# state = np.reshape(state, newshape=(self.num_chess, self.num_chess, 3))
# if self.render:
# self.env.render()
# for step in range(self.max_step):
# action = int(input())
# policy = [1 if i == action else 0 for i in range(self.num_chess ** 2)]
# action_onehot = np.reshape(policy, newshape=(self.num_chess, self.num_chess, 1))
# policy = np.array(policy)
# trajectory.append([state, action_onehot, policy])
# state, winner = self.env.step(action)
# if self.render:
# self.env.render()
# state = np.reshape(state, newshape=(self.num_chess, self.num_chess, 3))
# if winner is not None:
# break
# last_action = np.reshape([0 for _ in range(self.num_chess ** 2)], newshape=(self.num_chess, self.num_chess, 1))
# last_policy = np.array([0 for _ in range(self.num_chess ** 2)])
# trajectory.append([state, last_action, last_policy])
# return trajectory, winner
from tensorflow.keras import optimizers, losses
from collections import deque
import numpy as np
import tensorflow as tf
import random
class ReplayBuffer():
def __init__(self, max_memory):
self.memory = deque(maxlen=max_memory)
self.len = len(self.memory)
def save_memory(self, trajectory):
self.memory.append(
self.data_augmentation(trajectory)
)
def sample(self, sample_size):
batch_size = min(sample_size, len(self.memory))
return random.sample(self.memory, batch_size)
@staticmethod
def data_augmentation(trajectory):
new_t = []
for s_a_p_w in trajectory:
state, action, policy, winner = s_a_p_w
policy = np.reshape(policy, newshape=(action.shape[0], action.shape[1], 1))
action = np.reshape(action, newshape=(action.shape[0], action.shape[1], 1))
state_flip_1, action_flip_1, policy_flip_1, winner_flip_1 = tf.image.flip_left_right(state), tf.image.flip_left_right(action), tf.image.flip_left_right(policy), winner
state_rot90, action_rot90, policy_rot90, winner_rot90 = tf.image.rot90(state, k=1), tf.image.rot90(action, k=1), tf.image.rot90(policy, k=1), winner
state_flip_2, action_flip_2, policy_flip_2, winner_flip_2 = tf.image.flip_left_right(state_rot90), tf.image.flip_left_right(action_rot90), tf.image.flip_left_right(policy_rot90), winner
state_rot180, action_rot180, policy_rot180, winner_rot180 = tf.image.rot90(state, k=2), tf.image.rot90(action, k=2), tf.image.rot90(policy, k=2), winner
state_flip_3, action_flip_3, policy_flip_3, winner_flip_3 = tf.image.flip_left_right(state_rot180), tf.image.flip_left_right(action_rot180), tf.image.flip_left_right(policy_rot180), winner
state_rot270, action_rot270, policy_rot270, winner_rot270 = tf.image.rot90(state, k=3), tf.image.rot90(action, k=3), tf.image.rot90(policy, k=3), winner
state_flip_4, action_flip_4, policy_flip_4, winner_flip_4 = tf.image.flip_left_right(state_rot270), tf.image.flip_left_right(action_rot270), tf.image.flip_left_right(policy_rot270), winner
new_t.append([
np.array([state, state_flip_1, state_rot90, state_flip_2, state_rot180, state_flip_3, state_rot270, state_flip_4]),
np.array([
np.reshape(action, newshape=(action.shape[0], action.shape[1])),
np.reshape(action_flip_1, newshape=(action_flip_1.shape[0], action_flip_1.shape[1])),
np.reshape(action_rot90, newshape=(action_rot90.shape[0], action_rot90.shape[1])),
np.reshape(action_flip_2, newshape=(action_flip_2.shape[0], action_flip_2.shape[1])),
np.reshape(action_rot180, newshape=(action_rot180.shape[0], action_rot180.shape[1])),
np.reshape(action_flip_3, newshape=(action_flip_3.shape[0], action_flip_3.shape[1])),
np.reshape(action_rot270, newshape=(action_rot270.shape[0], action_rot270.shape[1])),
np.reshape(action_flip_4, newshape=(action_flip_4.shape[0], action_flip_4.shape[1]))
]),
np.array([
np.reshape(policy, newshape=(policy.shape[0] * policy.shape[1])),
np.reshape(policy_flip_1, newshape=(policy_flip_1.shape[0] * policy_flip_1.shape[1])),
np.reshape(policy_rot90, newshape=(policy_rot90.shape[0] * policy_rot90.shape[1])),
np.reshape(policy_flip_2, newshape=(policy_flip_2.shape[0] * policy_flip_2.shape[1])),
np.reshape(policy_rot180, newshape=(policy_rot180.shape[0] * policy_rot180.shape[1])),
np.reshape(policy_flip_3, newshape=(policy_flip_3.shape[0] * policy_flip_3.shape[1])),
np.reshape(policy_rot270, newshape=(policy_rot270.shape[0] * policy_rot270.shape[1])),
np.reshape(policy_flip_4, newshape=(policy_flip_4.shape[0] * policy_flip_4.shape[1]))
]),
np.array([winner, winner_flip_1, winner_rot90, winner_flip_2, winner_rot180, winner_flip_3, winner_rot270, winner_flip_4])
])
return new_t
class Trainer():
def __init__(self, lr=1e-3, max_save_memory=int(1e6)):
self.optimizer = optimizers.Adam(lr)
self.replay_buffer = ReplayBuffer(max_save_memory)
@staticmethod
def roll_to_end(traj, model, policy_targets, value_targets, policy_predicts, value_predicts):
first_state = traj[0][0]
hidden_state = model.representation.model(first_state)
for step in range(len(traj)):
p_pred, v_pred = model.prediction.model(hidden_state)
act = traj[step][1]
hidden_state = model.dynamics.model([hidden_state, act])
p_tar = traj[step][2]
v_tar = np.reshape(traj[step][3], newshape=(-1, 1))
policy_targets.append(p_tar)
value_targets.append(v_tar)
policy_predicts.append(p_pred)
value_predicts.append(v_pred)
return policy_targets, value_targets, policy_predicts, value_predicts
def run_train(self, batch_size, model):
train_data = self.replay_buffer.sample(batch_size)
policys_losses, value_losses, entropys = [], [], []
for data in train_data:
with tf.GradientTape() as tape:
policy_targets, value_targets = [], []
policy_predicts, value_predicts = [], []
for i in range(len(data)):
traj = data[i:]
policy_targets, value_targets, policy_predicts, value_predicts = self.roll_to_end(
traj, model,
policy_targets,
value_targets,
policy_predicts,
value_predicts
)
entropy = []
for policy in policy_predicts:
entropy.append(- np.sum(policy[0] * np.log(policy[0] + 1e-6)))
policy_loss = losses.categorical_crossentropy(
y_pred=policy_predicts,
y_true=policy_targets
)
value_loss = losses.mean_squared_error(
y_pred=value_predicts,
y_true=value_targets
)
loss = policy_loss + value_loss
grad = tape.gradient(loss, model.trainable_variables)
self.optimizer.apply_gradients(zip(grad, model.trainable_variables))
policys_losses.append(np.mean(policy_loss))
value_losses.append(np.mean(value_loss))
entropys.append(np.mean(entropy))
return np.mean(policys_losses), np.mean(value_losses), np.mean(entropys)
from resnet_model import model
from self_play import play_game
from trainer import Trainer
import multiprocessing
import threading
import datetime
import time
NUM_CHESS = 9
RENDER_BLOCK_SIZE = 50
OBSERVATION_SHAPE = (NUM_CHESS, NUM_CHESS, 3)
HIDDEN_STATE_CHANNEL = 32
NUM_SIMULATIONS = 400
BUFFER_SIZE = int(1e6)
BATCH_SIZE = 512
NUM_WORKERS = 8
def self_play_worker(pipe):
worker_model = model(OBSERVATION_SHAPE, HIDDEN_STATE_CHANNEL, NUM_CHESS)
while True:
weights = pipe.recv()
worker_model.representation.model.set_weights(weights[0])
worker_model.dynamics.model.set_weights(weights[1])
worker_model.prediction.model.set_weights(weights[2])
self_play = play_game(NUM_CHESS, RENDER_BLOCK_SIZE, worker_model, NUM_SIMULATIONS, render=False)
trajectory, winner = self_play.run()
win = 1.0 if winner is not None else 0.0
for i in trajectory[::-1]:
i.append(win)
win *= -1
pipe.send(trajectory)
def save_model():
global global_model
while True:
time.sleep(60 * 20)
global_model.save_weights("./model/gomoku_{}X{}".format(NUM_CHESS, NUM_CHESS))
print('\n save model at {}'.format(datetime.datetime.now()))
def training(trainer):
global episode
global global_model
while True:
t = time.time()
policy_loss, value_loss, entropy = trainer.run_train(BATCH_SIZE, global_model)
print("\r episode: {}, policy_loss: {}, value_loss: {}, losses: {}, entropy: {}, num_trajectory: {}, train time: {} s".format(
episode, policy_loss, value_loss, policy_loss + value_loss, entropy, len(trainer.replay_buffer.memory), int(time.time() - t)),
end="")
def communication(trainer, pipe_dict):
global episode
global global_model
while True:
for pipe in pipe_dict.values():
pipe[0].send(
[
global_model.representation.model.get_weights(),
global_model.dynamics.model.get_weights(),
global_model.prediction.model.get_weights()
]
)
for pipe in pipe_dict.values():
trajectory = pipe[0].recv()
trainer.replay_buffer.save_memory(trajectory)
episode += 1
if __name__ == '__main__':
global_model = model(OBSERVATION_SHAPE, HIDDEN_STATE_CHANNEL, NUM_CHESS)
trainer = Trainer()
episode = 0
# global_model.load_weights("./model/gomoku_{}X{}".format(NUM_CHESS, NUM_CHESS))
train_thread = threading.Thread(target=training, args=[trainer])
train_thread.start()
pipe_dict = {}
for w in range(NUM_WORKERS):
pipe_dict["worker_{}".format(str(w))] = multiprocessing.Pipe()
process = []
for w in range(NUM_WORKERS):
self_play_process = multiprocessing.Process(
target=self_play_worker,
args=(
pipe_dict["worker_{}".format(str(w))][1],
)
)
process.append(self_play_process)
[p.start() for p in process]
communication_thread = threading.Thread(target=communication, args=[trainer, pipe_dict])
communication_thread.start()
savemodel_thread = threading.Thread(target=save_model)
savemodel_thread.start()
from resnet_model import model
from self_play import play_game
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
NUM_CHESS = 9
RENDER_BLOCK_SIZE = 50
OBSERVATION_SHAPE = (NUM_CHESS, NUM_CHESS, 3)
HIDDEN_STATE_CHANNEL = 32
NUM_SIMULATIONS = 400
def self_play(model, num_simulations):
self_play = play_game(NUM_CHESS, RENDER_BLOCK_SIZE, model, num_simulations, True)
trajectory, winner = self_play.run()
print(winner)
if __name__ == '__main__':
gomoku_model = model(OBSERVATION_SHAPE, HIDDEN_STATE_CHANNEL, NUM_CHESS)
gomoku_model.load_weights("./model/gomoku_{}X{}".format(NUM_CHESS, NUM_CHESS))
self_play(gomoku_model, NUM_SIMULATIONS)