强化学习(Reinforcement Learning)是一种机器学习算法,强调如何基于环境而行动,以取得最大化的预期利益,在机器人领域应用较为广泛。Q-Learning属于强化学习的经典算法,用于解决马尔可夫决策问题。
package robot;
import java.awt.*;
import java.awt.geom.*;
import robocode.*;
import robocode.AdvancedRobot;
public class QLRobot extends AdvancedRobot {
//Q值存储表
private QLTable table;
//在状态state下采取action行为获得的回报
private double reward = 0.0;
//火力值(0,3]
private double firePower;
//方向
private int direction = 1;
//是否撞墙
private int isHitWall = 0;
//是否被击中
private int isHitByBullet = 0;
//存取数据
private Data data = new Data();
//敌方状态
private double distance = 100000;
private long ctime;
private double targetX;
private double targetY;
private double bearing;
private double turnDegree = 20.0;
private double moveDistance = 100.0;
public void run() {
table = data.readData();
setAdjustGunForRobotTurn(true);
setAdjustRadarForGunTurn(true);
turnRadarRightRadians(2 * Math.PI);
while (true) {
robotMovement();
firePower = 50 / distance;
if (firePower > 3)
firePower = 3;
//radarMovement();
if (getGunHeat() == 0) {
setFire(firePower);
}
execute();
}
}
void doMovement() {
if (getTime() % 20 == 0) {
direction *= -1;
setAhead(direction * 300);
}
setTurnRightRadians(bearing + (Math.PI / 2));
}
private void robotMovement() {
int state = getState();
int action = table.selectAction(state);
table.learn(state, action, reward);
reward = 0.0;
isHitWall = 0;
isHitByBullet = 0;
switch (action) {
case 0:
setAhead(moveDistance);
break;
case 1:
setBack(moveDistance);
break;
case 2:
setAhead(moveDistance);
setTurnLeft(turnDegree);
break;
case 3:
setAhead(moveDistance);
setTurnRight(turnDegree);
break;
case 4:
setAhead(moveDistance);
setTurnLeft(turnDegree);
break;
case 5:
setBack(moveDistance);
setTurnLeft(turnDegree);
break;
case 6:
setBack(moveDistance);
setTurnRight(turnDegree);
break;
}
}
private int getState() {
int heading = State.getHeading(getHeading());
int targetDistance = State.getTargetDistance(distance);
int targetBearing = State.getTargetBearing(bearing);
int state = State.Mapping[heading][targetDistance][targetBearing][isHitWall][isHitByBullet];
return state;
}
private void radarMovement() {
double radarOffset;
if (getTime() - ctime > 4) {
radarOffset = 4 * Math.PI;
} else {
radarOffset = getRadarHeadingRadians() - (Math.PI / 2 - Math.atan2(targetY - getY(), targetX - getX()));
radarOffset = NormaliseBearing(radarOffset);
if (radarOffset < 0)
radarOffset -= Math.PI / 10;
else
radarOffset += Math.PI / 10;
}
setTurnRadarLeftRadians(radarOffset);
}
double NormaliseBearing(double ang) {
if (ang > Math.PI)
ang -= 2 * Math.PI;
if (ang < -Math.PI)
ang += 2 * Math.PI;
return ang;
}
double NormaliseHeading(double ang) {
if (ang > 2 * Math.PI)
ang -= 2 * Math.PI;
if (ang < 0)
ang += 2 * Math.PI;
return ang;
}
public void onBulletHit(BulletHitEvent e) {
double change = e.getBullet().getPower() * 9;
reward += change;
}
public void onBulletMissed(BulletMissedEvent e) {
double change = -e.getBullet().getPower();
reward += change;
}
public void onHitByBullet(HitByBulletEvent e) {
double power = e.getBullet().getPower();
double change = -(4 * power + 2 * (power - 1));
reward += change;
isHitByBullet = 1;
}
public void onHitRobot(HitRobotEvent e) {
reward += -6.0;
}
public void onHitWall(HitWallEvent e) {
double change = -(Math.abs(getVelocity()) * 0.5 - 1);
reward += change;
isHitWall = 1;
}
public void onScannedRobot(ScannedRobotEvent e) {
if (e.getDistance() < distance) {
//找到敌人方位
double absbearing = (getHeadingRadians() + e.getBearingRadians()) % (2 * Math.PI);
//旋转到该方位
setTurnLeft(absbearing);
targetX = getX() + Math.sin(absbearing) * e.getDistance();
targetY = getY() + Math.cos(absbearing) * e.getDistance();
bearing = e.getBearingRadians();
ctime = getTime();
distance = e.getDistance();
}
}
public void onWin(WinEvent event) {
data.saveData(table);
}
public void onDeath(DeathEvent event) {
data.saveData(table);
}
}
package robot;
import java.io.Serializable;
public class QLTable implements Serializable{
private double[][] table;
//状态数
private final int states = 4 * 20 * 16 * 2 * 2;
//行为数
private final int actions = 7;
//是否是第一次行动
private boolean first = true;
//保存上一次状态
private int lastState;
//保存上一次行为
private int lastAction;
//衰减变量
private double gamma = 0.8;
//学习率
private double e = 0.1;
//探索值
private double exploration = 0.2;
public QLTable() {
table = new double[states][actions];
for (int i = 0; i < states; i++)
for (int j = 0; j < actions; j++)
table[i][j] = 0;
}
//对上个状态采取的action进行反馈
public void learn(int state, int action, double reward) {
if (first)
first = false;
else {
double oldValue = table[lastState][lastAction];
double newValue = (1 - e) * oldValue + e * (reward + gamma * getMaxValue(state));
//更新Q值
table[lastState][lastAction] = newValue;
}
lastState = state;
lastAction = action;
}
//获得某个state的最大Q值
private double getMaxValue(int state) {
double max = table[state][0];
for (int i = 0; i < actions; i++) {
if (table[state][i] > max)
max = table[state][i];
}
return max;
}
//根据目前的state,选择最合适的action
public int selectAction(int state) {
double qValue;
double sum = 0.0;
double[] value = new double[actions];
for (int i = 0; i < value.length; i++) {
qValue = table[state][i];
//保证为正数,且不改变函数增减性
value[i] = Math.exp(qValue);
sum += value[i];
}
//计算P(a|s)
if (sum != 0) {
for (int i = 0; i < value.length; i++) {
value[i] /= sum;
}
} else
return getBestAction(state);
int action = 0;
double cumProb = 0.0;
double randomNum = Math.random();
//大于探索值进行探索
if (randomNum < exploration) {
//保证每一种action都探索到
while (randomNum > cumProb && action < value.length) {
cumProb += value[action];
action++;
}
return action - 1;
} else {
//greedy
return getBestAction(state);
}
}
//获得最优action
private int getBestAction(int state) {
double max = table[state][0];
int action = 0;
for (int i = 0; i < actions; i++) {
if (table[state][i] > max) {
max = table[state][i];
action = i;
}
}
return action;
}
}