ML-Agents案例之食物收集者

本案例源自ML-Agents官方的示例,Github地址:https://github.com/Unity-Technologies/ml-agents,本文是详细的配套讲解。

本文基于我前面发的两篇文章,需要对ML-Agents有一定的了解,详情请见:Unity强化学习之ML-Agents的使用、ML-Agents命令及配置大全。

我前面的相关文章有:

ML-Agents案例之Crawler

ML-Agents案例之推箱子游戏

ML-Agents案例之跳墙游戏

环境说明

ML-Agents案例之食物收集者_第1张图片

环境中存在有多个智能体,他们的任务是收集尽可能多的绿色食物球,并避免碰到红色球:碰到绿色球奖励+1,碰到红色球-1,此外,智能体之间可以通过发射射线冻结其他智能体,以让自己吃到更多的食物球,达到增加自己分数的目的。

观察空间:使用了网格传感器(Grid Sensor),这个传感器具体说明参考ML-Agents案例之推箱子游戏 的多人模式。

在本案例中传感器直接挂载在智能体本身,z轴(前后方向)网格数为40,x轴(左右方向)网格数为40,y轴(上下方向)网格数为1,检测的标签有,食物,坏食物,其他智能体,被冻结的智能体,墙壁,加上什么都没观测到,维度为40 * 40 * 6 = 960个观察维度。

ML-Agents案例之食物收集者_第2张图片

ML-Agents案例之食物收集者_第3张图片

动作空间:3个连续动作输出,对应于前后运动,左右运动和旋转。1个离散输出,对应于是否发生激光。

代码分析

食物生成

控制食物生成的脚本,挂载在空物体上,这个脚本不会自己运作,需要在智能体挂载的脚本上调用。

using UnityEngine;
using Unity.MLAgentsExamples;

public class FoodCollectorArea : Area
{
    public GameObject food;
    public GameObject badFood;
    public int numFood;
    public int numBadFood;
    public bool respawnFood;
    public float range;

    // 生成食物
    void CreateFood(int num, GameObject type)
    {
        for (int i = 0; i < num; i++)
        {
            GameObject f = Instantiate(type, new Vector3(Random.Range(-range, range), 1f,
                Random.Range(-range, range)) + transform.position,
                Quaternion.Euler(new Vector3(0f, Random.Range(0f, 360f), 90f)));
            f.GetComponent().respawn = respawnFood;
            f.GetComponent().myArea = this;
        }
    }
	// 重置区域,随机化智能体们的位置,并生成两种食物
    public void ResetFoodArea(GameObject[] agents)
    {
        foreach (GameObject agent in agents)
        {
            if (agent.transform.parent == gameObject.transform)
            {
                agent.transform.position = new Vector3(Random.Range(-range, range), 2f,
                    Random.Range(-range, range))
                    + transform.position;
                agent.transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360)));
            }
        }

        CreateFood(numFood, food);
        CreateFood(numBadFood, badFood);
    }

    public override void ResetArea()
    {
    }
}

食物被吃掉的处理

挂载在食物上的脚本,同样需要调用才能起作用:

using UnityEngine;

public class FoodLogic : MonoBehaviour
{
    public bool respawn;
    public FoodCollectorArea myArea;

    // 被吃掉后的两种选择,一个是重新随机位置,一个是直接销毁
    public void OnEaten()
    {
        if (respawn)
        {
            transform.position = new Vector3(Random.Range(-myArea.range, myArea.range),
                3f,
                Random.Range(-myArea.range, myArea.range)) + myArea.transform.position;
        }
        else
        {
            Destroy(gameObject);
        }
    }
}

智能体主文件FoodCollectorAgent.cs:

初始化:

public override void Initialize()
{
    m_AgentRb = GetComponent();
    m_MyArea = area.GetComponent();
    m_FoodCollecterSettings = FindObjectOfType();
    // 从配置文件中获取参数
    m_ResetParams = Academy.Instance.EnvironmentParameters;
    // 设置参数
    SetResetParameters();
}
// 激光长度
public void SetLaserLengths()
{
    m_LaserLength = m_ResetParams.GetWithDefault("laser_length", 1.0f);
}
// 设置智能体体积大小
public void SetAgentScale()
{
    float agentScale = m_ResetParams.GetWithDefault("agent_scale", 1.0f);
    gameObject.transform.localScale = new Vector3(agentScale, agentScale, agentScale);
}

