【强化学习实战】tensorflow2.0 实现 MuZero

参考资料:
[1]ColinFred. 蒙特卡洛树搜索(MCTS)代码详解【python】. 2019-03-23 23:37:09.
[2]饼干Japson 深度强化学习实验室.【论文深度研读报告】MuZero算法过程详解.2021-01-19.
[3]Tangarf. Muzero算法研读报告. 2020-08-31 11:40:20 .
[4]带带弟弟好吗. AlphaGo版本三——MuZero. 2020-08-30.
[5]Google原论文:Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model.
[6]参考GitHub代码1.
[7]参考GitHub代码2.

from collections import deque
from tqdm import tqdm

import tensorflow as tf
import numpy as np
import gym

class MuZeroModels(object):
    def __init__(self, obs_shape: 'int', act_shape: 'int', obs_num=7, l2=1e-2):

        input_ops = tf.keras.Input(shape=(obs_num, obs_shape))
        x = tf.keras.layers.Flatten()(input_ops)
        x = tf.keras.layers.Dense(
            units=128,
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(l2)
        )(x)
        x = tf.keras.layers.Dense(
            units=128,
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(l2)
        )(x)
        hidden_state = tf.keras.layers.Dense(
            units=obs_shape,
            kernel_regularizer=tf.keras.regularizers.l2(l2),
            activation='sigmoid'
        )(x)
        self.representation = tf.keras.Model(inputs=input_ops, outputs=hidden_state)

        input_hidden_state_action = tf.keras.Input(shape=(obs_shape + act_shape))
        x = tf.keras.layers.Dense(
            units=128,
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(l2)
        )(input_hidden_state_action)

        x = tf.keras.layers.Dense(
            units=128,
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(l2)
        )(x)
        next_hidden_state = tf.keras.layers.Dense(
            units=obs_shape,
            kernel_regularizer=tf.keras.regularizers.l2(l2),
            activation='sigmoid'
        )(x)
        reward = tf.keras.layers.Dense(
            units=1,
            kernel_regularizer=tf.keras.regularizers.l2(l2),
            activation='tanh'
        )(x)
        self.dynamics = tf.keras.Model(inputs=input_hidden_state_action, outputs=[next_hidden_state, reward])

        input_hidden_state = tf.keras.Input(shape=(obs_shape))
        x = tf.keras.layers.Dense(
            units=128,
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(l2)
        )(input_hidden_state)
        x = tf.keras.layers.Dense(
            units=128,
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(l2)
        )(x)
        policy = tf.keras.layers.Dense(
            units=act_shape,
            activation='softmax',
            kernel_regularizer=tf.keras.regularizers.l2(l2)
        )(x)
        value = tf.keras.layers.Dense(
            units=1,
            activation='sigmoid',
            kernel_regularizer=tf.keras.regularizers.l2(l2)
        )(x)
        self.prediction = tf.keras.Model(inputs=input_hidden_state, outputs=[policy, value])

    def save_weights(self, save_path):
        self.representation.save_weights(save_path+'/representation.h5')
        self.dynamics.save_weights(save_path+'/dynamics.h5')
        self.prediction.save_weights(save_path+'/prediction.h5')

    def load_weights(self, load_path):
        self.representation.load_weights(load_path+'/representation.h5')
        self.dynamics.load_weights(load_path+'/dynamics.h5')
        self.prediction.load_weights(load_path+'/prediction.h5')

class minmax(object):
    def __init__(self):
        self.maximum = -float("inf")
        self.minimum = float("inf")

    def update(self, value):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value):
        if self.maximum > self.minimum:
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value

