本文作为学习ml-agents的开篇,主要是参考ml-agents的官方文档,结合Unity和ml-agents创建一个简单的游戏作为强化学习算法的训练环境。本文中用的是官方最新发布的稳定版ml-agents release_1(Version 1.0.0)。整个DIY过程亲测可运行, 并且经过大约半小时的训练小球就足够“智能”,可以很快地找到Target, 最终效果如下:
在Hierarchy窗口中单击鼠标右键-> 3D Object->分别创建Plane、Sphere和Cube对象。分别在三个对象的Inspector中Reset Transform,如下图
单击选中Ball->在Inspector窗口的最下面单击Add Component按钮->搜索Rigidbody并添加
PS: 添加Rigidbody使Object具备物理特性,例如重力、作用力等。
git clone --branch release_1 https://github.com/Unity-Technologies/ml-agents.git
pip3 install mlagents
运行完成后可以通过运行"mlagents-learn --help"命令验证是否成功安装。
单击选中Ball->在Inspector窗口的最下面单击Add Component按钮->搜索script, 选择New script, 命名为"BallAgent"-> 点击Create and Add按钮
创建完成后可以在Project窗口的Assets目录下找到新建的BallAgent脚本
双击BallAgent脚本, 在编辑器中打开, 在首行添加
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
把基类MonoBehaviour改为Agent。删除Update()方法,保留Start()方法。更新后结果如下:
初始化方法主要是两个:Start(), OnEpisodeBegin()
初始化代码如下:
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
public class BallAgent : Agent
{
Rigidbody rBody;
void Start () {
// 获取Object中的Rigidbody组件,本例中即Ball的Rigidbody组件
rBody = GetComponent<Rigidbody>();
}
// 获取Target对象
// Script中的public属性会显示在Object的该Script组件中,可以通过拖拽来指定要关联的对象
public Transform Target;
public override void OnEpisodeBegin()
{
// this指代Ball Object,if语句用于判断Ball的y坐标是否为负(即Ball是否掉落Floor)
if (this.transform.localPosition.y < 0)
{
// 如果Ball掉落,则在新一轮游戏开始时重置Ball到Floor中央
this.rBody.angularVelocity = Vector3.zero;
this.rBody.velocity = Vector3.zero;
this.transform.localPosition = new Vector3(0, 0.5f, 0);
}
// 在新一轮游戏开始时,把Target放置在Floor上的一个随机位置
Target.localPosition = new Vector3(Random.value * 8 - 4,
0.5f,
Random.value * 8 - 4);
}
}
通过CollectObservations()方法来指定Observation中所包含的信息。在本例中Observation中包含8个值,分别是Ball的位置坐标(x, y, z), Target的位置坐标(x, y, z), Ball的速度(x, z)。
注意: Unity中的坐标轴和一般用法有些不同。在Unity中X和Z坐标定义的平面是水平面,Y坐标是垂直坐标轴,如下图。所以标识速度是用(x, z)坐标。
Observation代码实现如下:
public override void CollectObservations(VectorSensor sensor)
{
// Target和Agent的位置信息
sensor.AddObservation(Target.localPosition);
sensor.AddObservation(this.transform.localPosition);
// Agent的速度信息
sensor.AddObservation(rBody.velocity.x);
sensor.AddObservation(rBody.velocity.z);
}
通过OnActionReceived()方法来接收Action, 并根据执行Action后的状态变化指定获得的Reward。实现代码如下:
// speed默认为10,可以在script组件中设置
public float speed = 10;
public override void OnActionReceived(float[] vectorAction)
{
// Actions的size = 2, 指示Ball在X轴和Z轴方向(即水平面)上的移动信号
Vector3 controlSignal = Vector3.zero;
controlSignal.x = vectorAction[0];
controlSignal.z = vectorAction[1];
// 通过在Rigidbody上作用力来使Ball移动
rBody.AddForce(controlSignal * speed);
// 获取Ball移动后与Target的距离
float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
// 通过Ball和Target之间的距离来判断Ball是否碰触了Target
if (distanceToTarget < 1.42f)
{
// 如果Ball碰触了Target, 获得1.0的奖励,并结束本次Episode
SetReward(1.0f);
EndEpisode();
}
// 判断Ball是否掉落Floor,如果掉落则之间结束本次Episode
if (this.transform.localPosition.y < 0)
{
EndEpisode();
}
}
强化学习的几个要素Observation, Action, Reward都在代码中设置好了,接下来需要在Unity Editor中进行配置,把这些要素都联系起来。
回到Unity窗口选中Ball对象,在Inspector中通过Add Component分别添加"Decision Requester"和"Behavior Parameters", 这两个其实是新版ML Agents包中已经封装好的脚本,我们只需要在Inspector配置就可以。
在正式训练之前我们都需要先对游戏进行一下测试,主要方法是手动玩一下游戏,以保证游戏的运行符合预期。可以通过在BallAgent中实现Heuristic()方法来实现这一功能,代码实现如下:
// 用于测试游戏,可以实现手动控制Ball的滚动
public override void Heuristic(float[] actionsOut)
{
actionsOut[0] = Input.GetAxis("Horizontal");
actionsOut[1] = Input.GetAxis("Vertical");
}
在BallAgent中添加完代码后,还需要在Inspector中把Behavior Parameters->Behavior Type改为"Heuristic Only",然后在Unity主窗口中点击Play按钮,进入Game窗口就可以通过键盘方向键来控制小球的滚动,从而测试游戏效果是否符合预期。
在ml-agents/config目录下新建配置文件rollingball_config.yaml内容如下。其中首行的“Ball”就是设置的Behavior Name
Ball:
trainer: ppo
batch_size: 10
beta: 5.0e-3
buffer_size: 100
epsilon: 0.2
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4
learning_rate_schedule: linear
max_steps: 5.0e5
memory_size: 128
normalize: false
num_epoch: 3
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 10000
use_recurrent: false
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
mlagents-learn config/rollingball_config.yaml --run-id=RollingBall
在Unity Console中: Couldn’t connect to trainer on port 5004 using API version 1.0.0. Will perform inference instead.
本机的命令行中:“The Unity environment took too long to respond. Make sure that :\n”
mlagents_envs.exception.UnityTimeOutException: The Unity environment took too long to respond. Make sure that :
The environment does not need user interaction to launch
The Agents are linked to the appropriate Brains
The environment and the Python interface have compatible versions.
如果在点击Play后出现上述报错信息可能是如下原因:
(1) ml-agents安装不完全。重新用git clone命令下载 ml-agents到本地进行安装,切记不要通过在github中用“Clone or download”按钮下载ml-agents到本地的形式安装!!!
(2) 防火墙阻止了5004端口的通信。关闭防火墙再重试。