众所周知reinforcement learning Toolbax for matlab是非常强大的,小编刚开始使用时走了很多弯路,有试过一层一层的去找调用的函数等等,看过底层的同学就知道用类做的集成,如果你的面向对象基础知识很牢固大概能看懂这其中的奥秘。小编研究下去的结果就是快吐了,其实没有必要这样。接下来想说下如何快速上手编写强化学习的代码。小编以用DQN训练Cart-Pole为例:
强化学习的主体之一是环境,MATLAB的reinforcement learning toolbax中环境有三种编写方式:
这个部分只需要调用工具箱即可,参考案例为link按照此链接的内容进行编写,此时文中的env = rlPredefinedEnv("CartPole-Discrete")
可以用上部分自己编写的环境代替。然后开始训练即可。
提醒大家一下现在matlab版本更新,如果你用的2019版本的很有可能好多函数不可用,但是19版本的函数2020一定可以用。
clc;
clear;
close all;
ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';
%%%%动作状态空间
ActionInfo = rlFiniteSetSpec([-10 10]);
ActionInfo.Name = 'CartPole Action';
%%创建环境
env = rlFunctionEnv(ObservationInfo,ActionInfo,'myStepFunction','myResetFunction');
%%%%创建DQN
statePath = [
imageInputLayer([4 1 1],'Normalization','none','Name','state')
fullyConnectedLayer(24,'Name','CriticStateFC1')
reluLayer('Name','CriticRelu1')
fullyConnectedLayer(24,'Name','CriticStateFC2')];
actionPath = [
imageInputLayer([1 1 1],'Normalization','none','Name','action')
fullyConnectedLayer(24,'Name','CriticActionFC1')];
commonPath = [
additionLayer(2,'Name','add')
reluLayer('Name','CriticCommonRelu')
fullyConnectedLayer(1,'Name','output')];
criticNetwork = layerGraph(statePath);
criticNetwork = addLayers(criticNetwork, actionPath);
criticNetwork = addLayers(criticNetwork, commonPath);
criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
figure
plot(criticNetwork)
%%%
criticOpts = rlRepresentationOptions('LearnRate',0.01,'GradientThreshold',1);
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
critic = rlRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'state'},'Action',{'action'},criticOpts);
agentOpts = rlDQNAgentOptions(...
'UseDoubleDQN',false, ...
'TargetUpdateMethod',"periodic", ...
'TargetUpdateFrequency',4, ...
'ExperienceBufferLength',100000, ...
'DiscountFactor',0.99, ...
'MiniBatchSize',256);
agent = rlDQNAgent(critic,agentOpts);
trainOpts = rlTrainingOptions(...
'MaxEpisodes', 1000, ...
'MaxStepsPerEpisode', 500, ...
'Verbose', false, ...
'Plots','training-progress',...
'StopTrainingCriteria','AverageReward',...
'StopTrainingValue',480);
doTraining = true;
if doTraining
% Train the agent.
trainingStats = train(agent,env,trainOpts);
else
% Load pretrained agent for the example.
load('MATLABCartpoleDQN.mat','agent');
end
其中myStepFunction、myResetFunction函数参考链接一。弄懂上文中的两个链接的内容离编写一个自己的强化学习程序不远了。用深度强化学习解决迷宫问题后续会发布。