【tensorflow2.x】训练 muzero 玩五子棋 (Gomoku)

【深度强化学习】tensorflow2.x 训练 muzero 玩五子棋 (Gomoku)


github代码地址:https://github.com/NickNameHaveBeenSwallowed/muzero-tensorflow2.x


参考资料
[1]ColinFred. 蒙特卡洛树搜索(MCTS)代码详解【python】. 2019-03-23 23:37:09.
[2]饼干Japson 深度强化学习实验室.【论文深度研读报告】MuZero算法过程详解.2021-01-19.
[3]Tangarf. Muzero算法研读报告. 2020-08-31 11:40:20 .
[4]带带弟弟好吗. AlphaGo版本三——MuZero. 2020-08-30.
[5]Google原论文:Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model.
[6]参考GitHub代码1.
[7]参考GitHub代码2.

【tensorflow2.x】训练 muzero 玩五子棋 (Gomoku)_第1张图片


  • 网上使用 muzero 训练五子棋的程序很少,Github上的又写的很难读,这里提供一个易读简单的版本。
  • 这里的代码是单线程的,收集 self play data 的效率比较低,后面会更新多线程的训练方法。
  • 这里因为个人电脑的原因没有完成完整的训练,具体效果如有条件还请自行训练查看。
  • 之前有人写的 alphazero 的版本 8x8 大小的棋盘大概3000多轮就可以得到很好的模型,如果有人完成了训练还请在评论区告知一下,因为没有验证博主也不知道这代码 work 不 work。
  • 和之前一样使用的是 gym 库完成的五子棋游戏环境的搭建。

更新

  • (2022.10.20) :将网络替换成 Resnet ,加入了多线程训练。

resnet_model.py

import tensorflow as tf
import numpy as np

num_blocks = 6
assert num_blocks >= 1 , "残差块的数量必须大于等于1"
l2 = 1e-4

def hidden_state_norm(x):
    min = tf.reduce_min(x, axis=(1, 2), keepdims=True) - 1e-6
    max = tf.reduce_max(x, axis=(1, 2), keepdims=True) + 1e-6
    hs_norm = (x - min) / (max - min)
    return hs_norm

