Unity MlAgent 使用介绍

分享github 工程:https://github.com/IsaWinding/mlagent02.git

1.本文介绍要一个战斗中,英雄Ai行为操作用MlAgent来训练操作

Unity MlAgent 使用介绍_第1张图片
图中为英雄训练的次数,生成的不同.onnx 文件。

Unity MlAgent 使用介绍_第2张图片
训练次数较少的时候,训练的英雄会随机在周围的墙边移动
随着训练次数变多,训练的英雄会直接奔向敌方小兵的位置,进行攻击操作。
英雄行为脚本:

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class UnitAgent : Agent
{
    private Unit player;
    public override void Initialize()
    {
        player = this.GetComponent<Unit>();
        MaxStep = 20000;
    }
    public override void CollectObservations(VectorSensor sensor)
    {
        if (player == null)
            player = this.GetComponent<Unit>();
        if (player == null)
            return;
        sensor.AddObservation(player.transform.localPosition);
        sensor.AddObservation(player.curHp);
        sensor.AddObservation(player.Target);
        if (player.Target != null && !player.Target.IsDead())
        {
            sensor.AddObservation(player.Target.transform.localPosition);
            sensor.AddObservation(player.Target.curHp);
        }
        else
        {
            sensor.AddObservation(Vector3.zero);
            sensor.AddObservation(0);
        }
    }
    public override void OnActionReceived(ActionBuffers actions)
    {
        var controlSignal = Vector3.zero;
        var vectorAction = actions.ContinuousActions;
        controlSignal.x = vectorAction[0];
        controlSignal.z = vectorAction[1];
        
        player.MoveByDir(controlSignal);
        if (vectorAction[2] > 0f){
            player.AtkAjust(()=> {
                AddReward(0.1f);
            },()=> {
                AddReward(1f);
            });
        }
        //死亡重启
        if (player.curHp <= 0)
        {
            //AddReward(-3f);
            EndEpisode();
        }
        //越界重启
        if (Mathf.Abs(player.transform.localPosition.y) >= 2 || Mathf.Abs(player.transform.localPosition.x) >= 20 ||
            Mathf.Abs(player.transform.localPosition.x) >= 20)
        {
            //AddReward(-1f);
            EndEpisode();
        }
    }
    public override void OnEpisodeBegin()
    {
        if (player == null)
            player = this.GetComponent<Unit>();
        if (player == null)
            return;
        player.ResetBattleField();
    }
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var continuousActionsOut = actionsOut.ContinuousActions;
        continuousActionsOut[0] = Input.GetAxis("Horizontal");
        continuousActionsOut[1] = Input.GetAxis("Vertical");
        continuousActionsOut[2] = Input.GetAxis("Jump");
    }
}