class TreeNode(object):
    def __init__(
        self,
        parent,
        prior_p,
        hidden_state,
        reward,
        is_PVP: 'bool'=False,
        gamma=0.997
    ):
        self._parent = parent
        self._children = {
     }
        self._num_visits = 0
        self._Q = 0
        self._U = 0
        self._P = prior_p

        self._hidden_state = hidden_state
        self.reward = reward

        self._is_PVP = is_PVP
        self._gamma = gamma

    def expand(self, action_priorP_hiddenStates_reward):
        '''
        :param action_priors: 元组类型,第一项为执行的动作, 第二项为预测的这个动作的概率, 第三项为 hidden state
        生成新节点扩展树
        '''
        for action, prob, hidden_state, reward in action_priorP_hiddenStates_reward:
            if action not in self._children.keys():
                self._children[action] = TreeNode(
                    parent=self,
                    prior_p=prob,
                    hidden_state=hidden_state,
                    reward=reward,
                    is_PVP=self._is_PVP,
                    gamma=self._gamma
                )

    def select(self, c_puct_1=1.25, c_puct_2=19652):
        '''
        :param c_puct_1: 这里根据论文的值设为1.25
        :param c_puct_2: 这里根据论文的值设为19652
        :return: 选择UCB值最大的节点
        '''
        return max(
            self._children.items(),
            key=lambda node_tuple: node_tuple[1].get_value(c_puct_1, c_puct_2)
        )

    def _update(self, value, reward, minmax):
        '''
        :param reward: 从最后叶子节点 n_l 到当前节点 n_k 回溯的奖励累计(乘上衰变因子)
        :param value: 模型估计的最后的叶子节点 n_l 的值乘上 gamma ^ (l-k)
        注意:此函数无需在类外调用
        '''
        _G = reward + value
        minmax.update(_G)
        _G = minmax.normalize(_G)
        self._Q = (self._num_visits * self._Q + _G) / (self._num_visits + 1)
        self._num_visits += 1

    def backward_update(self, minmax, value, backward_reward=0):
        '''
        :param backward_reward: 从叶子节点回溯的所有奖励乘上衰变因子 gamma 后之和
        :param value: 最后叶子节点估计的值函数
        注意:此函数只用在叶子节点调用, 非叶子节点不调用,值函数之评估最终状态
        '''
        self._update(value, backward_reward, minmax)
        if self._is_PVP:
            all_rewards = self.reward - self._gamma * backward_reward
        else:
            all_rewards = self.reward + self._gamma * backward_reward

        if self._parent:
            self._parent.backward_update(minmax, self._gamma * value, all_rewards)

    def get_value(self, c_puct_1=1.25, c_puct_2=19652):
        '''
        :param c_puct_1: 这里根据论文的值设为1.25
        :param c_puct_2: 这里根据论文的值设为19652
        :return: 计算的值
        注意这里UCB地值计算和 alphazero 不一样
        '''

        self._U = self._P *\
                  (np.sqrt(self._parent._num_visits)/(1 + self._num_visits)) *\
                  (
                    c_puct_1 + np.log(
                      (self._parent._num_visits + c_puct_2 + 1)/c_puct_2)
                  )

        return self._Q + self._U

    def is_leaf(self):
        return self._children == {
     }

    def is_root(self):
        return self._parent is None

class MCTS(object):
    def __init__(
        self,
        model: 'MuZeroModels',
        observations,
        action_num,
        reward,
        is_PVP: 'bool'=False,
        gamma=0.997,
        num_playout=50,
        c_puct_1=1.25,
        c_puct_2=19652,
    ):

        self._muzero_model = model
        self.observations = observations
        self.action_num = action_num
        self._minmax = minmax()
        self._root = TreeNode(
            parent=None,
            prior_p=1.0,
            hidden_state=self._muzero_model.representation(observations),
            reward=reward,
            is_PVP=is_PVP,
            gamma=gamma
        )
        self._c_pict_1 = c_puct_1
        self._c_pict_2 = c_puct_2
        self._num_playout = num_playout

    def _playout(self):
        node = self._root

        while True:
            if node.is_leaf():
                break
            _, node = node.select(self._c_pict_1, self._c_pict_2)

        action_probs, value = self._muzero_model.prediction(node._hidden_state)
        action_probs = action_probs[0]

        action_priorP_hiddenStates_reward = []

        action_num = 0

        for action_prob in action_probs:
            action = action_num

            action_num_one_hot = [1 if i == action_num else 0 for i in range(self.action_num)]

            action_num += 1

            state_action = list(node._hidden_state[0]) + action_num_one_hot

            next_hidden_state, reward = self._muzero_model.dynamics(np.array([state_action]))

            action_priorP_hiddenStates_reward.append((action, action_prob, next_hidden_state, reward[0][0]))

        node.expand(action_priorP_hiddenStates_reward)

        node.backward_update(minmax=self._minmax, value=value[0][0])

    def get_action_prob(self):

        for _ in range(self._num_playout + 1):
            self._playout()

        actions = []
        visits = []
        for action, node in self._root._children.items():
            actions.append(action)
            visits.append(node._num_visits)

        exp_visits = np.exp(visits)

        return actions, exp_visits / np.sum(exp_visits)

    def re_play(self):
        node = self._root

        node_list = [self._root]
        action_list = []

        while True:
            if node.is_leaf():
                break
            action, node = node.select(self._c_pict_1, self._c_pict_2)
            node_list.append(node)
            action_list.append([1 if i == action else 0 for i in range(self.action_num)])

        return node_list[:-1], action_list

    def __str__(self):
        return "MuZero_MCTS"

