这是一个我已经断断续续地研究了很长一段时间的项目。在此项目之前我从未尝试过修改游戏,也从未成功训练过“真正的”强化学习代理(智能体)。所以这个项目挑战是:解决钓鱼这个问题的“状态空间”是什么。当使用一些简单的 RL 框架进行编码时,框架本身可以为我们提供代理、环境和奖励,我们不必考虑问题的建模部分。但是在游戏中,必须考虑模型将读取每一帧的状态以及模型将提供给游戏的输入,然后相应地收集合适的奖励,此外还必须确保模型在游戏中具有正确的视角(它只能看到玩家看到的东西),否则它可能只是学会利用错误或者根本不收敛。
我的目标是编写一个能读取钓鱼小游戏状态并完美玩游戏的代理。目标的结果是使用官方 Stardew Valley 的 modding API 用 C# 编写一个自动钓鱼的mod。该模块加载了一个用 Python 训练的序列化 DQN 模型。所以首先要从游戏中收集数据,然后用这些数据用 Pytorch 训练一个简单的 DQN。经过一些迭代后,可以使用 ONNX 生成一个序列化模型,然后从 C# 端加载模型,并在每一帧中接收钓鱼小游戏的状态作为输入,并(希望)在每一帧上输出正确的动作。
钓鱼迷你游戏
这个代理是在SMAPI的帮助下编写的,SMAPI是Stardew Valley官方的mod API。API允许我在运行时访问游戏内存,并提供我所需要的一切去创造一个与游戏状态进行交互并实时向游戏提供输入的代理。
在钓鱼小游戏中,我们必须通过点击鼠标左键让“鱼钩”(一个绿色条)与移动的鱼对齐。鱼在这条竖线上无规律地移动,鱼钩条与鱼对齐时,绿色条就会填满一些,如果鱼成功逃离绿色条就会开始变空。当你填满绿色的条形图时,你会钓到鱼,当它绿条没有时鱼就跑了。
强化学习问题定义
所以这里只需要每帧从游戏内存中读取这些特定属性并将它们保存为在第 t 帧的状态。通过API我们可以查看并从游戏内存中读取特定属性的代码,对于自动钓鱼,需要在钓鱼小游戏期间跟踪的 4 个变量。 “钩子”中心的位置、鱼的位置、钩子的速度和绿色条的填充量(这是奖励!)。 游戏内部使用的名称有点奇怪,以下是读取它们的代码。
/ Update State
// hook position
bobberBarPos = Helper.Reflection.GetField(bar, "bobberBarPos").GetValue();
// fish position
bobberPosition = Helper.Reflection.GetField(bar, "bobberPosition").GetValue();
// hook speed
bobberBarSpeed = Helper.Reflection.GetField(bar, "bobberBarSpeed").GetValue();
// amount of green bar filled
distanceFromCatching = Helper.Reflection.GetField(bar, "distanceFromCatching").GetValue();
前三个定义了我们的状态:
这是模型可以在每一帧上可以获取的状态,要将其设置为强化学习问题还需要使用奖励来指导训练。 奖励将是绿色条的填充量,这是里的变量名称为 distanceFromCatching。 这个值的范围从 0 到 1,正好非常适合作为奖励。
Replay Memory
Replay Memory是 Q-learning 中使用的一种技术,用于将训练与特定的“时间”去关联。 所以需要将状态转换存储在缓存中并通过缓存中随机抽取批次来训练模型而不是直接使用最新数据进行训练。 为了训练模型,我们需要 4 个数据,分别是当前状态、下一个状态、采取的行动和奖励:
Q-learning 中关键问题是要获取曾经处于哪个状态和采取了哪些行动、到达哪个新的状态,以及执行这个行动中得到的奖励。有了这些数据,我们可以使用像价值迭代 (Value Iteration 一种动态规划算法)这样的简单算法将奖励从最终状态(获胜状态)开始分析,逐渐往回推直至推至所有状态。因此对于每个可能的状态,模型都会知道最大化其未来回报的方向。 但是我不会使用价值迭代来训练模型,因为真正的问题往往有太多的状态并且动态规划需要很长时间。
上面的价值迭代只是为了说明在 C# 中保存每个条目的方式。 这里使用缓存从最后一帧获取状态和动作,并将所有这些与当前帧的状态和奖励一起存储。
replayMemory[updateCounter,0] = OldState[0];
replayMemory[updateCounter,1] = OldState[1];
replayMemory[updateCounter,2] = OldState[2];
replayMemory[updateCounter,3] = NewState[0];
replayMemory[updateCounter,4] = NewState[1];
replayMemory[updateCounter,5] = NewState[2];
replayMemory[updateCounter,6] = reward;
replayMemory[updateCounter,7] = actionBuffer? 1 : 0;
所有这些数据都变成了一个巨大的 csv 文件,这样可以通过 Python 加载并用于训练 DQN 模型。
DQN 模型
使用神经网络估计 Q-table的 Q-Learning称为Deep Q-Learning。这个方法在很多个 Pytorch 教程中都有很好的解释,我从里面复制了很多代码并为我们的问题对其进行了一些修改。主要思想是使用两个神经网络。一个将估计 Q(s,a) 的值(Policy Net),另一个将估计未来 Q-values的值(Target Net)。然后我们对这两个网络的差异进行反向传播。
这是 Q-Learning算法的基本方程。我们将使用一个网络来估计当前状态 Q(s,a) 的正确值,另一个将估计下一个状态的最大可能值。两个网络都使用随机值进行初始化,并且每隔几次迭代将Policy Net权重复制到Target Net。Policy Net则通过反向传播更新权重 ,通过反向传播这种,Policy Net 最终将学会估计这两个值。
α 是学习率, 是用于选择为 Q 的未来值给出的重要值的折扣因子(discount factor)。强化学习是比较难易理解的所以最后会整理一堆链接,它们会做更好的细节解释。
训练
训练过程是“自我驱动的”,首先要自己玩游戏收集状态和奖励数据,然后训练一个初始化的效果很差的模型让它自动玩游戏,并为我们收集新的数据。然后使用这些数据在 Python 端训练新模型,生成一个新的 ONNX格式模型,该模型将每 1000 帧左右重新加载一次,然后使用新模型继续玩游戏并生成数据来训练新模型。 因为C̶# 必须编译 mod 并将其打包到与游戏可执行文件兼容的 Windows DLL 中,我没有找到一个可以生成正确的 .NET 机器学习框架二进制文件(Stardew Valley 是在 .NET 5 中编译的),所以我放弃了,这里直接用 Python 编写了这部分。
另外一个重要决定是该模型不需要在线训练。 Q-Learning就是要找到函数 Q(s,a) 的良好近似值,即估计在特定状态 s 下执行特定动作 a 的值的函数。所以模型的目的是数据彻底探索这个状态空间,无论是你(人肉)还是模型玩游戏都没有关系,当然如果能够全部自动化拿看起来肯定更加的高大上。
从 C # 中读取 ONNX 模型
C# 端唯一真正的 ML 代码是 ONNX 进行推理(预测),它定义了张量类型和会话的对象,可以发送张量输入并从序列化的 ONNX 模型获取张量输出。 下面的代码非常简单明了。 更新函数在每一帧都运行,并以当前状态作为输入查询训练模型的动作,最后几行只是用于获取模型输出的 argMax一些代码,这是与产生的动作对应的索引。序列化模型的重量只有 120kb 左右,所以运行起来非常轻巧。
public int Update(double[] currentState)
{
Tensor input = new DenseTensor(new[] {3});
input[0] = currentState[0];
input[1] = currentState[1];
input[2] = currentState[2];
// Setup inputs and outputs
var inputs = new List()
{
// the model has only one input, the state tuple
NamedOnnxValue.CreateFromTensor("0", input)
};
using (var results = session.Run(inputs))
{
Tensor outputs = results.First().AsTensor();
var maxValue = outputs.Max();
var maxIndex = outputs.ToList().IndexOf(maxValue);
return maxIndex;
}
}
使用 Harmony 进行输入
SMAPI 缺少的API是能够在游戏中提供输入,因为 99.999% 的mod不需要这样的东西。 为了进行输入我找到了一个名为 Harmony 的 C# 库在可以在运行时更改游戏的内部函数,这样我就可以让游戏以为它收到了鼠标输入。 这就是上面让mode自己玩游戏的方法。 非常感谢 Drynwynn,Mod FishingAutomaton 的作者,我使用了很多代码来设置我的 mod。
[HarmonyPatch(typeof(Game1), "isOneOfTheseKeysDown")]
class IsButtonDownHack
{
// ...
// some important stuff
// ...
// change function return value to true
// makes the game think a mouse left button click ocurred
__result = true;
return;
}
最终结果
目前,该模型可以捕获所有“简单”和“中级”的鱼。 还不能训练它捕捉传说中的鱼。
这个gif是未训练完成的演示
下面是我们训练的结果,效果还不错
资源和引用
非常感谢 Stardew Valley 的mod社区帮助并让我更好地理解游戏:)
C# mod 和 Python 训练的所有代码都可以在这里找到!
https://github.com/ThiagoLira...
下面是一些其他的DQN的相关资源,供参考:
https://www.overfit.cn/post/8f82fe918b644ce58a8e525243db87a4
作者:Thiago Lira