单位逻辑脚本:

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class Unit:MonoBehaviour
{
    [HideInInspector]
    public int id;
    public UnitInfo info;
    public AIPolicy aiPolicy;

    private AIPath path;
    public Unit Target;
    private UnitAni unitAni;
    public float followDistance { get { return info.flowRange; } }
    public float warnDistance { get { return info.warRange; } }
    public float atkDistance { get { return info.atkRange; } }
    public float moveSpeed { get { return info.moveSpeed; } }
    public float curHp { get { return info.hp; }set { info.hp = value; } } 
    public float MaxHp { get { return info.hpMax; } }
    private float atkCd { get { return info.atkCd; } }
    private float atk { get { return info.atk; } }
    public CampType campType { get { return info.campType; } }
    private float nextAtkTime = 0f;
    private bool IsNeedBackToHome = false;
    private bool isAttack = false;
    private bool isAttackTarget = false;
    private bool isMoveToTargetPos = false;
    private Vector3 findTargetPos;
    private Vector3 targetPos;
    private bool isCanReborn = true;
    private Vector3 bornPos;
    public List<Unit> AllUnits;
    public UnitManager manager;
    public void ResetBattleField()
    {
        manager.OnEpisodeBegin();
    }
    public void CreatUnit(int pId,UnitInfo pInfo,Vector3 pBornPos, AIPolicy pAIPolicy, AIPath pPath)
    {
        id = pId;
        info = pInfo;
        bornPos = pBornPos;
        aiPolicy = pAIPolicy;
        path = pPath.InitByGOList();
        path.Init();
        unitAni = this.gameObject.GetComponent<UnitAni>();
    }
    public void SetColor(Color pColor)
    {
        var mat = GetComponentInChildren<Material>();
        mat.SetColor("",pColor);
    }
    public void OnEpisodeBegin()
    {
        Idle();
        isAttackTarget = false;
        CancelInvoke("Reborn");
        Reborn();
    }
    public void OnAiAction(List<Unit> allUnits)
    {
        AllUnits = allUnits;
        if (IsDead())
            return;
        if(aiPolicy != null)
            aiPolicy.OnRun(allUnits);
    }

    public void Idle(){
        SetCurTarget(null);
        StopAttack();
        StopMove();
        unitAni.PlayAniByType(AniNameType.Idle, 1);
    }
    public void StopMove(){isMoveToTargetPos = false;}
    public void Attack(){isAttack = true;}
    public void StopAttack(){isAttack = false;}
    public void SetCurTarget(Unit pCharacterInput){
        Target = pCharacterInput;
        findTargetPos = this.transform.position;
    }
    public bool NeedBackToHome(){
        if (IsNeedBackToHome)
            return true;
        if (Target == null)
            return false;
        if (Target.IsDead())
            return true;
        var distance = Vector3.Distance(this.transform.position, findTargetPos);
        return distance >= followDistance;
    }
    public void SelectAdjustTarget(List<Unit> pAllUnits){
        var unit = GetOneUnitInRange(pAllUnits,warnDistance);
        SetCurTarget(unit);
    }
    public bool IsDead(){return info.hp <= 0;}
    public bool IsHaveTarget(){return Target != null;}
    public void BackToHome(){
        var distance = Vector3.Distance(this.transform.position, findTargetPos);
        if (distance >= 1){
            IsNeedBackToHome = true;
            SetCurTarget(null);
            MoveToTargetPos(findTargetPos);
        }
        else{
            IsNeedBackToHome = false;
        }
    }
    public void OnLoopMove(){
        if (path.IsReachNextPoint(this.transform.position))
        {
            path.OnReachNextPoint();
        }
        var nextPos = path.GetNextPoint();
        if (nextPos != null)
            MoveToTargetPos(nextPos.pos);
        else
            Idle();
    }
    public void MoToCurTarget(){
        StopAttack();
        MoveToTargetPos(Target.transform.position);
    }
    public void MoveByDir(Vector3 pDir)
    {
        var targetPos = this.transform.position + pDir.normalized*2f;
        MoveToTargetPos(targetPos);
    }
    public void MoveToTargetPos(Vector3 pTargetPos){
        StopAttack();
        targetPos = pTargetPos;
        isMoveToTargetPos = true;
    }
    public void FaceToCurTarget(){
        if(Target != null)
            this.transform.LookAt(Target.transform.position);
    }
    private bool IsInRange(Unit pTarget, float pRange)
    {
        var targetPos = pTarget.transform.position;
        var selfPos = this.transform.position;
        if (Vector3.Distance(targetPos, selfPos) <= pRange){
            return true;
        }
        return false;
    }
    public bool CurTargetInRange(float pRange){
        if (Target != null && !Target.IsDead())
        {
            return IsInRange(Target, pRange);
        }
        return false;
    }
    public bool IsCanCampSelect(CampType pTargetCampType)
    {
        if (campType == CampType.PlayerA)
        {
            return pTargetCampType == CampType.PlayerB || pTargetCampType == CampType.Monster;
        }
        else if (campType == CampType.PlayerB)
        {
            return pTargetCampType == CampType.PlayerA || pTargetCampType == CampType.Monster;
        }
        else if (campType == CampType.Monster)
        {
            return pTargetCampType == CampType.PlayerA || pTargetCampType == CampType.PlayerB;
        }
        return false;
    }
    public Unit GetOneUnitInRange(List<Unit> pAllUnits,float pRange)
    {
        for (var i = 0; i < pAllUnits.Count; i++){
            var unit_ = pAllUnits[i];
            if (!unit_.IsDead() && IsCanCampSelect(unit_.campType) &&IsInRange(pAllUnits[i], pRange))
                return pAllUnits[i];
        }
        return null;
    }
    public bool NeedSelectTarget(List<Unit> pAllUnits)//是否需要重新选择追踪目标
    {
        if (Target != null && !Target.IsDead() && IsInRange(Target, followDistance))
            return false;
        var character = GetOneUnitInRange(pAllUnits,warnDistance);
        return character != null;
    }
    private void OnDead()
    {
        isAttack = false;
        isAttackTarget = false;
        isMoveToTargetPos = false;
        Target = null;
        unitAni.PlayAniByType(AniNameType.Dead, 6);
        if (isCanReborn)
        {
            Invoke("Reborn", 2f);
            isCanReborn = false;
        }
    }
    public void HpChange(float pOld, float hp)
    {
        //characterHp.SetHpInfo(hp, MaxHp);
        if (hp <= 0){
            OnDead();
        }
    }

    private void Reborn()
    {
        isAttack = false;
        isAttackTarget = false;
        isMoveToTargetPos = false;
        Target = null;
        curHp = MaxHp;
        //SetHpInfo();
        RpcReborn();
    }
    private void RpcReborn()
    {
        this.transform.position = bornPos;
        isCanReborn = true;
    }
    private float oldHp;
    public void OnDamage(float pDamage, System.Action pAtkCB, System.Action pKillCB)
    {
        oldHp = curHp;
        curHp -= pDamage;
        if (curHp > MaxHp)
            curHp = MaxHp;
        if (curHp < 0)
            curHp = 0;
      
        HpChange(oldHp, curHp);
        //SetHpInfo();
        if (curHp <= 0){
            pKillCB?.Invoke();
        }
        pAtkCB?.Invoke();
    }
    public void AtkAjust(System.Action pAtkCB,System.Action pKillCB)
    {
        if (Target != null && !Target.IsDead()&& IsInRange(Target, followDistance))
        {
            if (IsInRange(Target, atkDistance))
                AtkCurTarget(pAtkCB, pKillCB);
            else { 
                MoveToCurTarget();
            }
        }
        else
        {
            SelectAdjustTarget(AllUnits);
            MoveToCurTarget();
        }
    }
    public void MoveToCurTarget()
    {
        if(Target!= null)
            MoveToTargetPos(Target.transform.position);
    }
    public void AtkCurTarget(System.Action pAtkCB, System.Action pKillCB)
    {
        StopMove();
        this.transform.LookAt(Target.transform);
        var curTime = Time.realtimeSinceStartup;
        if (curTime >= nextAtkTime)
        {
            isAttackTarget = true;
            nextAtkTime = curTime + atkCd;
            unitAni.PlayAniByType(AniNameType.Attack, 4, () => {
                DoNormalAttack(pAtkCB, pKillCB);
                isAttackTarget = false;
            });
        }
    }
    public void DoNormalAttack(System.Action pAtkCB, System.Action pKillCB)
    {
        if (Target != null)
            Target.OnDamage(atk, pAtkCB, pKillCB);
    }
    private Vector3 GetMoveDir(Vector3 pTargetPos){
        var dir = pTargetPos - this.transform.position;
        return dir.normalized * moveSpeed;
    }

    private void FixedUpdate()
    {
        if (IsDead())
            return;
        //curTime += Time.deltaTime;
        if (isAttack)
        {
            var curTime = Time.realtimeSinceStartup;
            if (curTime >= nextAtkTime)
            {
                nextAtkTime = curTime + atkCd;
                unitAni.PlayAniByType(AniNameType.Attack, 4, () =>
                {
                    isAttack = false;
                    DoNormalAttack(null, null);
                });
            }
        }
        if (isMoveToTargetPos && !isAttackTarget)
        {
            if (Vector3.Distance(targetPos, this.transform.position) <= 1)
            {
                isMoveToTargetPos = false;
            }
            else
            {
                this.transform.localPosition += GetMoveDir(targetPos) * Time.deltaTime;
                this.transform.LookAt(targetPos);
                unitAni.PlayAniByType(AniNameType.Move, 3);
            }
        }

        if (!isAttack && !isMoveToTargetPos&&!isAttackTarget)
        {
            unitAni.PlayAniByType(AniNameType.Idle, 1);
        }
    }
}