class ResidualBlock(tf.keras.Model):
    expansion = 1

    def __init__(self, in_channels, out_channels, strides=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(out_channels, kernel_size=3, strides=strides,
                                            padding="same", use_bias=False,
                                            kernel_regularizer=tf.keras.regularizers.l2(l2))
        self.bn1 = tf.keras.layers.BatchNormalization()

        self.conv2 = tf.keras.layers.Conv2D(out_channels, kernel_size=3, strides=1,
                                            padding="same", use_bias=False,
                                            kernel_regularizer=tf.keras.regularizers.l2(l2))
        self.bn2 = tf.keras.layers.BatchNormalization()

        """
        Adds a shortcut between input and residual block and merges them with "sum"
        """
        if strides != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = tf.keras.Sequential([
                    tf.keras.layers.Conv2D(self.expansion*out_channels, kernel_size=1,
                                           strides=strides, use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(l2)),
                    tf.keras.layers.BatchNormalization()]
                    )
        else:
            self.shortcut = lambda x,_: x

    def call(self, x, training=False):
        # if training: print("=> training network ... ")
        out = tf.nn.relu(self.bn1(self.conv1(x), training=training))
        out = self.bn2(self.conv2(out), training=training)
        out += self.shortcut(x, training)
        return tf.nn.relu(out)

class representation:
    def __init__(self, observation_shape, hidden_state_channel):
        observation = tf.keras.Input(shape=observation_shape)

        x = ResidualBlock(
                in_channels=observation_shape[-1],
                out_channels=hidden_state_channel,
            )(observation)

        for _ in range(num_blocks - 1):
            x = ResidualBlock(
                    in_channels=hidden_state_channel,
                    out_channels=hidden_state_channel,
                )(x)

        hidden_state = hidden_state_norm(x)

        self.model = tf.keras.Model(inputs=observation, outputs=hidden_state)
        self.trainable_variables = self.model.trainable_variables

    def predict(self, observation):
        observation = np.array([observation])
        hidden_state = np.array(self.model(observation)[0])
        return hidden_state

class dynamics:
    def __init__(self, hidden_state_shape, hidden_state_channel, num_chess):
        self.num_chess = num_chess
        hidden_state = tf.keras.Input(shape=hidden_state_shape)
        action = tf.keras.Input(shape=(num_chess, num_chess, 1))
        
        x = tf.keras.layers.concatenate([hidden_state, action])
        # print(x.shape)
        x = ResidualBlock(
            in_channels=hidden_state_channel + 1,
            out_channels=hidden_state_channel
        )(x)

        for _ in range(num_blocks - 1):
            x = ResidualBlock(
                    in_channels=hidden_state_channel,
                    out_channels=hidden_state_channel
                )(x)
        next_hidden_state = hidden_state_norm(x)

        self.model = tf.keras.Model(inputs=[hidden_state, action], outputs=next_hidden_state)
        self.trainable_variables = self.model.trainable_variables

    def predict(self, hidden_state, action):
        hidden_state = np.array([hidden_state])
        action = np.array([1 if i == action else 0 for i in range(self.num_chess ** 2)])
        action = np.reshape(action, newshape=(1, self.num_chess, self.num_chess, 1))
        next_hidden_state = self.model([hidden_state, action])
        next_hidden_state = np.array(next_hidden_state[0])
        return next_hidden_state

class prediction:
    def __init__(self, hidden_state_shape, hidden_state_channel, num_chess):
        hidden_state = tf.keras.Input(shape=hidden_state_shape)
        x = hidden_state
        for _ in range(num_blocks):
            x = ResidualBlock(
                    in_channels=hidden_state_channel,
                    out_channels=hidden_state_channel,
                )(x)

        policy = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=1,
                          padding="SAME", use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(l2))(x)
        policy = tf.keras.layers.BatchNormalization()(policy)
        policy = tf.keras.layers.Activation('relu')(policy)
        policy = tf.keras.layers.Flatten()(policy)
        policy = tf.keras.layers.Dense(units=1024, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(l2))(policy)
        policy = tf.keras.layers.Dense(units=num_chess ** 2, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2(l2))(policy)

        value = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=1,
                          padding="SAME", use_bias=False, kernel_regularizer=tf.keras.regularizers.l2(l2))(x)
        value = tf.keras.layers.BatchNormalization()(value)
        value = tf.keras.layers.Activation('relu')(value)
        value = tf.keras.layers.Flatten()(value)
        value = tf.keras.layers.Dense(units=1024, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(l2))(value)
        value = tf.keras.layers.Dense(units=1, activation='tanh', kernel_regularizer=tf.keras.regularizers.l2(l2))(value)
        self.model = tf.keras.Model(inputs=hidden_state, outputs=[policy, value])
        self.trainable_variables = self.model.trainable_variables

    def predict(self, hidden_state):
        hidden_state = np.array([hidden_state])
        policy, value = self.model(hidden_state)
        policy = np.array(policy[0])
        value = np.array(value[0][0])
        return policy, value

class model:
    def __init__(self, observation_shape, hidden_state_channel, num_chess):
        self.representation = representation(observation_shape, hidden_state_channel)
        hidden_state_shape = (observation_shape[0], observation_shape[1], hidden_state_channel)
        self.dynamics = dynamics(hidden_state_shape, hidden_state_channel, num_chess)
        self.prediction = prediction(hidden_state_shape, hidden_state_channel, num_chess)
        self.trainable_variables = self.representation.trainable_variables + \
                                   self.dynamics.trainable_variables + \
                                   self.prediction.trainable_variables

    def save_weights(self, path):
        self.representation.model.save_weights(path + "-representation.h5")
        self.dynamics.model.save_weights(path + '-dynamics.h5')
        self.prediction.model.save_weights(path + '-prediction.h5')

    def load_weights(self, path):
        self.representation.model.load_weights(path + "-representation.h5")
        self.dynamics.model.load_weights(path + '-dynamics.h5')
        self.prediction.model.load_weights(path + '-prediction.h5')

    def copy_weights(self, target_model):
        self.representation.model.set_weights(target_model.representation.model.get_weights())
        self.dynamics.model.set_weights(target_model.dynamics.model.get_weights())
        self.prediction.model.set_weights(target_model.prediction.model.get_weights())

