蒙特卡洛树搜索(Monte Carlo Tree Search,MCTS)是一种用于在决策过程中寻找最优策略的启发式搜索算法,广泛应用于组合博弈、人工智能等领域。MCTS结合了随机模拟的广泛覆盖性和树搜索的精确性,能够在复杂的决策空间中有效地探索和利用。
1. 工作原理
MCTS 的核心思想是通过迭代地构建和更新一棵搜索树来优化决策过程,该算法主要包含以下四个步骤:
上述过程重复多次,每次迭代都加深搜索树的构建和完善。最终,根据搜索树中各节点的统计信息,选择访问次数最多或价值最高的动作作为最优决策。
2. MCTS在通用 AI Agent 中的扩展应用
例如下是一个基于蒙特卡洛树搜索(MCTS)的通用 AI Agent 的例子,以井字棋(Tic-Tac-Toe)游戏为环境演示了MCTS的实现流程,包含环境定义、MCTS 算法实现和 Agent 决策逻辑。
实例5-4:基于MCTS的通用 AI Agent(源码路径:codes\5\Meng.py)
实例文件Meng.py的具体实现代码如下所示。
# ---------------- 环境定义 ---------------- #
class TicTacToeEnv:
def __init__(self):
self.board = np.zeros((3, 3), dtype=int) # 0: 空,1:玩家1(AI),-1:玩家2
self.current_player = 1 # 当前玩家(1 或 -1)
self.winner = None
self.done = False
def reset(self):
self.board = np.zeros((3, 3), dtype=int)
self.current_player = 1
self.winner = None
self.done = False
return self.board.copy()
def step(self, action):
row, col = action
if self.board[row, col] != 0:
return self.board.copy(), -1, True, {} # 非法动作惩罚
self.board[row, col] = self.current_player
self.check_win()
reward = self.get_reward()
info = {}
if self.done:
info['winner'] = self.winner
else:
self.current_player *= -1 # 切换玩家
return self.board.copy(), reward, self.done, info
def check_win(self):
player = self.current_player
# 检查行
for row in range(3):
if np.all(self.board[row, :] == player):
self.winner = player
self.done = True
return
# 检查列
for col in range(3):
if np.all(self.board[:, col] == player):
self.winner = player
self.done = True
return
# 检查对角线
if (np.all(np.diag(self.board) == player) or
np.all(np.diag(np.fliplr(self.board)) == player)):
self.winner = player
self.done = True
return
# 检查平局
if np.all(self.board != 0):
self.done = True
self.winner = 0
def get_reward(self):
if self.winner == self.current_player:
return 1
elif self.winner == 0:
return 0.5
elif self.winner == -self.current_player:
return -1
else:
return 0
def get_valid_actions(self):
return [(row, col) for row in range(3) for col in range(3) if self.board[row, col] == 0]
def render(self):
symbols = {0: ' ', 1: 'X', -1: 'O'}
for row in range(3):
print("|".join([symbols[cell] for cell in self.board[row]]))
if row != 2:
print("-----")
# ---------------- MCTS 核心实现 ---------------- #
class MCTSNode:
def __init__(self, parent, action, state):
self.parent = parent
self.children = {}
self.action = action
self.state = state
self.visits = 0
self.value = 0.0
def is_leaf(self):
return len(self.children) == 0
def is_root(self):
return self.parent is None
def select_child(self):
C = 1.4 # 探索常数
best_score = -float("inf")
best_child = None
for child in self.children.values():
if child.visits == 0:
return child
score = (child.value / child.visits) + C * math.sqrt(math.log(self.visits) / child.visits)
if score > best_score:
best_score = score
best_child = child
return best_child
def expand(self, env):
valid_actions = env.get_valid_actions()
for action in valid_actions:
if action not in self.children:
new_env = deepcopy(env)
new_env.board = self.state.copy()
new_env.current_player = 1
_, _, _, _ = new_env.step(action)
new_state = new_env.board.copy()
new_node = MCTSNode(parent=self, action=action, state=new_state)
self.children[action] = new_node
return self.children.values()
def backpropagate(self, reward):
self.visits += 1
self.value += reward
if self.parent:
self.parent.backpropagate(reward)
class MCTS:
def __init__(self, env, simulations=1000):
self.env = env
self.simulations = simulations
def search(self, initial_state):
root = MCTSNode(parent=None, action=None, state=initial_state)
for _ in range(self.simulations):
node = root
current_env = deepcopy(self.env)
current_env.board = node.state.copy()
current_env.current_player = 1
while not node.is_leaf():
node = node.select_child()
action = node.action
current_env.board = node.state.copy()
_, _, done, _ = current_env.step(action)
if done:
break
if not node.is_leaf() or not current_env.done:
node.expand(current_env)
# 模拟
current_env_sim = deepcopy(current_env)
while not current_env_sim.done:
valid_actions = current_env_sim.get_valid_actions()
action = random.choice(valid_actions)
_, _, done, _ = current_env_sim.step(action)
reward = current_env_sim.get_reward()
node.backpropagate(reward)
best_action = max(root.children.items(), key=lambda x: x[1].visits)[0]
return best_action
# ---------------- AI Agent ---------------- #
class MCTSAgent:
def __init__(self, env):
self.env = env
self.mcts = MCTS(env=env, simulations=1000)
def act(self, state):
return self.mcts.search(state)
# ---------------- 主程序 ---------------- #
def main():
env = TicTacToeEnv()
agent = MCTSAgent(env)
state = env.reset()
done = False
while not done:
action = agent.act(state)
print(f"AI选择动作: {action}")
next_state, reward, done, info = env.step(action)
state = next_state
env.render()
if done:
if info.get('winner') == 1:
print("AI 获胜!")
elif info.get('winner') == -1:
print("玩家2 获胜!")
else:
print("平局!")
break
if __name__ == "__main__":
main()
对上述代码的具体说明如下所示:
(1)环境(TicTacToeEnv)
(2)MCTS 核心流程
(3)Agent 决策
通过本实例可以清晰地看到 MCTS 在通用 AI Agent 中的应用流程:从环境交互到树结构构建,再到基于统计的最优决策。执行后会输出:
AI选择动作: (2, 2)
| |
-----
| |
-----
| |X
AI选择动作: (0, 0)
O| |
-----
| |
-----
| |X
AI选择动作: (1, 2)
O| |
-----
| |X
-----
| |X
AI选择动作: (0, 1)
O|O|
-----
| |X
-----
| |X
AI选择动作: (0, 2)
O|O|X
-----
| |X
-----
| |X
AI 获胜!
总之,MCTS 是一种强大的决策制定算法,广泛应用于棋类游戏、自动驾驶、机器人控制、自动化规划和大语言模型等领域。通过随机模拟和树搜索,MCTS 能够在复杂的决策空间中找到最优的行动路径。尽管 MCTS 需要大量的计算资源,但其在处理复杂问题时的高效性和适应性使其成为许多领域的首选算法。