查看原文
通过Behavior Cloning从人的经验中初始化策略网络的参数,策略网络的结构为:
随机初始化网络参数后,基于人类对弈的落子序列数据,进行模仿学习(分类任务),使用交叉熵作为损失函数进行参数更新:
每次蒙特卡洛树(MCTS)的搜索过程:
查看原文
AlphaGo Zero使用MCTS训练策略网络
# -*- coding: utf-8 -*-
# @Time : 2022/4/1 13:47
# @Author : CyrusMay WJ
# @FileName: resnet.py
# @Software: PyCharm
# @Blog :https://blog.csdn.net/Cyrus_May
import tensorflow as tf
import logging
import sys
import os
os.environ["PATH"] += os.pathsep + 'D:\software_root\Anoconda3\envs\AI\Graphviz\\bin' # 用于网络结构画图
class ResidualNet():
def __init__(self,input_dim,output_dim,net_struct,l2_reg=0,logger=None):
"""
:param input_dim:
:param output_dim:
:param net_struct: a list for residual network, net_struct[0] is the first CNN for inputs,
the rest is single block for residual connect. e.g. net_struct = [
{filters:64,kernel_size:(3,3), {filters:128,kernel_size:(3,3),
{filters:128,kernel_size:(3,3)}
]
:param logger:
"""
self.logger=logger
self.input_dim = input_dim
self.output_dim=output_dim
self.l2_reg = l2_reg
self.__build_model(net_struct)
def conv_layer(self,x,filters,kernel_size):
x = tf.keras.layers.Conv2D(filters=filters,
kernel_size=kernel_size,
activation="linear",
padding="same",
data_format="channels_last",
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg),
bias_regularizer=tf.keras.regularizers.l2(self.l2_reg))(x)
x = tf.keras.layers.BatchNormalization(axis=-1)(x)
x = tf.keras.layers.LeakyReLU()(x)
return x
def residual_block(self,inputs,filters,kernel_size):
x = self.conv_layer(inputs,filters,kernel_size)
x = tf.keras.layers.Conv2D(filters=filters,
kernel_size=kernel_size,
activation="linear",
padding="same",
data_format="channels_last",
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg),
bias_regularizer=tf.keras.regularizers.l2(self.l2_reg))(x)
x = tf.keras.layers.BatchNormalization(axis=-1)(x)
if inputs.shape[-1] == filters:
x = tf.keras.layers.add([inputs,x])
else:
inputs = tf.keras.layers.Conv2D(filters=filters,
kernel_size=(1,1),
activation="linear",
padding="same",
data_format="channels_last",
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg),
bias_regularizer=tf.keras.regularizers.l2(self.l2_reg))(inputs)
x = tf.keras.layers.add([inputs, x])
x = tf.keras.layers.LeakyReLU()(x)
return x
def policy_head(self,inputs):
x = tf.keras.layers.Conv2D(filters=2,
kernel_size=(1,1),
activation="linear",
padding="same",
data_format="channels_last",
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg),
bias_regularizer=tf.keras.regularizers.l2(self.l2_reg))(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(units=self.output_dim,
activation="linear",
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg),
bias_regularizer=tf.keras.regularizers.l2(self.l2_reg),
name="policy_head"
)(x)
return x
def state_value_head(self,inputs):
x = tf.keras.layers.Conv2D(filters=2,
kernel_size=(1, 1),
activation="linear",
padding="same",
data_format="channels_last",
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg),
bias_regularizer=tf.keras.regularizers.l2(self.l2_reg))(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(units=1,
activation="linear",
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg),
bias_regularizer=tf.keras.regularizers.l2(self.l2_reg),
name="state_value_head"
)(x)
return x
def __build_model(self,net_struct):
input_layer = tf.keras.layers.Input(shape=self.input_dim,name="inputs")
x = self.conv_layer(input_layer,net_struct[0]["filters"],net_struct[0]["kernel_size"])
for i in range(1,len(net_struct)):
x = self.residual_block(x,net_struct[i]["filters"],net_struct[i]["kernel_size"])
v_output = self.state_value_head(x)
p_output = self.policy_head(x)
self.model = tf.keras.models.Model(inputs=input_layer,outputs=[p_output,v_output])
tf.keras.utils.plot_model(self.model, to_file="./AlphZero.png")
self.model.compile(optimizer=tf.optimizers.Adam(),
loss = {"policy_head":tf.nn.softmax_cross_entropy_with_logits,"state_value_head":"mean_squared_error"},
loss_weights={"policy_head":0.5,"state_value_head":0.5})
if __name__ == '__main__':
logger = logging.getLogger(name="ResidualNet")
logger.setLevel(logging.INFO)
screen_handler = logging.StreamHandler(sys.stdout)
screen_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(module)s.%(funcName)s:%(lineno)d - %(levelname)s - %(message)s')
screen_handler.setFormatter(formatter)
logger.addHandler(screen_handler)
residual_net = ResidualNet(logger=logger,input_dim=[19,19,17],
output_dim=19*19,
net_struct=[
{"filters":64,"kernel_size":(3,3)},
{"filters": 128, "kernel_size": (3, 3)},
{"filters": 128, "kernel_size": (3, 3)},
{"filters": 64, "kernel_size": (3, 3)},
{"filters": 64, "kernel_size": (3, 3)},
])
参考我的另一篇博客。
本文部分内容为参考B站学习视频书写的笔记!
by CyrusMay 2022 04 04
当时有多少的心愿
就有多少的残缺
————五月天(步步)————