MCTS.py

import numpy as np

PB_C_INIT = 1.25
PB_C_BASE = 19652

class MinMax:
    def __init__(self):
        self.maximum = -float("inf")
        self.minimum = float("inf")

    def update(self, value):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value):
        if self.maximum > self.minimum:
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value

class TreeNode:
    def __init__(self):
        self.parent = None
        self.prior = 1.0
        self.hidden_state = None
        self.children = {}
        self.visit_count = 0
        self.reward = 0
        self.Q = 0

    def is_leaf_Node(self):
        return self.children == {}

    def is_root_Node(self):
        return self.parent is None

def add_exploration_noise(node, dirichlet_alpha=0.3, exploration_fraction=0.25):
    actions = list(node.children.keys())
    noise = np.random.dirichlet([dirichlet_alpha] * len(actions))
    frac = exploration_fraction
    for a, n in zip(actions, noise):
        node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac

def ucb_score_Atari(node, minmax, pb_c_init=PB_C_INIT, pb_c_base=PB_C_BASE):
    pb_c = np.log(
        (node.parent.visit_count + pb_c_base + 1) / pb_c_base
    ) + pb_c_init
    pb_c *= np.sqrt(node.parent.visit_count) / (node.visit_count + 1)
    prior_score = pb_c * node.prior
    return minmax.normalize(node.Q) + prior_score

def select_argmax_pUCB_child_Atari(node, minmax):
    return max(
        node.children.items(),
        key=lambda key_node_tuple: ucb_score_Atari(key_node_tuple[1], minmax)
    )

def expand_Atari(node, model):
    if node.parent is not None:
        node.hidden_state, node.reward = model.dynamics.predict(node.parent.hidden_state, node.action)
    policy, value = model.prediction.predict(node.hidden_state)
    node.Q = value
    keys = list(range(len(policy)))
    for k in keys:
        child = TreeNode()
        child.action = k
        child.prior = policy[k]
        child.parent = node
        node.children[k] = child

def backpropagate_Atari(node, minmax, discount):
    value = node.Q
    while True:
        node.visit_count += 1
        minmax.update(node.Q)
        if node.is_root_Node():
            break
        else:
            value = node.reward + discount * value
            node = node.parent
            node.Q = (node.Q * node.visit_count + value) / (node.visit_count + 1)

class MCTS_Atari:
    def __init__(self, model, observation):
        self.root_Node = TreeNode()
        self.model = model
        self.root_Node.hidden_state = self.model.representation.predict(observation)
        self.minmax = MinMax()

    def simulations(self, num_simulation, discount, add_noise=True):
        for _ in range(num_simulation + 1):
            node = self.root_Node
            while True:
                if node.is_leaf_Node():break
                else:
                    _key, node = select_argmax_pUCB_child_Atari(node, self.minmax)
            expand_Atari(node, self.model)
            if node == self.root_Node and add_noise:
                add_exploration_noise(node)
            backpropagate_Atari(node, self.minmax, discount)
        action_visits = {}
        for k, n in self.root_Node.children.items():
            action_visits[k] = n.visit_count
        return action_visits, self.root_Node.Q

    def __str__(self):
        return "Muzero_MCTS_Atari"

def ucb_score_Chess(node, minmax, pb_c_init=PB_C_INIT, pb_c_base=PB_C_BASE):
    pb_c = np.log(
        (node.parent.visit_count + pb_c_base + 1) / pb_c_base
    ) + pb_c_init
    pb_c *= np.sqrt(node.parent.visit_count) / (node.visit_count + 1)
    prior_score = pb_c * node.prior
    return minmax.normalize(node.Q) + prior_score

def select_argmax_pUCB_child_Chess(node, minmax):
    return max(
        node.children.items(),
        key=lambda key_node_tuple: ucb_score_Chess(key_node_tuple[1], minmax)
    )

