一个基于蒙特卡洛搜索树的五子棋实现

          最近有点烦啊,也有点无聊,去年研究德州扑克失败,后面知道AlphaZero都用了蒙特卡洛搜索树,那估计俺方向错误了?如是准备学习下这个东东,为深度学习攻克德州扑克做技术准备工作。这个东东理论上的介绍网络上实在是太多了,大部分也没有什么问题。但没有代码的实现的东西,感觉不是踏实,不靠谱。我想用什么方法来验证下我是否真正理解了这个东西了,那就做一个地球人都知道的五子棋来验证我的对这个算法的理解吧!

                我从网上把理论介绍最最关键的部分摘录如下(有我自己的修改)

    对于MCTS的树结构,如果是最简单的方法,只需要在节点上保存状态对应的历史胜负记录。在每条边上保存采样的动作。这样MCTS的搜索需要走4步,如下图(图来自维基百科):

    第一步是选择(Selection):这一步会从根节点开始,每次都选一个“最值得搜索的子节点”,一般使用UCT选择分数最高的节点,直到来到一个“存在未扩展的子节点”的节点,如图中的 3/3 节点。之所以叫做“存在未扩展的子节点”,是因为这个局面存在未走过的后续着法,也就是MCTS中没有后续的动作可以参考了。这时我们进入第二步。

          本人标注:选择过程中,因为如果第一层是对方,那么第二个层就是自己,接下来第三层就是对方,在选择的过程都是选择比值最大的点(假设对手也是很厉害啊)

    第二步是扩展(Expansion),在这个搜索到的存在未扩展的子节点,加上一个0/0的子节点,表示没有历史记录参考。这时我们进入第三步。

      本人标注:扩展的过程中要记得你动作不能重复(对五子棋就是不能再下已经有落点的地方),我是一次扩展完毕。

    第三步是仿真(simulation),从上面这个没有试过的着法开始,用一个简单策略比如快速走子策略(Rollout policy)走到底,得到一个胜负结果。快速走子策略一般适合选择走子很快可能不是很精确的策略。因为如果这个策略走得慢,结果虽然会更准确,但由于耗时多了,在单位时间内的模拟次数就少了,所以不一定会棋力更强,有可能会更弱。这也是为什么我们一般只模拟一次,因为如果模拟多次,虽然更准确,但更慢。

     本人标注:模拟过程,完全是随机走,走到底。

    第四步是回溯(backpropagation), 将我们最后得到的胜负结果回溯加到MCTS树结构上。注意除了之前的MCTS树要回溯外,新加入的节点也要加上一次胜负历史记录,如上图最右边所示。

      本人标注:赢了加1,输掉子增加访问节点数量

 

本来想用Python来实现,但考虑到用界面C# WinForm方便,就放弃了。到时候项目进展的时候再改成Python。代码如下,如果大家有问题欢迎指正,交流!

 

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Windows.Forms;

namespace WindowsFormsApp1
{
    public partial class Form1 : Form
    {

        public AutoResetEvent waitEvent = new AutoResetEvent(false);
        const int Row =8;
        const int Col =8;
        short[,] m_board = new short[Row, Col];
        short[,] m_boardSearch = new short[Row, Col];

        public class Pos
        {
            public short x;
            public short y;
            public Button btn=null;
            public Pos(short ax, short ay)
            {
                x = ax;
                y = ay;
            }
            
        }
        class TreeNode
        {
            public TreeNode parentNode;
            public List m_ChildNodes=new List();
            public double m_win_count = 0;
            public double m_visit_count = 0;
            public Pos m_pos;
            public bool IsAI;
        }
        public Form1()
        {
            InitializeComponent();
            CreateBoardUI();
        }

        Dictionary xy_pos_dic = new Dictionary();
        void CreateBoardUI()
        {
            int x_s=100, y_s=100;
            for(byte i=0;i= 5)
                                    return true;// "—"
                            }
                            else
                                h = 0;
                        }