public void SetResetParameters()
{
    SetLaserLengths();
    SetAgentScale();
}

观察输入(可以通过设置useVectorObs和useVectorFrozenFlag)来配置是否输入,在案例中没有选择开启这两项输入:

public override void CollectObservations(VectorSensor sensor)
{
    if (useVectorObs)
    {
        var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity);
        // 输入水平方向的两个速度
        sensor.AddObservation(localVelocity.x);
        sensor.AddObservation(localVelocity.z);
        // 输入是否冻结和是否发射
        sensor.AddObservation(m_Frozen);
        sensor.AddObservation(m_Shoot);
    }
    else if (useVectorFrozenFlag)
    {
        // 输入是否冻结和是否发射
        sensor.AddObservation(m_Frozen);
    }
}

动作输出:

// 执行输出的主函数,里面的内容都封装到MoveAgent了
public override void OnActionReceived(ActionBuffers actionBuffers)
{
    MoveAgent(actionBuffers);
}

public void MoveAgent(ActionBuffers actionBuffers)
{
    m_Shoot = false;
	// 超过冻结时间,解冻
    if (Time.time > m_FrozenTime + 4f && m_Frozen)
    {
        Unfreeze();
    }
    // 超过被毒时间,解除毒素;超过满意时间,变为正常状态
    if (Time.time > m_EffectTime + 0.5f)
    {
        if (m_Poisoned)
        {
            Unpoison();
        }
        if (m_Satiated)
        {
            Unsatiate();
        }
    }

    var dirToGo = Vector3.zero;
    var rotateDir = Vector3.zero;

    var continuousActions = actionBuffers.ContinuousActions;
    var discreteActions = actionBuffers.DiscreteActions;

    // 没有冻结的时候才能执行动作
    if (!m_Frozen)
    {
        // 获取三个连续输出
        var forward = Mathf.Clamp(continuousActions[0], -1f, 1f);
        var right = Mathf.Clamp(continuousActions[1], -1f, 1f);
        var rotate = Mathf.Clamp(continuousActions[2], -1f, 1f);

        dirToGo = transform.forward * forward;
        dirToGo += transform.right * right;
        rotateDir = -transform.up * rotate;

        // 获取一个离散输出
        var shootCommand = discreteActions[0] > 0;
        if (shootCommand)
        {
            // 射击时减速
            m_Shoot = true;
            dirToGo *= 0.5f;
            m_AgentRb.velocity *= 0.75f;
        }
        // 执行移动
        m_AgentRb.AddForce(dirToGo * moveSpeed, ForceMode.VelocityChange);
        transform.Rotate(rotateDir, Time.fixedDeltaTime * turnSpeed);
    }

    // 超过一定速度需要减速
    if (m_AgentRb.velocity.sqrMagnitude > 25f) // slow it down
    {
        m_AgentRb.velocity *= 0.95f;
    }

    // 执行射击逻辑
    if (m_Shoot)
    {
        var myTransform = transform;
        myLaser.transform.localScale = new Vector3(1f, 1f, m_LaserLength);
        var rayDir = 25.0f * myTransform.forward;
        // 绘制射线,参数为起始位置,长度向量,颜色,持续时间,能否被遮挡
        Debug.DrawRay(myTransform.position, rayDir, Color.red, 0f, true);
        RaycastHit hit;
        // 发出球状射线,参数为起始位置,球半径,距离,碰撞到的物体,距离
        if (Physics.SphereCast(transform.position, 2f, rayDir, out hit, 25f))
        {
            // 射线碰到其他智能体,其他智能体会被冻结
            if (hit.collider.gameObject.CompareTag("agent"))
            {
                hit.collider.gameObject.GetComponent().Freeze();
            }
        }
    }
    else
    {
        myLaser.transform.localScale = new Vector3(0f, 0f, 0f);
    }
}
// 被冻结的逻辑,修改标签,计算时间,替换材质
void Freeze()
{
    gameObject.tag = "frozenAgent";
    m_Frozen = true;
    m_FrozenTime = Time.time;
    gameObject.GetComponentInChildren().material = frozenMaterial;
}
// 解冻的逻辑
void Unfreeze()
{
    m_Frozen = false;
    gameObject.tag = "agent";
    gameObject.GetComponentInChildren().material = normalMaterial;
}

