【PyTorch】argparse + os.environ 设置pytorch网络使用的显卡

指定使用的显卡编号

os.environ("CUDA_VISIBLE_DEVICES")='2,3,4'

设置环境变量CUDA_VISIBLE_DEVICES为’2,3,4’,这个时候对于系统来说只有编号2,3,4的显卡是可见的(从0开始)
通过torch.cuda.device_count()获取显卡数量的时候显示的是3,即只能看见这三张显卡

在使用pytorch时,如果需要在gpu上对某些数据进行操作,一般的流程是:

# 获取设备
device = torch.device('cuda:0') # 使用可见gpu中的第0个
# device = torch.device('cpu') 如果使用cpu设备

# 将数据送到设备中
data.to(device)

在上面获取设备的时候,第0个设备获取的其实是编号为2的显卡,同理 torch.device('cuda:2')获取的就是编号为4的显卡。

指定显卡对数据进行操作的方法大概就是这样,也有一些其他的方法,比如

  • 在命令行前加:CUDA_VISIBLE_DEVICES=2 python ...

这里不做赘述

指定显卡不生效

通过 os.environ 指定显卡的时候,经常遇见指定了3号卡训练,但最后还是在0号卡训练的,这个问题最可能的原因是指定显卡和使用torch的顺序问题。

可能的一种错误方式为:

import torch
print(torch,cuda.device_count())

import os
os.environ["CUDA_VISIBLE_DEVICES"]='2'

device = torch.device('cuda:0')
data = torch.tensor([1,2,3])
data.to(device)

在第一次使用torch的时候,系统就会去获取当前可用的设备了,这个时候还没有指定可见的显卡,代码就会加载所有可用的显卡,程序会输出设备数为8,且将data放在0号卡

如果在第一次使用torch前指定显卡,就可以达到想要的效果

import torch

import os
os.environ["CUDA_VISIBLE_DEVICES"]='2'

print(torch,cuda.device_count())

device = torch.device('cuda:0')
data = torch.tensor([1,2,3])
data.to(device)

这个时候程序输出设备数为1,即2号显卡,且data也会放在2号显卡上

所以如果你通过environ指定显卡无法生效,可以尝试看一下自己的代码,是不是在指定显卡之前已经使用了torch的某些函数或功能。

使用argparse在命令行指定显卡

正如上面提到的,可以直接在python命令之前加 CUDA_VISIBLE_DEVICES=1,2 来指定显卡
当然可以使用获取参数时最经常使用的argparse来获取指定的设备号

实现也很简单

import argparse
import os
import torch

parser = argparse.ArgumentParser(description='...')
parser.add_argument('-d', '--device', type=str)
args = parser.parse_args()
if args.device:
	os.environ["CUDA_VISIBLE_DEVICES"] = args.device
# code about torch: 将torch相关的代码写在指定显卡的代码后面

只需要保证torch的使用在指定显卡之后即可,如果你为了保险,将 import torch 也放在指定显卡的代码后面也是可以的。

在使用的时候直接在命令行指定设备即可:

python train.py --device 2

你可能感兴趣的:(深度学习,PyTorch,pytorch,深度学习,python)