有的服务器是多用户使用,GPU的资源常常被占据着,很可能在夜间GPU空闲了,但来不及运行自己的脚本。如果没有和别人共享服务器的话,自己的多个程序想排队使用GPU,也可以用这个脚本。
1. 脚本设定在执行metric时不可被ctrl+c中断,由此确保每一轮train和metric的连贯性。
2. 当用户指定metric优先时,将追加一轮train和metric。
3. 中断train时会保存当前的ckpt文件到half_result文件夹中,若文件夹不存在,也不会创建,请自行修改。
4. 设定上执行五轮后会问询是否继续。
GPU: 4 ✖ Nvidia 2080Ti (11GB)
OS: CentOS 7
python: 3.6.8
import os
import sys
import random
import time
alert = 1000 #一旦显存低于alert则可被程序占用
def modify_train_py(gpu_id):
modify_file('train.py', gpu_id)
print('modify "train.py" seccessfully')
def modify_metric_py(gpu_id):
modify_file('metric.py', gpu_id)
print('modify "train.py" seccessfully')
def modify_file(file_name, gpu_id):
file = open(file_name, 'r')
textlines = file.readlines()
file.close()
textlines[1] = str('common.set_gpu(%d)\n' % gpu_id)
file = open(file_name, 'w')
file.writelines(textlines)
file.close()
def retry_train_run():
print('retry run metric.py...')
gpu_id = gpu_monitor()
modify_train_py(gpu_id)
flag = os.system('python3 train.py')
return flag
def retry_metric_run():
print('retry run metric.py...')
gpu_id = gpu_monitor()
modify_metric_py(gpu_id)
flag = os.system('python3 metric.py')
return flag
def run(train_or_metric='train'):
if(train_or_metric == 'metric'):
flag = os.system('python3 metric.py')
while (flag != 0):
flag = retry_metric_run()
flag = os.system('cp -f ./ckpt/c* ./ckpt/bk/ && python3 train.py')
if (flag == 0):
flag = os.system('python3 metric.py')
while (flag != 0):
flag = retry_metric_run()
else:
print('recovering the ckpt file...')
os.system('cp -f ./ckpt/c* ./ckpt/bk/half_result/ && python3 train.py')
os.system('cp -f ./ckpt/bk/c* ./ckpt/')
for i in range(5):
flag = retry_train_run()
if (flag == 0):
flag = os.system('python3 metric.py')
while (flag != 0):
flag = retry_metric_run()
break
def gpu_monitor():
run_flag = False
print('monitoring the gpu infomation...')
while(not run_flag):
sleep_time = random.randint(10,60)
time.sleep(float(sleep_time))
nvidia_smi = os.popen('nvidia-smi | grep %').readlines()
for i in range(4):
memory_used = int(nvidia_smi[i].split('|')[2].split('M')[0].strip())
if memory_used < alert:
run_flag = True
print('Find a GPU! (GPU id: %d)' % i)
return i
if __name__ == '__main__':
if len(sys.argv) != 1:
print('\n**********user specify: %s first!**********\n' % sys.argv[1])
gpu_id = gpu_monitor()
modify_train_py(gpu_id)
modify_metric_py(gpu_id)
if len(sys.argv) != 1:
run(sys.argv[1])
else:
run()
epoch = 1
while(True):
if epoch%5 == 0:
flag_continue = input('Continue? (y/n)')
if (flag_continue == 'n' or flag_continue == 'N'):
exit(0)
elif (flag_continue != 'y' and flag_continue != 'Y'):
print("illegal value: %s" % flag_continue)
continue
gpu_id = gpu_monitor()
modify_train_py(gpu_id)
modify_metric_py(gpu_id)
run()
epoch += 1
Output:
********** user specify: metric first! **********
monitoring the gpu infomation...