Alpha Zero 趣味学习:训练网络

启动 TicTacToe ,会进入网络训练或人机对弈。如果不设置启动参数 –human_play=1,便自动进入网络训练。即:

python main.py

开始训练前,要在 main.py 中做两件事情。第一,初始化游戏和网络:

if CFG.game == 0:
    game = TicTacToeGame()
net = NeuralNetworkWrapper(game)

第二,初始化训练类,开始训练:

train = Train(game, net)
train.start()

类 Train 的初始化:

def __init__(self, game, net):
    """Initializes Train with the board state and neural network."""
    self.game = game
    self.net = net
    self.eval_net = NeuralNetworkWrapper(game)

在 Train.start() 中,初始化了神经网络、估值网络和估值器:

current_mcts = MonteCarloTreeSearch(self.net)
eval_mcts = MonteCarloTreeSearch(self.eval_net)
evaluator = Evaluate(current_mcts=current_mcts, eval_mcts=eval_mcts, game=self.game)

综上所见,网络训练主要由以下模块和类协作实现:

  • train.py 类 Train 负责网络训练的管理
  • 目录 tic_tac_toe 中 tic_tac_toe_game.py 的类 TicTacToeGame ,实现游戏逻辑,管理棋局
  • neural_net.py 类 NeuralNetworkWrapper 是 类 NeuralNetwork 的包装或接口,负责神经网络的管理
  • tic_tac_toe.py 类 TicTacToeGame,是综合了棋盘表示、游戏规则、游戏逻辑等的物件
  • mcts.py 类 MonteCarloTreeSearch 负责蒙特卡罗搜索树管理
  • evaluate.py 类 Evaluate 是下棋策略和价值判断的估值器(残值网络)

这些模块和类,每一个都有丰富的知识内容,应该逐一设置专题判读。在此,仅仅关注 Train。

粗略判读 Train

Train 的训练管理任务,主要由其 start 方法负责。直接用代码说事儿:

def start(self):
    """训练的主循环"""
    for i in range(CFG.num_iterations): # 共计 4 轮次
        print("迭代", i + 1) # 第 i + 1 轮自学
        training_data = []  # 保存自学中的棋局, pis 和 vs
        for j in range(CFG.num_games): # 每轮下棋30局
            print("开始自学训练的棋局", j + 1) #
            game = self.game.clone()  # 对每局棋进行克隆
            self.play_game(game, training_data) # 对弈生成棋局

        # 保存当前网络模型
        self.net.save_model()
        # 把最近保存的模型装入估值网络
        self.eval_net.load_model()
        # 用自学棋局的估值训练网络
        self.net.train(training_data)
        # 初始化蒙特卡罗搜索树,以用于上述两种网络
        current_mcts = MonteCarloTreeSearch(self.net)
        eval_mcts = MonteCarloTreeSearch(self.eval_net)
        # 初始化估值器
        evaluator = Evaluate(current_mcts=current_mcts, 
                             eval_mcts=eval_mcts, game=self.game)
        # 算出胜负局数
        wins, losses = evaluator.evaluate()
        print("胜:", wins)
        print("负:", losses)
        # 计算胜率
        num_games = wins + losses
        if num_games == 0:
            win_rate = 0
        else:
            win_rate = wins / num_games
        print("胜率:", win_rate)
        # 若胜率超过已有成绩
        if win_rate > CFG.eval_win_rate:
            # 保存当前模型为最佳模型
            print("新模型存为最佳模型")
            self.net.save_model("best_model")
        # 若胜率未超已有成绩
        else:
            print("舍弃新模型,载入已有模型")
            # 舍弃当前模型,启用已有最佳模型
            self.net.load_model()

Train 还有2个函数(方法),分别是:

def play_game(self, game, training_data):
    """
    每一棋局自学训练的循环
    对每一棋局运行 MCTS ,并根据 MCTS 的输出进行下一步棋
    在棋局结束时,终止循环并打印输出获胜一方
    参量:
        game: 棋局对象
        training_data: 列表,保存着自学训练的棋局和 pis ,vs
    """
    mcts = MonteCarloTreeSearch(self.net)

    game_over = False
    value = 0
    self_play_data = []
    count = 0

    node = TreeNode()

    # 在棋局终局前一直下棋
    while not game_over:
        # MCTS 模拟,得到最好的下一步棋
        # 根据下棋的步数是否小于探测控制的阈值,作出选择
        if count < CFG.temp_thresh:
            best_child = mcts.search(game, node, CFG.temp_init)
        else:
            best_child = mcts.search(game, node, CFG.temp_final)

        # 保存棋局、概率prob 和估值v
        self_play_data.append([deepcopy(game.state),
                               deepcopy(best_child.parent.child_psas), 0])

        action = best_child.action
        game.play_action(action)  # 棋局的下一步
        count += 1

        game_over, value = game.check_game_over(game.current_player)

        best_child.parent = None
        node = best_child  # Make the child node the root node.

    # 更新估值,作为棋局结果
    for game_state in self_play_data:
        value = -value
        game_state[2] = value
        self.augment_data(game_state, training_data, game.row, game.column)
def augment_data(self, game_state, training_data, row, column):
    """
    每一局棋自学训练的循环
    对每一棋局运行 MCTS ,并根据 MCTS 的输出进行下一步棋
    在棋局结束时,终止循环并打印输出获胜一方
    参量:
        game_state: 以棋局、概率和估值为内容的对象
        training_data: 列表,保存自学训练的棋局、概率和估值
        row: 整数,表示棋盘横线数目
        column: 整数,表示棋盘竖线数目
    """
    # deepcopy 深复制,即将被复制对象完全再复制一遍作为独立的新个体单独存在。
    # 所以改变原有被复制对象不会对已经复制出来的新对象产生影响。
    state = deepcopy(game_state[0])
    psa_vector = deepcopy(game_state[1])

    if CFG.game == 2 or CFG.game == 1:
        training_data.append([state, psa_vector, game_state[2]])
    else: # 这个是 tic-tac-toe
        psa_vector = np.reshape(psa_vector, (row, column))

    # 棋盘旋转、翻转产生的参数
    for i in range(4):
        training_data.append([np.rot90(state, i),
                              np.rot90(psa_vector, i).flatten(),
                              game_state[2]])

        training_data.append([np.fliplr(np.rot90(state, i)),
                              np.fliplr(np.rot90(psa_vector, i)).flatten(),
                              game_state[2]])

你可能感兴趣的:(Alpha,Zero)