通过简单的强化学习实现井字棋(Tic-Tac-Toe)

一、强化学习简介

强化学习的过程可以理解为Agent与Environment的交互、学习、进步的过程,在井字棋中,可以简单的将其中的一方理解为Agent,另一方为Environment。交互的过程中主要有一下4个要素:

  • 状态(state):指可能出现的情况或局面,在井字棋中指局面上的落子情况与先后手。
  • 操作(action):指从一个状态(state)到另一个状态(state)的过程,在井字棋中指下一步的落子。
  • 价值(value):衡量每一个状态(state)的好坏程度,在井字棋中指在当前局面下取胜的可能性。价值与状态的对应关系又称价值函数(value function)
  • 增益(reward):指每一个操作所带来的价值(value)的变化,在井字棋中指这一步下的好不好。

强化学习实际是在模拟人类学习的过程。设想一个小孩在学习井字棋,一开始他完全不会下,只知道胜利局面的value是最高的,失败局面的value是最低的,只能随便下一个位置(explore),这样随便下几局之后,他会学习到,将要导致胜利的前几步的value是比较高的,将要导致失败的前几步的value是比较低的,这时通过explore使他对每个状态value的有了新的认识。当然他也会利用现有的知识(exploit),每当落子前,他会考虑落子后棋盘上可能的所有局面(state),然后选择一个对自己最有利的(max value)位置下棋。通过exploit,他可以利用他自己的知识使自己获胜的概率达到最大。

二、井字棋的算法

1、训练

强化学习井字棋训练的过程如下:

重复epochs次
    while true:
        if 分出胜负或平局
            返回结果,break
        随机的选择explore或exploit
        if 选择explore
            随机的选择落点下棋
        else 选择exploit
            从value_table中查找对应最大value状态的落点下棋
            根据新状态的value在value_table中更新原状态的value

其中“根据新状态的value在value_table中更新原状态的value” 是非常重要的一部分,决定了强化学习的学习方法,即能不能学到知识。由于井字棋状态逻辑非常简单,因此使用如下简单的表达式即可:

V(S)=V(S)+α(V(S)V(S)) V ( S ) = V ( S ) + α ( V ( S ′ ) − V ( S ) )

其中 V V 表示value function, S S 表示当前状态, S S ′ 表示新状态, V(S) V ( S ) 表示S的value, α α 表示学习率,是可以调整的超参。

另外还需要控制的参数有训练次数epochs和选择explore的概率 ϵ ϵ

2、测试

测试井字棋即为使双方反复的按照训练出来的value table进行对弈,统计一方胜利、失败、平局的次数。由于井字棋有不败的策略,所以好的模型应该是可以保证下成平局的。测试的过程如下

重复epochs次
    while true:
        if 分出胜负或平局
            记录结果,break
        elsevalue table中选择可以到达value最大的state的落点并下棋
输出统计结果

三、井字棋训练的实现

# python 3.5
import math
import random

# 1 for cross, 0 for circle

class State:
    def __init__(self, board, turn=1):
        self.board = board
        self.turn = turn

    def __str__(self):
        return str(self.board) + str(self.turn)

    def __eq__(self, other):
        return self.board == other.board and self.turn == other.turn