                        for (int m = i + 1; m < Row; m++)
                        {
                            // V方向|
                            if (board[m, j] == who)
                            {
                                v += 1;
                                if (v >= 5)
                                    return true;//, "|"
                            }
                            else
                                v = 0;
                        }                       
                    }
                }

            Dictionary kb = new Dictionary();

            for (int i = -Col; i < Col; i++)// '/'
                kb.Add(i, -i);
            foreach (var k in kb.Keys)
            {//one line 
                p = 0;
                for (int x = 0; x < Col; x++)
                {
                  
                    int y =x + kb[k];

                    if (x >= 0 && x < Col
                        && y >= 0 && y < Row)
                    {
                        if (board[x, y] == who)
                        {
                            ++p;
                            if (p >= 5)
                                return true;
                        }
                        else
                            p = 0;
                    }
                }
            }
            kb.Clear();
            for (int i =0; i < Col*2; i++)// '\'
                kb.Add(-i, i);
            foreach (var k in kb.Keys)
            {//one line 
                l = 0;
                for (int x = 0; x < Col; x++)
                {
                    int y = -x + kb[k];

                    if (x >= 0 && x < Col
                        && y >= 0 && y < Row)
                    {
                        if (board[x, y] == who)
                        {
                            ++l;
                            if (l >= 5)
                                return true;
                        }
                        else
                            l = 0;
                    }
                }
            }
            return false;
        }

       
        TreeNode ExpandNodeOld(TreeNode node)//扩展,模拟
        {
            

            for (int i = 0; i <= 80; ++i)//数量如何控制
            {
                TreeNode oneNode = new TreeNode();
                oneNode.m_pos = GetAnEmptyPos();
                if (oneNode.m_pos.x == -1)
                    break;
                m_boardSearch[oneNode.m_pos.x, oneNode.m_pos.y] = 1;
                oneNode.parentNode = node;
                if (node.m_ChildNodes == null)
                    node.m_ChildNodes = new List();
                node.m_ChildNodes.Add(oneNode);
            }
            return node.m_ChildNodes.Count > 0 ? node.m_ChildNodes[0] : null;
        }
        TreeNode ExpandNode(TreeNode node)//扩展,模拟
        {
            for (short i = 0; i < Row; ++i)
                for (short j = 0; j < Col;++j)
                    if (m_boardSearch[i, j] == 0)
                    {
                        TreeNode oneNode = new TreeNode();
                        oneNode.IsAI = !node.IsAI;
                        oneNode.m_pos = xy_pos_dic[i + "-" + j];
                        m_boardSearch[i,j] = 1;
                        oneNode.parentNode = node;                       
                        node.m_ChildNodes.Add(oneNode);
                    }
            return node.m_ChildNodes.Count > 0 ? node.m_ChildNodes[0] : null;
        }
        Random m_rnd = new Random();
        Pos GetAnEmptyPos(int x1, int x2, int y1, int y2)
        {
            short x = -1, y = -1;
            bool cando = false;
           
            for (int i = x1; i < x2; i++)
            {
                for (int j = y1; j < y2; j++)
                {
                    if (m_boardSearch[i, j] == 0)
                    {
                        cando = true;
                        break;
                    }
                }
                if (cando)
                    break;
            }

            if(!cando)
                return new Pos(-1, -1);
            while (true)
            {
                var res = from m in m_boardSearch.Cast() where m == 0 select m;
                if (res.Count() == 0)
                {
                    return new Pos(-1, -1);
                }
                x = (short)m_rnd.Next(x1, x2);
                y = (short)m_rnd.Next(y1, y2);
                if (m_boardSearch[x, y] == 0)
                    break;
            }
            return xy_pos_dic[x + "-" + y];

        }
       

        Pos GetAnEmptyPos()
        {           
            short x = -1, y = -1;
            var res = from m in m_boardSearch.Cast() where m == 0 select m;
            if (res.Count() == 0)
            {
                return new Pos(-1, -1);
            }
            Random m_rnd = new Random();
            while (true)
            {               
                x = (short)m_rnd.Next(0, Row);
                y = (short)m_rnd.Next(0, Col);
                if (m_boardSearch[x, y] == 0)
                    break;
            }
            return xy_pos_dic[x + "-" + y];

        }
        TreeNode mcts_select(bool isMe,TreeNode pNode)//选择少叶子节点再扩展
        {          
            if(pNode.m_ChildNodes.Count==0)
            {               
                return pNode;
            }

            double max_score = 0;
            bool isFirst = true;
            TreeNode bestNode = null;
            List canSelectList = new List();
            foreach (var node in pNode.m_ChildNodes)
            {//探索最优点
                if(node.m_visit_count==0)//优先选择未采用过的节点
                {
                    bestNode = node;
                    canSelectList.Clear();
                    canSelectList.Add(node);
                    break;
                }
                double score = node.m_win_count / node.m_visit_count + 1.414 * Math.Sqrt(Math.Log(pNode.m_visit_count) / node.m_visit_count);
                score = Math.Round(score, 4);
               // if (isMe)
                {
                    if (max_score < score || isFirst)
                    {
                        max_score = score;
                        bestNode = node;
                        canSelectList.Clear();
                        canSelectList.Add(bestNode);
                        isFirst = false;
                    }
                    else if (max_score == score)
                    {
                        canSelectList.Add(node);
                    }
                }
                //else
                //{
                //    if (max_score > score || isFirst)
                //    {
                //        max_score = score;
                //        canSelectList.Clear();
                //        bestNode = node;
                //        canSelectList.Add(bestNode);
                //        isFirst = false;
                //    }
                //    else if (max_score == score)
                //        canSelectList.Add(node);
                //}
            }
            bestNode = canSelectList[m_rnd.Next(0, canSelectList.Count)];//这里导致expand失败。

            m_boardSearch[bestNode.m_pos.x, bestNode.m_pos.y] = isMe ?(short) 1 : (short)2;
            return  mcts_select(!isMe, bestNode);
        }
        void GotoNext(TreeNode rootNode)
        {
            TreeNode bestNode=null;
            foreach(var node in rootNode.m_ChildNodes)
            {
                if (bestNode == null)
                    bestNode = node;
                 else if ( bestNode.m_visit_count < node.m_visit_count)
                // else if( bestNode.m_win_count/ bestNode.m_visit_count < node.m_win_count/ node.m_visit_count)
                    bestNode = node;
            }
            if (bestNode.m_pos.x <= 1 || bestNode.m_pos.x >= 7)
                Console.WriteLine("aaaa");
            if (bestNode.m_pos.y <= 1 || bestNode.m_pos.y >= 7)
                Console.WriteLine("bbbb");
            m_board[bestNode.m_pos.x, bestNode.m_pos.y] = 1;
            UpdateBoardUI(bestNode.m_pos);
        }
        private void FillBoard(bool isMe, TreeNode leafNode)
        {
            m_boardSearch[leafNode.m_pos.x, leafNode.m_pos.y] = isMe ? (short)1 : (short)2;
            if (leafNode.parentNode != null
                && leafNode.parentNode.m_pos!=null)
                FillBoard(!isMe, leafNode.parentNode);
        }
        //private void FillBoard(bool isMe, TreeNode leafNode)
        //{
        //    if (leafNode.parentNode == null)
        //        return;
        //    m_boardSearch[leafNode.m_pos.x, leafNode.m_pos.y] = isMe ? (short)1 : (short)2;
        //    FillBoard(!isMe, leafNode.parentNode);
        //}

        private void StartBtn_Click(object sender, EventArgs e)
        {
            for (int i = 0; i < Row; ++i)
                for (int j = 0; j < Col; ++j)
                    m_board[i, j] = 0;
            foreach (var one in xy_pos_dic.Values)
                one.btn.Text = "";
            Thread th = new Thread(this.AIMainThread);
            th.IsBackground = true;
            th.Start();
        }
        private bool m_AI_thinking = false;
        private void AIMainThread()
        {
            waitEvent.Reset();//后手
            while (true)
            {
                waitEvent.WaitOne();
                if (GameOver(m_board, 2))
                {
                    UpdateMsg("恭喜,你赢啦!");
                    break;
                }
                m_AI_thinking = true;
                UpdateMsg("AI思考中...");
                TreeNode rootNode = new TreeNode();
                rootNode.m_visit_count = 1;
                rootNode.IsAI = false;
                int count = 9000;// 1500+9655;//

                for (int i = 0; i < Row; ++i)
                    for (int j = 0; j < Col; ++j)
                        m_boardSearch[i, j] = m_board[i, j];
                
                DateTime dtStart = DateTime.Now;

              //  while (--count > 0)//模拟次数
                //while((--count > 0))//|| 
                while((DateTime.Now- dtStart).TotalMinutes<=0.33)
                {
                  //  Console.WriteLine("\n1    " + DateTime.Now + ":" + DateTime.Now.Millisecond);
                    for (int i = 0; i < Row; ++i)
                        for (int j = 0; j < Col; ++j)
                            m_boardSearch[i, j] = m_board[i, j];

                    TreeNode leafNode = mcts_select(true,rootNode);//在m_boardSearch copy m_board的基础上模拟,一条直线下来。
                  //  Console.WriteLine("\n2    " + DateTime.Now + ":" + DateTime.Now.Millisecond);

                    TreeNode expNode = leafNode.m_ChildNodes.Count == 0&& leafNode.m_visit_count>0 ? ExpandNodeOld(leafNode) : leafNode;
                    if (expNode == null)
                        break;
                    for (int i = 0; i < Row; ++i)
                        for (int j = 0; j < Col; ++j)
                            m_boardSearch[i, j] = m_board[i, j];
                  //  Console.WriteLine("\n3    " + DateTime.Now + ":" + DateTime.Now.Millisecond);
                    FillBoard(true, expNode);//回填下棋步骤(直线回传)
                  //  Console.WriteLine("\n4    " + DateTime.Now + ":" + DateTime.Now.Millisecond);
                    int res = StartSimulate(expNode);
                  //  Console.WriteLine("\n5    " + DateTime.Now + ":" + DateTime.Now.Millisecond);
                    BackUp(res == 1, res == 0, expNode);
                 //   Console.WriteLine("\n6    " + DateTime.Now + ":" + DateTime.Now.Millisecond);
                }
                GotoNext(rootNode);//真正选择最佳策略.
                if (GameOver(m_board, 1))
                {
                    UpdateMsg("AI赢啦!");
                    break;
                }
                if (GameOver(m_board, 2))
                {
                    UpdateMsg("恭喜,你赢啦!");
                    break;
                }
                UpdateMsg("请开始选择...");
                m_AI_thinking = false;
            }
            m_AI_thinking = false;
           
        }

        public void UpdateMsg(string str)
        {
            this.BeginInvoke(new Action(() =>
            {
                SYlabel.Text = str;
            }));
        }
        public void UpdateBoardUI(Pos p)
        {
            this.BeginInvoke(new Action(() =>
            {
                p.btn.Text = "1";
                p.btn.ForeColor = Color.Black;
                p.btn.Focus();
            }));
        }
    }
}

一个基于蒙特卡洛搜索树的五子棋实现_第1张图片

可以看到在8*8 ,10*10 ,搜索时间在20秒内,AI基本能下棋正确!

你可能感兴趣的:(C#,AI,TensorFlow)