寻路路径

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public enum PathType
{
	Once = 1,
	Loop = 2,
	PingPong = 3,

}
[System.Serializable]
public class PathPoint
{
	public Vector3 pos;
	public PathPoint NextPoint;
	public PathPoint PrePoint;
	public PathPoint(Vector3 pPos)
	{
		pos = pPos;
	}
}

[System.Serializable]
public class AIPath
{
	[HideInInspector]
	public List<PathPoint> Paths = new List<PathPoint>();
	public List<GameObject> PathGos = new List<GameObject>();
	public PathType pathType = PathType.Loop;
	private PathPoint nextPoint;
	public void SetAdjustTwoPointPath(Vector3 pMiddlePos, float pDistance)
	{
		Paths.Clear();
		var posR = pMiddlePos + Vector3.right * pDistance;
		var pathPointR = new PathPoint(posR);
		Paths.Add(pathPointR);
		var posL = pMiddlePos + Vector3.right * -pDistance;
		var pathPointL = new PathPoint(posL);
		Paths.Add(pathPointL);
	}
	public AIPath InitByGOList()
	{
		var aiPath = new AIPath();
		aiPath.Paths.Clear();
		foreach (var temp in (PathGos))
		{
			var pathPointR = new PathPoint(temp.transform.position);
			aiPath.Paths.Add(pathPointR);
		}
		return aiPath;
	}