class Model:
    CROSS_WIN = 1
    CIRCLE_WIN = -1
    DRAW = 0
    NOT_FINISHED = 2333

    def __init__(self, learning_rate, explore_rate, training_epoch=1000):
        self.value_table = {}
        self.learning_rate = learning_rate
        self.explore_rate = explore_rate
        self.training_epoch = training_epoch
        pass

    @staticmethod
    def result(state):
        board = state.board
        if (board[0] + board[3] + board[6] == 3
                or board[1] + board[4] + board[7] == 3
                or board[2] + board[5] + board[8] == 3
                or board[0] + board[4] + board[8] == 3
                or board[2] + board[4] + board[6] == 3):
            return Model.CROSS_WIN

        if (board[0] + board[3] + board[6] == -3
                or board[1] + board[4] + board[7] == -3
                or board[2] + board[5] + board[8] == -3
                or board[0] + board[4] + board[8] == -3
                or board[2] + board[4] + board[6] == -3):
            return Model.CIRCLE_WIN

        if sum(map(abs, board)) == 9:
            return Model.DRAW

        return Model.NOT_FINISHED

    def get_next_states(self, state):
        next_state_ids = []
        for i in range(len(state.board)):
            if state.board[i] == 0:
                next_state_ids.append(i)
        return next_state_ids

    def next_state(self, state, i):
        board = state.board[:]
        board[i] = state.turn
        return State(board, -state.turn)

    def explore(self, state):
        # get all states
        next_state_ids = self.get_next_states(state)
        if len(next_state_ids) == 0:
            return -1, -1

        # if state not in value table, set an initial value
        for i in next_state_ids:
            next_state = self.next_state(state, i)

            key = str(next_state)
            if key not in self.value_table:
                self.value_table[key] = 0.5

        # select next state randomly
        return random.choice(next_state_ids)

    def exploit(self, state):
        # select the state with highest (or lowest) value
        next_state_ids = self.get_next_states(state)
        if len(next_state_ids) == 0:
            return -1, -1

        if state.turn == 1:
            # cross turn, select the highest
            next_step = -1
            value = -math.inf
            for i in next_state_ids:
                next_state = self.next_state(state, i)
                key = str(next_state)

                # select the highest value
                if key in self.value_table:
                    if self.value_table[key] > value:
                        value = self.value_table[key]
                        next_step = i
                # set initial value for states not in value table
                else:
                    self.value_table[key] = 0.5     # set initial value of new state
                    if next_step == -1:
                        value = 0.5
                        next_step = i
            return next_step, value
        elif state.turn == -1:
            # circle turn, select lowest
            next_step = -1
            value = math.inf
            for i in next_state_ids:
                next_state = self.next_state(state, i)
                key = str(next_state)

                # select the lowest value
                if key in self.value_table:
                    if self.value_table[key] < value:
                        value = self.value_table[key]
                        next_step = i
                # set initial value for states not in value table
                else:
                    self.value_table[key] = 0.5  # set initial value of new state
                    if next_step == -1:
                        value = 0.5
                        next_step = i
            return next_step, value

    def train(self):
        for i in range(self.training_epoch):
            # get initial state, circle first
            board = [0 for _ in range(9)]
            board[random.randint(0, 8)] = -1
            state = State(board, 1)
            self.value_table[str(state)] = 0.5

            print("Train game %d: " % i, end="")

            # play one game
            while True:
                if self.result(state) == Model.CROSS_WIN:
                    self.value_table[str(state)] = 1
                    print("cross win", end=" ")
                    break
                elif self.result(state) == Model.CIRCLE_WIN:
                    self.value_table[str(state)] = 0
                    print("circle win", end=" ")
                    break
                elif self.result(state) == Model.DRAW:
                    self.value_table[str(state)] = 0.5
                    print("draw", end=" ")
                    break
                else:
                    print(str(state.board), end=" ")
                    if random.uniform(0, 1) < self.explore_rate:
                        next_step = self.explore(state)
                        if next_step == -1:
                            break
                        state = self.next_state(state, next_step)
                    else:
                        next_step, value = self.exploit(state)
                        if next_step == -1:
                            break

                        self.value_table[str(state)] += \
                            self.learning_rate * (value - self.value_table[str(state)])
                        state = self.next_state(state, next_step)
            print("")

    def test(self, test_epochs=10000):
        cnt_cross, cnt_circle, cnt_draw = 0, 0, 0
        for i in range(test_epochs):
            # get initial state, circle first
            board = [0 for _ in range(9)]
            board[random.randint(0, 8)] = -1
            state = State(board, 1)
            self.value_table[str(state)] = 0.5

            print("Test game %d: " % i, end="")

            while True:
                if self.result(state) == Model.CROSS_WIN:
                    self.value_table[str(state)] = 1
                    print("cross win", end=" ")
                    cnt_cross += 1
                    break
                elif self.result(state) == Model.CIRCLE_WIN:
                    self.value_table[str(state)] = 0
                    print("circle win", end=" ")
                    cnt_circle += 1
                    break
                elif self.result(state) == Model.DRAW:
                    self.value_table[str(state)] = 0.5
                    print("draw", end=" ")
                    cnt_draw += 1
                    break
                else:
                    print(str(state.board), end=" ")
                    next_step, value = self.exploit(state)
                    if next_step == -1:
                        break
                    state = self.next_state(state, next_step)
            print("")
        print("Cross win %d, Circle win %d, Draw %d" % (cnt_cross, cnt_circle, cnt_draw))
        print(repr(self.value_table))


model = Model(learning_rate=0.01, explore_rate=0.2, training_epoch=1000)
model.train()
print("Start Testing ...")
model.test(10000)

四、代码测试

设置学习率为0.01,探索率为0.2,训练对局1000次后,经过10000次测试对局,发现所有的对局均为平局。

Cross win 0, Circle win 0, Draw 10000

这就意味着训练出的模型已经学会了如何下井字棋,测试通过。

你可能感兴趣的:(强化学习)