def expand_Chess(node, model):
    if node.parent is not None:
        node.hidden_state = model.dynamics.predict(node.parent.hidden_state, node.action)
    policy, value = model.prediction.predict(node.hidden_state)
    node.Q = value
    keys = list(range(len(policy)))
    for k in keys:
        child = TreeNode()
        child.action = k
        child.prior = policy[k]
        child.parent = node
        node.children[k] = child

def backpropagate_Chess(node, minmax):
    value = node.Q
    while True:
        node.visit_count += 1
        minmax.update(node.Q)
        if node.is_root_Node():
            break
        else:
            value = - value
            node = node.parent
            node.Q = (node.Q * node.visit_count + value) / (node.visit_count + 1)

class MCTS_Chess:
    def __init__(self, model, observation):
        self.root_Node = TreeNode()
        self.model = model
        self.root_Node.hidden_state = self.model.representation.predict(observation)
        self.minmax = MinMax()

    def simulations(self, num_simulation, add_noise=True):
        for _ in range(num_simulation + 1):
            node = self.root_Node
            while True:
                if node.is_leaf_Node(): break
                else:
                    key, node = select_argmax_pUCB_child_Chess(node, self.minmax)
            expand_Chess(node, self.model)
            if node == self.root_Node and add_noise:
                add_exploration_noise(node)
            backpropagate_Chess(node, self.minmax)

        action_visits = {}
        for k, n in self.root_Node.children.items():
            action_visits[k] = n.visit_count
        return action_visits

    def __str__(self):
        return "Muzero_MCTS_Chess"

game.py

from gym.envs.classic_control import rendering
import numpy as np
import gym

def check(filter, state, size, filter_w, filter_h):
    done = False
    result = []
    for i in range(size - filter_h + 1):
        for j in range(size - filter_w + 1):
            input_block = state[i:i + filter_h, j:j + filter_w]
            result.append(np.sum(filter * input_block))

    for i in result:
        if i == 5:
            done =True
    return done

class Gomoku(gym.Env):
    def __init__(self, num_chess, block_size):

        if num_chess < 5:
            raise ValueError("The minimum checkerboard is 5.")

        self.board = None
        self.num_chess = num_chess
        self.winner = None

        self.block_size = block_size

        self.viewer = None

        self.player = None

    def reset(self):
        self.board = np.zeros([3, self.num_chess, self.num_chess])
        self.player = 0

        return self.board, self.winner

    def render(self, mode="human"):
        if self.viewer is None:
            self.viewer = rendering.Viewer(
            self.num_chess * self.block_size,
            self.num_chess * self.block_size
            )
            self.viewer.geoms.clear()
            self.viewer.onetime_geoms.clear()
        for i in range(self.num_chess - 1):
            line = rendering.Line((0, (i+1) * self.block_size), (self.num_chess * self.block_size, (i+1) * self.block_size))
            line.set_color(0, 0, 0)
            self.viewer.add_geom(line)
            line = rendering.Line(((i+1) * self.block_size, 0), ((i+1) * self.block_size, self.num_chess * self.block_size))
            line.set_color(0, 0, 0)
            self.viewer.add_geom(line)

        for i in range(self.num_chess):
            for j in range(self.num_chess):
                if self.board[0][j][i] == 1:
                    circle = rendering.make_circle(0.35 * self.block_size)
                    circle.set_color(0 / 255, 139 / 255, 0 / 255)
                    move = rendering.Transform(
                        translation=(
                            (i + 0.5) * self.block_size,
                            (self.num_chess - j - 0.5) * self.block_size
                        )
                    )
                    circle.add_attr(move)
                    self.viewer.add_geom(circle)

        for i in range(self.num_chess):
            for j in range(self.num_chess):
                if self.board[1][j][i] == 1:
                    circle = rendering.make_circle(0.35 * self.block_size)
                    circle.set_color(238 / 255,  118 / 255, 33 / 255)
                    move = rendering.Transform(
                        translation=(
                            (i + 0.5) * self.block_size,
                            (self.num_chess - j - 0.5) * self.block_size
                        )
                    )
                    circle.add_attr(move)
                    self.viewer.add_geom(circle)

        return self.viewer.render(return_rgb_array=mode == 'rgb_array')

    def done(self):
        done = False
        filter0 = np.array([1, 1, 1, 1, 1])
        filter1 = np.array([[1], [1], [1], [1], [1]])
        filter2 = np.eye(5)
        filter3 = np.eye(5)[::-1]
        done = check(filter0, self.board[0], self.num_chess, 5, 1) or done
        done = check(filter0, self.board[1], self.num_chess, 5, 1) or done
        done = check(filter1, self.board[0], self.num_chess, 1, 5) or done
        done = check(filter1, self.board[1], self.num_chess, 1, 5) or done
        done = check(filter2, self.board[0], self.num_chess, 5, 5) or done
        done = check(filter2, self.board[1], self.num_chess, 5, 5) or done
        done = check(filter3, self.board[0], self.num_chess, 5, 5) or done
        done = check(filter3, self.board[1], self.num_chess, 5, 5) or done
        return done

    def step(self, action: int):
        i = int(action / self.num_chess)
        j = action % self.num_chess
        if self.board[0][i][j] == 1 or self.board[1][i][j] == 1:
            raise ValueError("Action error, there are pieces here")
        else:
            self.board[self.player][i][j] = 1

        if self.done():
            self.winner = self.player
            if self.player == 0:
                self.board[2] = np.ones([self.num_chess, self.num_chess])
                self.player = 1
            else:
                self.board[2] = np.zeros([self.num_chess, self.num_chess])
                self.player = 0
            return self.board, self.winner

        else:
            if self.player == 0:
                self.board[2] = np.ones([self.num_chess, self.num_chess])
                self.player = 1
            else:
                self.board[2] = np.zeros([self.num_chess, self.num_chess])
                self.player = 0
            return self.board, self.winner