	public void Init()
	{
		for (var i = 0; i < Paths.Count; i++)
		{
			var path = Paths[i];
			if (i == 0)
			{
				path.NextPoint = Paths[i + 1];
				if (pathType == PathType.Loop)
					path.PrePoint = Paths[Paths.Count - 1];
			}
			else if (i == Paths.Count - 1)
			{
				if (pathType == PathType.Loop)
					path.NextPoint = Paths[0];
				path.PrePoint = Paths[i - 1];
			}
			else
			{
				path.NextPoint = Paths[i + 1];
				path.PrePoint = Paths[i - 1];
			}
		}
		nextPoint = Paths[0];
	}
	public PathPoint GetNextPoint()
	{
		return nextPoint;
	}
	public bool IsReachNextPoint(Vector3 pos)
	{
		if (Mathf.Abs(nextPoint.pos.x - pos.x) <= 1)
			return true;
		return false;
	}
	private bool isForward = true;
	public void OnReachNextPoint()
	{
		if (pathType == PathType.Once)
		{
			nextPoint = nextPoint.NextPoint;
		}
		else if (pathType == PathType.Loop)
		{
			nextPoint = nextPoint.NextPoint;
		}
		else if (pathType == PathType.PingPong)
		{
			if (isForward)
			{
				if (nextPoint.NextPoint != null)
					nextPoint = nextPoint.NextPoint;

				else
				{
					nextPoint = nextPoint.PrePoint;
					isForward = false;
				}
			}
			else
			{
				if (nextPoint.PrePoint != null)
					nextPoint = nextPoint.PrePoint;
				else
				{
					nextPoint = nextPoint.NextPoint;
					isForward = true;
				}
			}
		}
	}
}

ai策略代码:

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public enum AIPolicyType
{ 
    None = 0,
    Soldier = 1,
    Tower = 2,
    Hero = 3
}

public class TowerPolicy : AIPolicy{
    protected override List<AIAction> allActions {
        get {
            if (allActions_ == null)
            {
                allActions_ = new List<AIAction>() {
                    new AttackToTargetAction(),
                    new SelectTargetAction(),
                    new IdleAction()};
            }
            return allActions_;
        }
    }
    public TowerPolicy(Unit pMonsterAi)
    {
        bindAi = pMonsterAi;
    }
}
public class SoldierPolicy : AIPolicy{
    protected override List<AIAction> allActions {
        get {
            if (allActions_ == null)
            {
                allActions_ = new List<AIAction>() {
                    new AttackToTargetAction(),
                    new MoveToTargetAction(),
                    new SelectTargetAction(),
                    new MoveAction(),
                    new IdleAction()};
            }
            return allActions_;
        }
    }
    public SoldierPolicy(Unit pMonsterAi) {
        bindAi = pMonsterAi;
    }
}
public class AIPolicy
{
    protected List<AIAction> allActions_;
    protected virtual List<AIAction> allActions { get { return allActions_;} } 
    private AIAction curAiAction;
    protected Unit bindAi;
    
