DQN-FlappyBird学习之main.py解析之argparse函数 shell传参

今天根据周博磊老师的视频推荐,download了DNQ的代码,这篇博文主要是学习python在shell脚本中定义and怎么写出优美的代码(传参)

话不多说,直接上代码

DQN-FlappyBird学习之main.py解析之argparse函数 shell传参

    • import
    • argparse
    • if __name__ == '__main__':
    • train.sh

import

import sys
import argparse
from misc import *
from BrainDQN import *
import torch.cuda

第一个是系统相关的包,第二个就是今天的主角
第三、四个是自己编写的代码文件,import*,一个好处是可以不用看在文中引用了哪些,然后直接写函数就可以
第五个用于检测pytorch下是否可用GPU

argparse

parser = argparse.ArgumentParser(description='DQN demo for flappy bird')
parser.add_argument('--train', action='store_true', default=False,
        help='If set true, train the model; otherwise, play game with pretrained model')
parser.add_argument('--cuda', action='store_true', default=False,
        help='If set true, with cuda enabled; otherwise, with CPU only')
parser.add_argument('--lr', type=float, help='learning rate', default=0.0001)
parser.add_argument('--gamma', type=float,
        help='discount rate', default=0.99)
parser.add_argument('--batch_size', type=int,
        help='batch size', default=32)
parser.add_argument('--memory_size', type=int,
        help='memory size for experience replay', default=5000)
parser.add_argument('--init_e', type=float,
        help='initial epsilon for epsilon-greedy exploration',
        default=1.0)
parser.add_argument('--final_e', type=float,
        help='final epsilon for epsilon-greedy exploration',
        default=0.1)
parser.add_argument('--observation', type=int,
        help='random observation number in the beginning before training',
        default=100)
parser.add_argument('--exploration', type=int,
        help='number of exploration using epsilon-greedy policy',
        default=10000)
parser.add_argument('--max_episode', type=int,
        help='maximum episode of training',
        default=20000)
parser.add_argument('--weight', type=str,
        help='weight file name for finetunig(Optional)', default='')
parser.add_argument('--save_checkpoint_freq', type=int,
        help='episode interval to save checkpoint', default=2000)

这部分就是定义了参数的形式
1、创建一个解析器对象

parser = argparse.ArgumentParser(description='DQN demo for flappy bird')

2、添加参数
这里定义了两种方式
(1)、

parser.add_argument('--train', action='store_true', default=False,
        help='If set true, train the model; otherwise, play game with pretrained model')

这个代码的意义是,定义参数’–train’,action = 'store_true’表示如果在脚本中有这个参数,–train就为真(True),默认为False。
(2)、

parser.add_argument('--lr', type=float, help='learning rate', default=0.0001)

这个的意思是,定义参数’–lr’,默认值为0.0001。如果在命令中有给定值,’–lr’的值就是给定的值。

if name == ‘main’:

if __name__ == '__main__':
    args = parser.parse_args()
    if args.cuda and not torch.cuda.is_available():
        print 'CUDA is not availale, maybe you should not set --cuda'
        sys.exit(1)
    if not args.train and args.weight == '':
        print 'When test, a pretrained weight model file should be given'
        sys.exit(1)
    if args.cuda:
        print 'With GPU support!'
    if args.train:
        model = BrainDQN(epsilon=args.init_e, mem_size=args.memory_size, cuda=args.cuda)
        resume = not args.weight == ''
        train_dqn(model, args, resume)
    else:
        play_game(args.weight, args.cuda, True)

首先是相当于激活,或者赋值

    args = parser.parse_args()

然后是对于True和False的判断

    if args.cuda and not torch.cuda.is_available():
        print 'CUDA is not availale, maybe you should not set --cuda'
        sys.exit(1)
    if not args.train and args.weight == '':
        print 'When test, a pretrained weight model file should be given'
        sys.exit(1)
    if args.cuda:
        print 'With GPU support!'

这代码的意思是
(1)、如果命令需要使用cuda并且pytorch的cuda不可用,就会输出提示,然后执行sys.exit(1)
sys.exit(1)是有通用错误并退出
(2)、如果命令需要执行训练的代码并且没有给定训练好的weight,同样的输出提示,并且因为有错发生退出程序
(3)、如果命令需要使用cuda,并且pytorch的cuda可用,就会输出成功的提示。
(4)、注意,在使用变量的时候直接调用args.变量名,这里没有’–’

最后是在满足以上的条件时,利用,命令所给定的参数进行训练

    if args.train:
        model = BrainDQN(epsilon=args.init_e, mem_size=args.memory_size, cuda=args.cuda)
        resume = not args.weight == ''
        train_dqn(model, args, resume)
    else:
        play_game(args.weight, args.cuda, True)

其中 resume = not args.weight == ''的意思是如果weight有值,resume为真,否则为为假。
当不训练的时候,即当使用预训练模型的时候,就会执行play_game程序。

train.sh

lr=0.0001
gamma=0.99
batch_size=32
mem_size=5000
initial_epsilon=1.
final_epsilon=0.1
observation=100
exploration=50000
max_episode=100000
# for fine tuning, uncomment this
#weight=model_best.pth.tar

python main.py --train\
               --cuda\
               --lr=$lr\
               --gamma=$gamma\
               --batch_size=$batch_size\
               --memory_size=$mem_size\
               --init_e=$initial_epsilon\
               --final_e=$final_epsilon\
               --observation=$observation\
               --exploration=$exploration\
               --max_episode=$max_episode
               #--weight=$weight   # for fine tuning, uncomment this

通过这个脚本实现了参数的传递,通过shell脚本调用main文件,然后在main文件中选择执行相应的训练,运行以及错误退出操作。

你可能感兴趣的:(Reinforcement,Learning,python,Pytorch,python,DNQ,强化学习,reinforcement,learning,机器学习)