self_play.py

from game import Gomoku
from MCTS import MCTS_Chess
import numpy as np
import time

class play_game:
    def __init__(self, num_chess, block_size, model, num_simulations, render):
        self.num_chess = num_chess
        self.env = Gomoku(num_chess, block_size)
        self.render = render
        self.max_step = num_chess ** 2
        self.valid_action = list(range(num_chess ** 2))
        self.model = model
        self.mcts = MCTS_Chess
        self.num_simulations = num_simulations

    def choice_action(self, observation, T=1.0):
        # t = time.time()
        mcts = self.mcts(self.model, observation)
        visit_count = mcts.simulations(self.num_simulations)
        # print(visit_count.values())
        for k, v in visit_count.items():
            if k not in self.valid_action:
                visit_count[k] = 0

        action_visits = np.array(list(visit_count.values()))
        if np.any(action_visits):
            policy = action_visits ** (1 / T) / np.sum(action_visits ** (1 / T))
        else:
            policy = np.array([1 / len(self.valid_action) if i in self.valid_action else 0 for i in range(self.num_chess ** 2)])

        action = np.random.choice(len(policy), p=policy)
        self.valid_action.remove(action)
        # print(time.time() - t)
        return action, policy

    def run(self):
        trajectory = []
        state, winner = self.env.reset()
        # state = np.reshape(state, newshape=(self.num_chess, self.num_chess, 3))
        state = np.transpose(state, (1, 2, 0))
        if self.render:
            self.env.render()
        for step in range(self.max_step):
            action, policy = self.choice_action(state)
            action_onehot = np.reshape([1 if i == action else 0 for i in range(self.num_chess ** 2)], newshape=(self.num_chess, self.num_chess, 1))
            trajectory.append([state, action_onehot, policy])
            state, winner = self.env.step(action)

            if self.render:
                self.env.render()
            # state = np.reshape(state, newshape=(self.num_chess, self.num_chess, 3))
            state = np.transpose(state, (1, 2, 0))
            if winner is not None:
                break

        return trajectory, winner

