今天根据周博磊老师的视频推荐,download了DNQ的代码,这篇博文主要是学习python在shell脚本中定义and怎么写出优美的代码(传参)
话不多说,直接上代码
import sys
import argparse
from misc import *
from BrainDQN import *
import torch.cuda
第一个是系统相关的包,第二个就是今天的主角
第三、四个是自己编写的代码文件,import*,一个好处是可以不用看在文中引用了哪些,然后直接写函数就可以
第五个用于检测pytorch下是否可用GPU
parser = argparse.ArgumentParser(description='DQN demo for flappy bird')
parser.add_argument('--train', action='store_true', default=False,
help='If set true, train the model; otherwise, play game with pretrained model')
parser.add_argument('--cuda', action='store_true', default=False,
help='If set true, with cuda enabled; otherwise, with CPU only')
parser.add_argument('--lr', type=float, help='learning rate', default=0.0001)
parser.add_argument('--gamma', type=float,
help='discount rate', default=0.99)
parser.add_argument('--batch_size', type=int,
help='batch size', default=32)
parser.add_argument('--memory_size', type=int,
help='memory size for experience replay', default=5000)
parser.add_argument('--init_e', type=float,
help='initial epsilon for epsilon-greedy exploration',
default=1.0)
parser.add_argument('--final_e', type=float,
help='final epsilon for epsilon-greedy exploration',
default=0.1)
parser.add_argument('--observation', type=int,
help='random observation number in the beginning before training',
default=100)
parser.add_argument('--exploration', type=int,
help='number of exploration using epsilon-greedy policy',
default=10000)
parser.add_argument('--max_episode', type=int,
help='maximum episode of training',
default=20000)
parser.add_argument('--weight', type=str,
help='weight file name for finetunig(Optional)', default='')
parser.add_argument('--save_checkpoint_freq', type=int,
help='episode interval to save checkpoint', default=2000)
这部分就是定义了参数的形式
1、创建一个解析器对象
parser = argparse.ArgumentParser(description='DQN demo for flappy bird')
2、添加参数
这里定义了两种方式
(1)、
parser.add_argument('--train', action='store_true', default=False,
help='If set true, train the model; otherwise, play game with pretrained model')
这个代码的意义是,定义参数’–train’,action = 'store_true’表示如果在脚本中有这个参数,–train就为真(True),默认为False。
(2)、
parser.add_argument('--lr', type=float, help='learning rate', default=0.0001)
这个的意思是,定义参数’–lr’,默认值为0.0001。如果在命令中有给定值,’–lr’的值就是给定的值。
if __name__ == '__main__':
args = parser.parse_args()
if args.cuda and not torch.cuda.is_available():
print 'CUDA is not availale, maybe you should not set --cuda'
sys.exit(1)
if not args.train and args.weight == '':
print 'When test, a pretrained weight model file should be given'
sys.exit(1)
if args.cuda:
print 'With GPU support!'
if args.train:
model = BrainDQN(epsilon=args.init_e, mem_size=args.memory_size, cuda=args.cuda)
resume = not args.weight == ''
train_dqn(model, args, resume)
else:
play_game(args.weight, args.cuda, True)
首先是相当于激活,或者赋值
args = parser.parse_args()
然后是对于True和False的判断
if args.cuda and not torch.cuda.is_available():
print 'CUDA is not availale, maybe you should not set --cuda'
sys.exit(1)
if not args.train and args.weight == '':
print 'When test, a pretrained weight model file should be given'
sys.exit(1)
if args.cuda:
print 'With GPU support!'
这代码的意思是
(1)、如果命令需要使用cuda并且pytorch的cuda不可用,就会输出提示,然后执行sys.exit(1)。
sys.exit(1)是有通用错误并退出
(2)、如果命令需要执行训练的代码并且没有给定训练好的weight,同样的输出提示,并且因为有错发生退出程序
(3)、如果命令需要使用cuda,并且pytorch的cuda可用,就会输出成功的提示。
(4)、注意,在使用变量的时候直接调用args.变量名,这里没有’–’
最后是在满足以上的条件时,利用,命令所给定的参数进行训练
if args.train:
model = BrainDQN(epsilon=args.init_e, mem_size=args.memory_size, cuda=args.cuda)
resume = not args.weight == ''
train_dqn(model, args, resume)
else:
play_game(args.weight, args.cuda, True)
其中 resume = not args.weight == ''的意思是如果weight有值,resume为真,否则为为假。
当不训练的时候,即当使用预训练模型的时候,就会执行play_game程序。
lr=0.0001
gamma=0.99
batch_size=32
mem_size=5000
initial_epsilon=1.
final_epsilon=0.1
observation=100
exploration=50000
max_episode=100000
# for fine tuning, uncomment this
#weight=model_best.pth.tar
python main.py --train\
--cuda\
--lr=$lr\
--gamma=$gamma\
--batch_size=$batch_size\
--memory_size=$mem_size\
--init_e=$initial_epsilon\
--final_e=$final_epsilon\
--observation=$observation\
--exploration=$exploration\
--max_episode=$max_episode
#--weight=$weight # for fine tuning, uncomment this
通过这个脚本实现了参数的传递,通过shell脚本调用main文件,然后在main文件中选择执行相应的训练,运行以及错误退出操作。