启动 TicTacToe ,会进入网络训练或人机对弈。如果不设置启动参数 –human_play=1,便自动进入网络训练。即:
python main.py
开始训练前,要在 main.py 中做两件事情。第一,初始化游戏和网络:
if CFG.game == 0:
game = TicTacToeGame()
net = NeuralNetworkWrapper(game)
第二,初始化训练类,开始训练:
train = Train(game, net)
train.start()
类 Train 的初始化:
def __init__(self, game, net):
"""Initializes Train with the board state and neural network."""
self.game = game
self.net = net
self.eval_net = NeuralNetworkWrapper(game)
在 Train.start() 中,初始化了神经网络、估值网络和估值器:
current_mcts = MonteCarloTreeSearch(self.net)
eval_mcts = MonteCarloTreeSearch(self.eval_net)
evaluator = Evaluate(current_mcts=current_mcts, eval_mcts=eval_mcts, game=self.game)
综上所见,网络训练主要由以下模块和类协作实现:
这些模块和类,每一个都有丰富的知识内容,应该逐一设置专题判读。在此,仅仅关注 Train。
Train 的训练管理任务,主要由其 start 方法负责。直接用代码说事儿:
def start(self):
"""训练的主循环"""
for i in range(CFG.num_iterations): # 共计 4 轮次
print("迭代", i + 1) # 第 i + 1 轮自学
training_data = [] # 保存自学中的棋局, pis 和 vs
for j in range(CFG.num_games): # 每轮下棋30局
print("开始自学训练的棋局", j + 1) #
game = self.game.clone() # 对每局棋进行克隆
self.play_game(game, training_data) # 对弈生成棋局
# 保存当前网络模型
self.net.save_model()
# 把最近保存的模型装入估值网络
self.eval_net.load_model()
# 用自学棋局的估值训练网络
self.net.train(training_data)
# 初始化蒙特卡罗搜索树,以用于上述两种网络
current_mcts = MonteCarloTreeSearch(self.net)
eval_mcts = MonteCarloTreeSearch(self.eval_net)
# 初始化估值器
evaluator = Evaluate(current_mcts=current_mcts,
eval_mcts=eval_mcts, game=self.game)
# 算出胜负局数
wins, losses = evaluator.evaluate()
print("胜:", wins)
print("负:", losses)
# 计算胜率
num_games = wins + losses
if num_games == 0:
win_rate = 0
else:
win_rate = wins / num_games
print("胜率:", win_rate)
# 若胜率超过已有成绩
if win_rate > CFG.eval_win_rate:
# 保存当前模型为最佳模型
print("新模型存为最佳模型")
self.net.save_model("best_model")
# 若胜率未超已有成绩
else:
print("舍弃新模型,载入已有模型")
# 舍弃当前模型,启用已有最佳模型
self.net.load_model()
Train 还有2个函数(方法),分别是:
def play_game(self, game, training_data):
"""
每一棋局自学训练的循环
对每一棋局运行 MCTS ,并根据 MCTS 的输出进行下一步棋
在棋局结束时,终止循环并打印输出获胜一方
参量:
game: 棋局对象
training_data: 列表,保存着自学训练的棋局和 pis ,vs
"""
mcts = MonteCarloTreeSearch(self.net)
game_over = False
value = 0
self_play_data = []
count = 0
node = TreeNode()
# 在棋局终局前一直下棋
while not game_over:
# MCTS 模拟,得到最好的下一步棋
# 根据下棋的步数是否小于探测控制的阈值,作出选择
if count < CFG.temp_thresh:
best_child = mcts.search(game, node, CFG.temp_init)
else:
best_child = mcts.search(game, node, CFG.temp_final)
# 保存棋局、概率prob 和估值v
self_play_data.append([deepcopy(game.state),
deepcopy(best_child.parent.child_psas), 0])
action = best_child.action
game.play_action(action) # 棋局的下一步
count += 1
game_over, value = game.check_game_over(game.current_player)
best_child.parent = None
node = best_child # Make the child node the root node.
# 更新估值,作为棋局结果
for game_state in self_play_data:
value = -value
game_state[2] = value
self.augment_data(game_state, training_data, game.row, game.column)
def augment_data(self, game_state, training_data, row, column):
"""
每一局棋自学训练的循环
对每一棋局运行 MCTS ,并根据 MCTS 的输出进行下一步棋
在棋局结束时,终止循环并打印输出获胜一方
参量:
game_state: 以棋局、概率和估值为内容的对象
training_data: 列表,保存自学训练的棋局、概率和估值
row: 整数,表示棋盘横线数目
column: 整数,表示棋盘竖线数目
"""
# deepcopy 深复制,即将被复制对象完全再复制一遍作为独立的新个体单独存在。
# 所以改变原有被复制对象不会对已经复制出来的新对象产生影响。
state = deepcopy(game_state[0])
psa_vector = deepcopy(game_state[1])
if CFG.game == 2 or CFG.game == 1:
training_data.append([state, psa_vector, game_state[2]])
else: # 这个是 tic-tac-toe
psa_vector = np.reshape(psa_vector, (row, column))
# 棋盘旋转、翻转产生的参数
for i in range(4):
training_data.append([np.rot90(state, i),
np.rot90(psa_vector, i).flatten(),
game_state[2]])
training_data.append([np.fliplr(np.rot90(state, i)),
np.fliplr(np.rot90(psa_vector, i)).flatten(),
game_state[2]])