# class human_play:
#     def __init__(self, num_chess, block_size, render):
#         self.num_chess = num_chess
#         self.env = Gomoku(num_chess, block_size)
#         self.render = render
#         self.max_step = num_chess ** 2
#
#     def run(self):
#         trajectory = []
#         state, winner = self.env.reset()
#         state = np.reshape(state, newshape=(self.num_chess, self.num_chess, 3))
#         if self.render:
#             self.env.render()
#         for step in range(self.max_step):
#             action = int(input())
#             policy = [1 if i == action else 0 for i in range(self.num_chess ** 2)]
#             action_onehot = np.reshape(policy, newshape=(self.num_chess, self.num_chess, 1))
#             policy = np.array(policy)
#             trajectory.append([state, action_onehot, policy])
#             state, winner = self.env.step(action)
#             if self.render:
#                 self.env.render()
#             state = np.reshape(state, newshape=(self.num_chess, self.num_chess, 3))
#             if winner is not None:
#                 break
#         last_action = np.reshape([0 for _ in range(self.num_chess ** 2)], newshape=(self.num_chess, self.num_chess, 1))
#         last_policy = np.array([0 for _ in range(self.num_chess ** 2)])
#         trajectory.append([state, last_action, last_policy])
#         return trajectory, winner

trainer.py

from tensorflow.keras import optimizers, losses
from collections import deque

import numpy as np
import tensorflow as tf
import random

class ReplayBuffer():
    def __init__(self, max_memory):
        self.memory = deque(maxlen=max_memory)
        self.len = len(self.memory)

    def save_memory(self, trajectory):
        self.memory.append(
            self.data_augmentation(trajectory)
        )

    def sample(self, sample_size):
        batch_size = min(sample_size, len(self.memory))
        return random.sample(self.memory, batch_size)

    @staticmethod
    def data_augmentation(trajectory):
        new_t = []
        for s_a_p_w in trajectory:
            state, action, policy, winner = s_a_p_w
            policy = np.reshape(policy, newshape=(action.shape[0], action.shape[1], 1))
            action = np.reshape(action, newshape=(action.shape[0], action.shape[1], 1))
            state_flip_1, action_flip_1, policy_flip_1, winner_flip_1 = tf.image.flip_left_right(state), tf.image.flip_left_right(action), tf.image.flip_left_right(policy), winner
            state_rot90, action_rot90, policy_rot90, winner_rot90 = tf.image.rot90(state, k=1), tf.image.rot90(action, k=1), tf.image.rot90(policy, k=1), winner
            state_flip_2, action_flip_2, policy_flip_2, winner_flip_2 = tf.image.flip_left_right(state_rot90), tf.image.flip_left_right(action_rot90), tf.image.flip_left_right(policy_rot90), winner
            state_rot180, action_rot180, policy_rot180, winner_rot180 = tf.image.rot90(state, k=2), tf.image.rot90(action, k=2), tf.image.rot90(policy, k=2), winner
            state_flip_3, action_flip_3, policy_flip_3, winner_flip_3 = tf.image.flip_left_right(state_rot180), tf.image.flip_left_right(action_rot180), tf.image.flip_left_right(policy_rot180), winner
            state_rot270, action_rot270, policy_rot270, winner_rot270 = tf.image.rot90(state, k=3), tf.image.rot90(action, k=3), tf.image.rot90(policy, k=3), winner
            state_flip_4, action_flip_4, policy_flip_4, winner_flip_4 = tf.image.flip_left_right(state_rot270), tf.image.flip_left_right(action_rot270), tf.image.flip_left_right(policy_rot270), winner
            new_t.append([
                np.array([state, state_flip_1, state_rot90, state_flip_2, state_rot180, state_flip_3, state_rot270, state_flip_4]),
                np.array([
                    np.reshape(action, newshape=(action.shape[0], action.shape[1])),
                    np.reshape(action_flip_1, newshape=(action_flip_1.shape[0], action_flip_1.shape[1])),
                    np.reshape(action_rot90, newshape=(action_rot90.shape[0], action_rot90.shape[1])),
                    np.reshape(action_flip_2, newshape=(action_flip_2.shape[0], action_flip_2.shape[1])),
                    np.reshape(action_rot180, newshape=(action_rot180.shape[0], action_rot180.shape[1])),
                    np.reshape(action_flip_3, newshape=(action_flip_3.shape[0], action_flip_3.shape[1])),
                    np.reshape(action_rot270, newshape=(action_rot270.shape[0], action_rot270.shape[1])),
                    np.reshape(action_flip_4, newshape=(action_flip_4.shape[0], action_flip_4.shape[1]))

                ]),
                np.array([
                    np.reshape(policy, newshape=(policy.shape[0] * policy.shape[1])),
                    np.reshape(policy_flip_1, newshape=(policy_flip_1.shape[0] * policy_flip_1.shape[1])),
                    np.reshape(policy_rot90, newshape=(policy_rot90.shape[0] * policy_rot90.shape[1])),
                    np.reshape(policy_flip_2, newshape=(policy_flip_2.shape[0] * policy_flip_2.shape[1])),
                    np.reshape(policy_rot180, newshape=(policy_rot180.shape[0] * policy_rot180.shape[1])),
                    np.reshape(policy_flip_3, newshape=(policy_flip_3.shape[0] * policy_flip_3.shape[1])),
                    np.reshape(policy_rot270, newshape=(policy_rot270.shape[0] * policy_rot270.shape[1])),
                    np.reshape(policy_flip_4, newshape=(policy_flip_4.shape[0] * policy_flip_4.shape[1]))
                ]),
                np.array([winner, winner_flip_1, winner_rot90, winner_flip_2, winner_rot180, winner_flip_3, winner_rot270, winner_flip_4])
            ])
        return new_t


