强化学习matlab工具箱应用

1. 如何使用强化学习强大的工具箱编写自己的工程

众所周知reinforcement learning Toolbax for matlab是非常强大的,小编刚开始使用时走了很多弯路,有试过一层一层的去找调用的函数等等,看过底层的同学就知道用类做的集成,如果你的面向对象基础知识很牢固大概能看懂这其中的奥秘。小编研究下去的结果就是快吐了,其实没有必要这样。接下来想说下如何快速上手编写强化学习的代码。小编以用DQN训练Cart-Pole为例:

2.了解和编程步骤

2.1环境的编写

强化学习的主体之一是环境,MATLAB的reinforcement learning toolbax中环境有三种编写方式:

  1. 预定义的模型(是MATLAB内部集成好的模型,可以拿来学习下强化学习)
  2. M函数编写
  3. simlink仿真模型
    其中后两种主要是为用户自定义的环境服务的,小编主要分享第二种,请详细阅读link

2.2 算法部分

这个部分只需要调用工具箱即可,参考案例为link按照此链接的内容进行编写,此时文中的env = rlPredefinedEnv("CartPole-Discrete")可以用上部分自己编写的环境代替。然后开始训练即可。
提醒大家一下现在matlab版本更新,如果你用的2019版本的很有可能好多函数不可用,但是19版本的函数2020一定可以用。

代码

clc;
clear;
close all;
ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';
%%%%动作状态空间
ActionInfo = rlFiniteSetSpec([-10 10]);
ActionInfo.Name = 'CartPole Action';
%%创建环境
env = rlFunctionEnv(ObservationInfo,ActionInfo,'myStepFunction','myResetFunction');
%%%%创建DQN
statePath = [
    imageInputLayer([4 1 1],'Normalization','none','Name','state')
    fullyConnectedLayer(24,'Name','CriticStateFC1')
    reluLayer('Name','CriticRelu1')
    fullyConnectedLayer(24,'Name','CriticStateFC2')];
actionPath = [
    imageInputLayer([1 1 1],'Normalization','none','Name','action')
    fullyConnectedLayer(24,'Name','CriticActionFC1')];
commonPath = [
    additionLayer(2,'Name','add')
    reluLayer('Name','CriticCommonRelu')
    fullyConnectedLayer(1,'Name','output')];
criticNetwork = layerGraph(statePath);
criticNetwork = addLayers(criticNetwork, actionPath);
criticNetwork = addLayers(criticNetwork, commonPath);    
criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
figure
plot(criticNetwork)
%%%
criticOpts = rlRepresentationOptions('LearnRate',0.01,'GradientThreshold',1);
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
critic = rlRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'state'},'Action',{'action'},criticOpts);
agentOpts = rlDQNAgentOptions(...
    'UseDoubleDQN',false, ...    
    'TargetUpdateMethod',"periodic", ...
    'TargetUpdateFrequency',4, ...   
    'ExperienceBufferLength',100000, ...
    'DiscountFactor',0.99, ...
    'MiniBatchSize',256);
agent = rlDQNAgent(critic,agentOpts);
trainOpts = rlTrainingOptions(...
    'MaxEpisodes', 1000, ...
    'MaxStepsPerEpisode', 500, ...
    'Verbose', false, ...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',480); 
doTraining = true;
if doTraining    
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
else
    % Load pretrained agent for the example.
    load('MATLABCartpoleDQN.mat','agent');
end

其中myStepFunction、myResetFunction函数参考链接一。弄懂上文中的两个链接的内容离编写一个自己的强化学习程序不远了。用深度强化学习解决迷宫问题后续会发布。

你可能感兴趣的:(强化学习)