帮助深度学习模型抢占显卡的python3简单脚本

帮助深度学习模型抢占显卡的python3简单脚本

  • 简介
  • 环境
  • 代码

简介

有的服务器是多用户使用,GPU的资源常常被占据着,很可能在夜间GPU空闲了,但来不及运行自己的脚本。如果没有和别人共享服务器的话,自己的多个程序想排队使用GPU,也可以用这个脚本。

1. 脚本设定在执行metric时不可被ctrl+c中断,由此确保每一轮train和metric的连贯性。
2. 当用户指定metric优先时,将追加一轮train和metric。
3. 中断train时会保存当前的ckpt文件到half_result文件夹中,若文件夹不存在,也不会创建,请自行修改。
4. 设定上执行五轮后会问询是否继续。

环境

GPU: 4 ✖ Nvidia 2080Ti (11GB)
OS: CentOS 7
python: 3.6.8

代码

import os
import sys
import random
import time
alert = 1000	#一旦显存低于alert则可被程序占用


def modify_train_py(gpu_id):
    modify_file('train.py', gpu_id)
    print('modify "train.py" seccessfully')


def modify_metric_py(gpu_id):
    modify_file('metric.py', gpu_id)
    print('modify "train.py" seccessfully')


def modify_file(file_name, gpu_id):
    file = open(file_name, 'r')
    textlines = file.readlines()
    file.close()
    textlines[1] = str('common.set_gpu(%d)\n' % gpu_id)
    file = open(file_name, 'w')
    file.writelines(textlines)
    file.close()


def retry_train_run():
    print('retry run metric.py...')
    gpu_id = gpu_monitor()
    modify_train_py(gpu_id)
    flag = os.system('python3 train.py')
    return flag


def retry_metric_run():
    print('retry run metric.py...')
    gpu_id = gpu_monitor()
    modify_metric_py(gpu_id)
    flag = os.system('python3 metric.py')
    return flag


def run(train_or_metric='train'):
    if(train_or_metric == 'metric'):
        flag = os.system('python3 metric.py')
        while (flag != 0):
            flag = retry_metric_run()
    flag = os.system('cp -f ./ckpt/c* ./ckpt/bk/ && python3 train.py')
    if (flag == 0):
        flag = os.system('python3 metric.py')
        while (flag != 0):
            flag = retry_metric_run()
    else:
        print('recovering the ckpt file...')
        os.system('cp -f ./ckpt/c* ./ckpt/bk/half_result/ && python3 train.py')
        os.system('cp -f ./ckpt/bk/c* ./ckpt/')
        for i in range(5):
            flag = retry_train_run()
            if (flag == 0):
                flag = os.system('python3 metric.py')
                while (flag != 0):
                    flag = retry_metric_run()
                break    
            

def gpu_monitor():
    run_flag = False
    print('monitoring the gpu infomation...')
    while(not run_flag):
        sleep_time = random.randint(10,60)
        time.sleep(float(sleep_time))
        nvidia_smi = os.popen('nvidia-smi | grep %').readlines()
        for i in range(4):
            memory_used = int(nvidia_smi[i].split('|')[2].split('M')[0].strip())
            if memory_used < alert:
                run_flag = True
                print('Find a GPU! (GPU id: %d)' % i)
                return i
                

if __name__ == '__main__':
    if len(sys.argv) != 1:
        print('\n**********user specify: %s first!**********\n' % sys.argv[1])
    gpu_id = gpu_monitor()
    modify_train_py(gpu_id)
    modify_metric_py(gpu_id)
    if len(sys.argv) != 1:
        run(sys.argv[1])
    else:
        run()
    epoch = 1
    while(True):
        if epoch%5 == 0:
            flag_continue = input('Continue? (y/n)')
            if (flag_continue == 'n' or flag_continue == 'N'):
                exit(0)
            elif (flag_continue != 'y' and flag_continue != 'Y'):
                print("illegal value: %s" % flag_continue)
                continue
        gpu_id = gpu_monitor()
        modify_train_py(gpu_id)
        modify_metric_py(gpu_id)
        run()
        epoch += 1

Output:

********** user specify: metric first! ********** 
monitoring the gpu infomation... 

你可能感兴趣的:(python)