class Trainer():
    def __init__(self, lr=1e-3, max_save_memory=int(1e6)):
        self.optimizer = optimizers.Adam(lr)
        self.replay_buffer = ReplayBuffer(max_save_memory)

    @staticmethod
    def roll_to_end(traj, model, policy_targets, value_targets, policy_predicts, value_predicts):
        first_state = traj[0][0]
        hidden_state = model.representation.model(first_state)
        for step in range(len(traj)):
            p_pred, v_pred = model.prediction.model(hidden_state)
            act = traj[step][1]
            hidden_state = model.dynamics.model([hidden_state, act])
            p_tar = traj[step][2]
            v_tar = np.reshape(traj[step][3], newshape=(-1, 1))
            policy_targets.append(p_tar)
            value_targets.append(v_tar)

            policy_predicts.append(p_pred)
            value_predicts.append(v_pred)
        return policy_targets, value_targets, policy_predicts, value_predicts

    def run_train(self, batch_size, model):
        train_data = self.replay_buffer.sample(batch_size)
        policys_losses, value_losses, entropys = [], [], []
        for data in train_data:
            with tf.GradientTape() as tape:
                policy_targets, value_targets = [], []
                policy_predicts,  value_predicts = [], []
                for i in range(len(data)):
                    traj = data[i:]
                    policy_targets, value_targets, policy_predicts, value_predicts = self.roll_to_end(
                        traj, model,
                        policy_targets,
                        value_targets,
                        policy_predicts,
                        value_predicts
                    )
                entropy = []
                for policy in policy_predicts:
                    entropy.append(- np.sum(policy[0] * np.log(policy[0] + 1e-6)))

                policy_loss = losses.categorical_crossentropy(
                    y_pred=policy_predicts,
                    y_true=policy_targets
                )
                value_loss = losses.mean_squared_error(
                    y_pred=value_predicts,
                    y_true=value_targets
                )

                loss = policy_loss + value_loss

            grad = tape.gradient(loss, model.trainable_variables)
            self.optimizer.apply_gradients(zip(grad, model.trainable_variables))
            policys_losses.append(np.mean(policy_loss))
            value_losses.append(np.mean(value_loss))
            entropys.append(np.mean(entropy))

        return np.mean(policys_losses), np.mean(value_losses), np.mean(entropys)

run_training.py (主训练函数)

from resnet_model import model
from self_play import play_game
from trainer import Trainer

import multiprocessing
import threading
import datetime
import time

NUM_CHESS = 9
RENDER_BLOCK_SIZE = 50
OBSERVATION_SHAPE = (NUM_CHESS, NUM_CHESS, 3)
HIDDEN_STATE_CHANNEL = 32
NUM_SIMULATIONS = 400