class MuZero:
    def __init__(
        self,
        env,
        obs_shape: 'int',
        act_shape: 'int',

        is_PVP: 'bool'=False,
        obs_num=7,
        gamma=0.997,
        num_playout=50,
        c_puct_1=1.25,
        c_puct_2=19652,
        play_steps=500,

        l2=1e-2,
        lr=1e-3
    ):

        self._env = env
        self.obs_shape = obs_shape
        self.act_shape = act_shape
        self.obs_num = obs_num

        self.initialize()

        self.is_PVP = is_PVP
        self.gamma = gamma
        self.num_playout = num_playout
        self.c_puct_1 = c_puct_1
        self.c_puct_2 = c_puct_2
        self.play_steps = play_steps

        self.model = MuZeroModels(obs_shape, act_shape, obs_num, l2)
        self.opt = tf.keras.optimizers.Adam(learning_rate=lr)
        self.replay_buffer = deque(maxlen=10000)

    def get_action(self, state, reward):

        self._obs_queue.append(state)

        mcts = MCTS(
            model=self.model,
            observations=np.array([self._obs_queue]),
            action_num=self.act_shape,
            reward=reward,
            is_PVP=self.is_PVP,
            gamma=self.gamma,
            num_playout=self.num_playout,
            c_puct_1=self.c_puct_1,
            c_puct_2=self.c_puct_2
        )
        actions, probs = mcts.get_action_prob()

        return mcts, np.random.choice(actions, p=probs)

    def self_play(self, self_play_num):

        for _ in tqdm(range(self_play_num)):
            state = self._env.reset()
            reward = 0
            for t in range(self.play_steps):
                self._env.render()

                mcts, action = self.get_action(state, reward)
                state, reward, done, info = env.step(action)

                choice_act_one_hot = [1 if action == i else 0 for i in range(self.act_shape)]
                next_reward = reward

                self.replay_buffer.append((mcts, choice_act_one_hot, next_reward))

                if done:
                    break

            self.initialize()

    def train(self):
        '''
        mcts是每走一步的蒙特卡洛树
        '''
        for mcts, choice_action, reward in self.replay_buffer:
            nodelist, act_list = mcts.re_play()
            act_list[0] = choice_action
            observations = mcts.observations

            length = len(nodelist)

            for i in range(length):

                if i == 0: # 更新根节点
                    reward_true = reward
                    policy_true = act_list[i]
                    value_true = nodelist[i]._Q
                    with tf.GradientTape() as tape:
                        reward_prob = self.model.dynamics(tf.keras.layers.concatenate([
                            self.model.representation(observations),
                            tf.constant([policy_true], dtype=tf.float32)
                        ]))[1]
                        policy_prob, value_prob = self.model.prediction(
                            self.model.representation(observations)
                        )

                        reward_loss = tf.keras.losses.mean_squared_error(
                            y_true=reward_true,
                            y_pred=reward_prob
                        )

                        policy_value_loss = tf.keras.losses.categorical_crossentropy(
                            y_true=[policy_true],
                            y_pred=policy_prob
                        ) + tf.keras.losses.mean_squared_error(
                            y_true=[[value_true]],
                            y_pred=value_prob
                        )

                        trainable_variables = self.model.representation.trainable_variables + self.model.dynamics.trainable_variables + self.model.prediction.trainable_variables

                        grad = tape.gradient(
                            reward_loss + policy_value_loss,
                            trainable_variables
                        )
                        self.opt.apply_gradients(zip(grad, trainable_variables))
                        # print(reward_loss + policy_value_loss)
                        # print('policy:')
                        # print(policy_true, policy_prob)
                        # print('value:')
                        # print(value_true, value_prob)
                        # print('reward:')policy_value
                        # print(reward_true, reward_prob)
                else: # 更新其他节点:只用更新
                    pass

    def initialize(self):
        self._obs_queue = deque(maxlen=self.obs_num)

        for _ in range(self.obs_num):
            self._obs_queue.append([0 for __ in range(self.obs_shape)])

if __name__ == '__main__':

    env = gym.make('CartPole-v1')
    agent = MuZero(env=env, obs_shape=4, act_shape=2, num_playout=20)
    agent.self_play(1)

    agent.train()

ps : 代码未完全完成,如有错误欢迎更正。

你可能感兴趣的:(机器学习,强化学习,python,强化学习,tensorflow,机器学习,深度学习,python)