    public AIPolicy() { }
    public AIPolicy(Unit pMonsterAi) {
        bindAi = pMonsterAi;
    }
    public void OnRun(List<Unit> pAllUnits)
    {
        for (var i = 0; i < allActions.Count; i++)
        {
            if (allActions[i].IsPass(bindAi, pAllUnits))
            {
                curAiAction = allActions[i];
                break;
            }
        }
        curAiAction.OnAcion(bindAi, pAllUnits);
    }
}
public class BackToHomeAction : AIAction{
    public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new BackToHomeConditioner(); return conditioner; } }
    public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits){
        pBindAi.BackToHome();
    }
}
public class AttackToTargetAction : AIAction
{
    public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new AttackToTargetConditioner(); return conditioner; } }
    public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
    {
        pBindAi.StopMove();
        pBindAi.FaceToCurTarget();
        pBindAi.Attack();
    }
}
public class MoveToTargetAction : AIAction
{
    public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new MoveToTargetConditioner(); return conditioner; } }
    public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
    {
        pBindAi.MoToCurTarget();
    }
}
public class SelectTargetAction : AIAction
{
    public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new ForcusTargetConditioner(); return conditioner; } }
    public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
    {
        pBindAi.SelectAdjustTarget(pAllUnits);
    }
}
public class MoveAction : AIAction {
    public override AIConditioner Conditioner { get { if (conditioner == null) conditioner = new MoveConditioer(); return conditioner; } }
    public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
    {
        pBindAi.OnLoopMove();
    }
}
public class IdleAction : AIAction {
    public override AIConditioner Conditioner { get {if (conditioner == null) conditioner = new IdleConditoner(); return conditioner; } }
    public override void OnAcion(Unit pBindAi, List<Unit> pAllUnits)
    {
        pBindAi.Idle();
    }
}
public class AIAction
{
    protected AIConditioner conditioner;

    public virtual AIConditioner Conditioner { get { if (conditioner == null) conditioner = new AIConditioner(); return conditioner; } }
    public virtual void OnAcion(Unit pBindAi, List<Unit> pAllUnits) { }
    public bool IsPass(Unit pBindAi, List<Unit> pAllUnits) {
        return Conditioner.IsPass(pBindAi, pAllUnits);
    }
}

public class BackToHomeConditioner : AIConditioner//目标超出追踪范围,返回原地
{
    public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
    {
        return pBindAi.NeedBackToHome();
    }
}
public class AttackToTargetConditioner : AIConditioner//攻击目标
{
    public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
    {
        return pBindAi.CurTargetInRange(pBindAi.atkDistance);
    }
}
public class MoveToTargetConditioner : AIConditioner//追踪目标
{
    public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
    {
        return pBindAi.CurTargetInRange(pBindAi.followDistance);
    }
}
public class ForcusTargetConditioner : AIConditioner//锁定目标
{
    public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
    {
        return pBindAi.NeedSelectTarget(pAllUnits);
    }
}
public class MoveConditioer : AIConditioner //巡逻d 
{
    public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
    {
        return !pBindAi.IsHaveTarget();
    }
}
public class IdleConditoner: AIConditioner//待机
{
    public override bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
    {
        return !pBindAi.IsDead();
    }
}
public class AIConditioner
{
    public virtual bool IsPass(Unit pBindAi, List<Unit> pAllUnits)
    {
        return true;
    }
}

大部分核心代码分享了,感兴趣的同学,可以直接下载github 的资源测试了解。
再次分享github 工程:https://github.com/IsaWinding/mlagent02.git

你可能感兴趣的:(Unity,游戏,经验分享,unity,深度学习,tensorflow)