BUFFER_SIZE = int(1e6)
BATCH_SIZE = 512

NUM_WORKERS = 8

def self_play_worker(pipe):
    worker_model = model(OBSERVATION_SHAPE, HIDDEN_STATE_CHANNEL, NUM_CHESS)
    while True:
        weights = pipe.recv()
        worker_model.representation.model.set_weights(weights[0])
        worker_model.dynamics.model.set_weights(weights[1])
        worker_model.prediction.model.set_weights(weights[2])

        self_play = play_game(NUM_CHESS, RENDER_BLOCK_SIZE, worker_model, NUM_SIMULATIONS, render=False)
        trajectory, winner = self_play.run()

        win = 1.0 if winner is not None else 0.0
        for i in trajectory[::-1]:
            i.append(win)
            win *= -1
        pipe.send(trajectory)

def save_model():
    global global_model
    while True:
        time.sleep(60 * 20)
        global_model.save_weights("./model/gomoku_{}X{}".format(NUM_CHESS, NUM_CHESS))
        print('\n save model at {}'.format(datetime.datetime.now()))

def training(trainer):
    global episode
    global global_model
    while True:
        t = time.time()
        policy_loss, value_loss, entropy = trainer.run_train(BATCH_SIZE, global_model)
        print("\r episode: {}, policy_loss: {}, value_loss: {}, losses: {}, entropy: {}, num_trajectory: {}, train time: {} s".format(
            episode, policy_loss, value_loss, policy_loss + value_loss, entropy, len(trainer.replay_buffer.memory), int(time.time() - t)),
        end="")

def communication(trainer, pipe_dict):
    global episode
    global global_model
    while True:
        for pipe in pipe_dict.values():
            pipe[0].send(
                [
                    global_model.representation.model.get_weights(),
                    global_model.dynamics.model.get_weights(),
                    global_model.prediction.model.get_weights()
                ]
            )

        for pipe in pipe_dict.values():
            trajectory = pipe[0].recv()
            trainer.replay_buffer.save_memory(trajectory)
        episode += 1

if __name__ == '__main__':
    global_model = model(OBSERVATION_SHAPE, HIDDEN_STATE_CHANNEL, NUM_CHESS)
    trainer = Trainer()

    episode = 0

    # global_model.load_weights("./model/gomoku_{}X{}".format(NUM_CHESS, NUM_CHESS))
    train_thread = threading.Thread(target=training, args=[trainer])
    train_thread.start()

    pipe_dict = {}
    for w in range(NUM_WORKERS):
        pipe_dict["worker_{}".format(str(w))] = multiprocessing.Pipe()

    process = []
    for w in range(NUM_WORKERS):
        self_play_process = multiprocessing.Process(
            target=self_play_worker,
            args=(
                pipe_dict["worker_{}".format(str(w))][1],
            )
        )
        process.append(self_play_process)
    [p.start() for p in process]

    communication_thread = threading.Thread(target=communication, args=[trainer, pipe_dict])
    communication_thread.start()

    savemodel_thread = threading.Thread(target=save_model)
    savemodel_thread.start()

test.py (AI vs AI 测试函数)

from resnet_model import model
from self_play import play_game

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

NUM_CHESS = 9
RENDER_BLOCK_SIZE = 50
OBSERVATION_SHAPE = (NUM_CHESS, NUM_CHESS, 3)
HIDDEN_STATE_CHANNEL = 32
NUM_SIMULATIONS = 400

def self_play(model, num_simulations):
    self_play = play_game(NUM_CHESS, RENDER_BLOCK_SIZE, model, num_simulations, True)
    trajectory, winner = self_play.run()
    print(winner)

if __name__ == '__main__':
    gomoku_model = model(OBSERVATION_SHAPE, HIDDEN_STATE_CHANNEL, NUM_CHESS)
    gomoku_model.load_weights("./model/gomoku_{}X{}".format(NUM_CHESS, NUM_CHESS))
    self_play(gomoku_model, NUM_SIMULATIONS)

你可能感兴趣的:(机器学习,深度学习,强化学习,tensorflow,深度学习,python,强化学习)