// 吃到坏食物的状态
void Poison()
{
    m_Poisoned = true;
    m_EffectTime = Time.time;
    gameObject.GetComponentInChildren().material = badMaterial;
}
// 解毒
void Unpoison()
{
    m_Poisoned = false;
    gameObject.GetComponentInChildren().material = normalMaterial;
}
// 吃到好食物的状态
void Satiate()
{
    m_Satiated = true;
    m_EffectTime = Time.time;
    gameObject.GetComponentInChildren().material = goodMaterial;
}
// 好状态消失
void Unsatiate()
{
    m_Satiated = false;
    gameObject.GetComponentInChildren().material = normalMaterial;
}

和食物碰撞的逻辑:

void OnCollisionEnter(Collision collision)
{
    // 吃到好食物,奖励一分,进入满意状态
    if (collision.gameObject.CompareTag("food"))
    {
        Satiate();
        collision.gameObject.GetComponent().OnEaten();
        AddReward(1f);
        if (contribute)
        {
            m_FoodCollecterSettings.totalScore += 1;
        }
    }
    // 吃到坏食物,扣除1分,进入中毒状态
    if (collision.gameObject.CompareTag("badFood"))
    {
        Poison();
        collision.gameObject.GetComponent().OnEaten();

        AddReward(-1f);
        if (contribute)
        {
            m_FoodCollecterSettings.totalScore -= 1;
        }
    }
}

一轮游戏开始执行的逻辑:

public override void OnEpisodeBegin()
{
    // 消除所有异常状态
    Unfreeze();
    Unpoison();
    Unsatiate();
    m_Shoot = false;
    // 速度位置都归零
    m_AgentRb.velocity = Vector3.zero;
    myLaser.transform.localScale = new Vector3(0f, 0f, 0f);
    transform.position = new Vector3(Random.Range(-m_MyArea.range, m_MyArea.range),
                                     2f, Random.Range(-m_MyArea.range, m_MyArea.range))
        + area.transform.position;
    transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360)));
	// 重新设置参数,包括激光长短,智能体大小
    SetResetParameters();
}

可以通过以下代码自己控制智能体的输出:

 public override void Heuristic(in ActionBuffers actionsOut)
 {
     var continuousActionsOut = actionsOut.ContinuousActions;
     if (Input.GetKey(KeyCode.D))
     {
         continuousActionsOut[2] = 1;
     }
     if (Input.GetKey(KeyCode.W))
     {
         continuousActionsOut[0] = 1;
     }
     if (Input.GetKey(KeyCode.A))
     {
         continuousActionsOut[2] = -1;
     }
     if (Input.GetKey(KeyCode.S))
     {
         continuousActionsOut[0] = -1;
     }
     var discreteActionsOut = actionsOut.DiscreteActions;
     discreteActionsOut[0] = Input.GetKey(KeyCode.Space) ? 1 : 0;
 }

配置文件

配置文件是普通的PPO和SAC配置:

behaviors:
  GridFoodCollector:
    trainer_type: ppo
    hyperparameters:
      batch_size: 1024
      buffer_size: 10240
      learning_rate: 0.0003
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: false
      hidden_units: 256
      num_layers: 1
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 2000000
    time_horizon: 64
    summary_freq: 10000

behaviors:
  GridFoodCollector:
    trainer_type: sac
    hyperparameters:
      learning_rate: 0.0003
      learning_rate_schedule: constant
      batch_size: 256
      buffer_size: 2048
      buffer_init_steps: 0
      tau: 0.005
      steps_per_update: 10.0
      save_replay_buffer: false
      init_entcoef: 0.05
      reward_signal_steps_per_update: 10.0
    network_settings:
      normalize: false
      hidden_units: 256
      num_layers: 1
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 2000000
    time_horizon: 64
    summary_freq: 60000
    threaded: false

后记

本案例虽然场景中存在多个智能体,但它们彼此竞争,所以是单智能体案例。相比于前面的案例,新颖之处在于拥有攻击手段,能够发射射线干扰对手,如何利用攻击手段也称为了智能体进化的课题。

你可能感兴趣的:(强化学习,Unity强化学习,Unity,强化学习,深度强化学习)