最近有点烦啊,也有点无聊,去年研究德州扑克失败,后面知道AlphaZero都用了蒙特卡洛搜索树,那估计俺方向错误了?如是准备学习下这个东东,为深度学习攻克德州扑克做技术准备工作。这个东东理论上的介绍网络上实在是太多了,大部分也没有什么问题。但没有代码的实现的东西,感觉不是踏实,不靠谱。我想用什么方法来验证下我是否真正理解了这个东西了,那就做一个地球人都知道的五子棋来验证我的对这个算法的理解吧!
我从网上把理论介绍最最关键的部分摘录如下(有我自己的修改)
对于MCTS的树结构,如果是最简单的方法,只需要在节点上保存状态对应的历史胜负记录。在每条边上保存采样的动作。这样MCTS的搜索需要走4步,如下图(图来自维基百科):
第一步是选择(Selection):这一步会从根节点开始,每次都选一个“最值得搜索的子节点”,一般使用UCT选择分数最高的节点,直到来到一个“存在未扩展的子节点”的节点,如图中的 3/3 节点。之所以叫做“存在未扩展的子节点”,是因为这个局面存在未走过的后续着法,也就是MCTS中没有后续的动作可以参考了。这时我们进入第二步。
本人标注:选择过程中,因为如果第一层是对方,那么第二个层就是自己,接下来第三层就是对方,在选择的过程都是选择比值最大的点(假设对手也是很厉害啊)
第二步是扩展(Expansion),在这个搜索到的存在未扩展的子节点,加上一个0/0的子节点,表示没有历史记录参考。这时我们进入第三步。
本人标注:扩展的过程中要记得你动作不能重复(对五子棋就是不能再下已经有落点的地方),我是一次扩展完毕。
第三步是仿真(simulation),从上面这个没有试过的着法开始,用一个简单策略比如快速走子策略(Rollout policy)走到底,得到一个胜负结果。快速走子策略一般适合选择走子很快可能不是很精确的策略。因为如果这个策略走得慢,结果虽然会更准确,但由于耗时多了,在单位时间内的模拟次数就少了,所以不一定会棋力更强,有可能会更弱。这也是为什么我们一般只模拟一次,因为如果模拟多次,虽然更准确,但更慢。
本人标注:模拟过程,完全是随机走,走到底。
第四步是回溯(backpropagation), 将我们最后得到的胜负结果回溯加到MCTS树结构上。注意除了之前的MCTS树要回溯外,新加入的节点也要加上一次胜负历史记录,如上图最右边所示。
本人标注:赢了加1,输掉子增加访问节点数量
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Windows.Forms;
namespace WindowsFormsApp1
{
public partial class Form1 : Form
{
public AutoResetEvent waitEvent = new AutoResetEvent(false);
const int Row =8;
const int Col =8;
short[,] m_board = new short[Row, Col];
short[,] m_boardSearch = new short[Row, Col];
public class Pos
{
public short x;
public short y;
public Button btn=null;
public Pos(short ax, short ay)
{
x = ax;
y = ay;
}
}
class TreeNode
{
public TreeNode parentNode;
public List m_ChildNodes=new List();
public double m_win_count = 0;
public double m_visit_count = 0;
public Pos m_pos;
public bool IsAI;
}
public Form1()
{
InitializeComponent();
CreateBoardUI();
}
Dictionary xy_pos_dic = new Dictionary();
void CreateBoardUI()
{
int x_s=100, y_s=100;
for(byte i=0;i= 5)
return true;// "—"
}
else
h = 0;
}
for (int m = i + 1; m < Row; m++)
{
// V方向|
if (board[m, j] == who)
{
v += 1;
if (v >= 5)
return true;//, "|"
}
else
v = 0;
}
}
}
Dictionary kb = new Dictionary();
for (int i = -Col; i < Col; i++)// '/'
kb.Add(i, -i);
foreach (var k in kb.Keys)
{//one line
p = 0;
for (int x = 0; x < Col; x++)
{
int y =x + kb[k];
if (x >= 0 && x < Col
&& y >= 0 && y < Row)
{
if (board[x, y] == who)
{
++p;
if (p >= 5)
return true;
}
else
p = 0;
}
}
}
kb.Clear();
for (int i =0; i < Col*2; i++)// '\'
kb.Add(-i, i);
foreach (var k in kb.Keys)
{//one line
l = 0;
for (int x = 0; x < Col; x++)
{
int y = -x + kb[k];
if (x >= 0 && x < Col
&& y >= 0 && y < Row)
{
if (board[x, y] == who)
{
++l;
if (l >= 5)
return true;
}
else
l = 0;
}
}
}
return false;
}
TreeNode ExpandNodeOld(TreeNode node)//扩展,模拟
{
for (int i = 0; i <= 80; ++i)//数量如何控制
{
TreeNode oneNode = new TreeNode();
oneNode.m_pos = GetAnEmptyPos();
if (oneNode.m_pos.x == -1)
break;
m_boardSearch[oneNode.m_pos.x, oneNode.m_pos.y] = 1;
oneNode.parentNode = node;
if (node.m_ChildNodes == null)
node.m_ChildNodes = new List();
node.m_ChildNodes.Add(oneNode);
}
return node.m_ChildNodes.Count > 0 ? node.m_ChildNodes[0] : null;
}
TreeNode ExpandNode(TreeNode node)//扩展,模拟
{
for (short i = 0; i < Row; ++i)
for (short j = 0; j < Col;++j)
if (m_boardSearch[i, j] == 0)
{
TreeNode oneNode = new TreeNode();
oneNode.IsAI = !node.IsAI;
oneNode.m_pos = xy_pos_dic[i + "-" + j];
m_boardSearch[i,j] = 1;
oneNode.parentNode = node;
node.m_ChildNodes.Add(oneNode);
}
return node.m_ChildNodes.Count > 0 ? node.m_ChildNodes[0] : null;
}
Random m_rnd = new Random();
Pos GetAnEmptyPos(int x1, int x2, int y1, int y2)
{
short x = -1, y = -1;
bool cando = false;
for (int i = x1; i < x2; i++)
{
for (int j = y1; j < y2; j++)
{
if (m_boardSearch[i, j] == 0)
{
cando = true;
break;
}
}
if (cando)
break;
}
if(!cando)
return new Pos(-1, -1);
while (true)
{
var res = from m in m_boardSearch.Cast() where m == 0 select m;
if (res.Count() == 0)
{
return new Pos(-1, -1);
}
x = (short)m_rnd.Next(x1, x2);
y = (short)m_rnd.Next(y1, y2);
if (m_boardSearch[x, y] == 0)
break;
}
return xy_pos_dic[x + "-" + y];
}
Pos GetAnEmptyPos()
{
short x = -1, y = -1;
var res = from m in m_boardSearch.Cast() where m == 0 select m;
if (res.Count() == 0)
{
return new Pos(-1, -1);
}
Random m_rnd = new Random();
while (true)
{
x = (short)m_rnd.Next(0, Row);
y = (short)m_rnd.Next(0, Col);
if (m_boardSearch[x, y] == 0)
break;
}
return xy_pos_dic[x + "-" + y];
}
TreeNode mcts_select(bool isMe,TreeNode pNode)//选择少叶子节点再扩展
{
if(pNode.m_ChildNodes.Count==0)
{
return pNode;
}
double max_score = 0;
bool isFirst = true;
TreeNode bestNode = null;
List canSelectList = new List();
foreach (var node in pNode.m_ChildNodes)
{//探索最优点
if(node.m_visit_count==0)//优先选择未采用过的节点
{
bestNode = node;
canSelectList.Clear();
canSelectList.Add(node);
break;
}
double score = node.m_win_count / node.m_visit_count + 1.414 * Math.Sqrt(Math.Log(pNode.m_visit_count) / node.m_visit_count);
score = Math.Round(score, 4);
// if (isMe)
{
if (max_score < score || isFirst)
{
max_score = score;
bestNode = node;
canSelectList.Clear();
canSelectList.Add(bestNode);
isFirst = false;
}
else if (max_score == score)
{
canSelectList.Add(node);
}
}
//else
//{
// if (max_score > score || isFirst)
// {
// max_score = score;
// canSelectList.Clear();
// bestNode = node;
// canSelectList.Add(bestNode);
// isFirst = false;
// }
// else if (max_score == score)
// canSelectList.Add(node);
//}
}
bestNode = canSelectList[m_rnd.Next(0, canSelectList.Count)];//这里导致expand失败。
m_boardSearch[bestNode.m_pos.x, bestNode.m_pos.y] = isMe ?(short) 1 : (short)2;
return mcts_select(!isMe, bestNode);
}
void GotoNext(TreeNode rootNode)
{
TreeNode bestNode=null;
foreach(var node in rootNode.m_ChildNodes)
{
if (bestNode == null)
bestNode = node;
else if ( bestNode.m_visit_count < node.m_visit_count)
// else if( bestNode.m_win_count/ bestNode.m_visit_count < node.m_win_count/ node.m_visit_count)
bestNode = node;
}
if (bestNode.m_pos.x <= 1 || bestNode.m_pos.x >= 7)
Console.WriteLine("aaaa");
if (bestNode.m_pos.y <= 1 || bestNode.m_pos.y >= 7)
Console.WriteLine("bbbb");
m_board[bestNode.m_pos.x, bestNode.m_pos.y] = 1;
UpdateBoardUI(bestNode.m_pos);
}
private void FillBoard(bool isMe, TreeNode leafNode)
{
m_boardSearch[leafNode.m_pos.x, leafNode.m_pos.y] = isMe ? (short)1 : (short)2;
if (leafNode.parentNode != null
&& leafNode.parentNode.m_pos!=null)
FillBoard(!isMe, leafNode.parentNode);
}
//private void FillBoard(bool isMe, TreeNode leafNode)
//{
// if (leafNode.parentNode == null)
// return;
// m_boardSearch[leafNode.m_pos.x, leafNode.m_pos.y] = isMe ? (short)1 : (short)2;
// FillBoard(!isMe, leafNode.parentNode);
//}
private void StartBtn_Click(object sender, EventArgs e)
{
for (int i = 0; i < Row; ++i)
for (int j = 0; j < Col; ++j)
m_board[i, j] = 0;
foreach (var one in xy_pos_dic.Values)
one.btn.Text = "";
Thread th = new Thread(this.AIMainThread);
th.IsBackground = true;
th.Start();
}
private bool m_AI_thinking = false;
private void AIMainThread()
{
waitEvent.Reset();//后手
while (true)
{
waitEvent.WaitOne();
if (GameOver(m_board, 2))
{
UpdateMsg("恭喜,你赢啦!");
break;
}
m_AI_thinking = true;
UpdateMsg("AI思考中...");
TreeNode rootNode = new TreeNode();
rootNode.m_visit_count = 1;
rootNode.IsAI = false;
int count = 9000;// 1500+9655;//
for (int i = 0; i < Row; ++i)
for (int j = 0; j < Col; ++j)
m_boardSearch[i, j] = m_board[i, j];
DateTime dtStart = DateTime.Now;
// while (--count > 0)//模拟次数
//while((--count > 0))//||
while((DateTime.Now- dtStart).TotalMinutes<=0.33)
{
// Console.WriteLine("\n1 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
for (int i = 0; i < Row; ++i)
for (int j = 0; j < Col; ++j)
m_boardSearch[i, j] = m_board[i, j];
TreeNode leafNode = mcts_select(true,rootNode);//在m_boardSearch copy m_board的基础上模拟,一条直线下来。
// Console.WriteLine("\n2 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
TreeNode expNode = leafNode.m_ChildNodes.Count == 0&& leafNode.m_visit_count>0 ? ExpandNodeOld(leafNode) : leafNode;
if (expNode == null)
break;
for (int i = 0; i < Row; ++i)
for (int j = 0; j < Col; ++j)
m_boardSearch[i, j] = m_board[i, j];
// Console.WriteLine("\n3 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
FillBoard(true, expNode);//回填下棋步骤(直线回传)
// Console.WriteLine("\n4 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
int res = StartSimulate(expNode);
// Console.WriteLine("\n5 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
BackUp(res == 1, res == 0, expNode);
// Console.WriteLine("\n6 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
}
GotoNext(rootNode);//真正选择最佳策略.
if (GameOver(m_board, 1))
{
UpdateMsg("AI赢啦!");
break;
}
if (GameOver(m_board, 2))
{
UpdateMsg("恭喜,你赢啦!");
break;
}
UpdateMsg("请开始选择...");
m_AI_thinking = false;
}
m_AI_thinking = false;
}
public void UpdateMsg(string str)
{
this.BeginInvoke(new Action(() =>
{
SYlabel.Text = str;
}));
}
public void UpdateBoardUI(Pos p)
{
this.BeginInvoke(new Action(() =>
{
p.btn.Text = "1";
p.btn.ForeColor = Color.Black;
p.btn.Focus();
}));
}
}
}
可以看到在8*8 ,10*10 ,搜索时间在20秒内,AI基本能下棋正确!