Pytorch | GPU | 将代码加载到GPU上运行

Pytorch | GPU | 将代码加载到GPU上运行


“I think if you do something and it turns out pretty good, then you should go do something else wonderful, not dwell on it for too long. Just figure out what’s next.”
——Steve Jobs


❤️流程

  • 声明用GPU(指定具体的卡)
  • 将模型(model)加载到GPU上
  • 把数据和标签放到GPU上

使用GPU的过程中离不开torch.device()

什么是torch.device()?

A torch.device is an object representing the device on which a torch.Tensor is or will be allocated. 就是装torch.Tensor的一个地方。

声明用GPU(指定具体的卡)

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

'cuda' 这里如果没有指定具体的卡号,那么系统默认cuda:0
或者:
如果使用pycharm,程序的运行路径填的就是服务器上的路径,默认torch.cuda.is_available() = Ture,那么就可以省略后面的判断语句。下面的例子是使用了2号卡(从0开始计数)。

device = torch.device('cuda:2')

将模型(model)加载到GPU上

model = resnet19()	#例子中,采用resnet模型
model.to(device)

把数据和标签放到GPU上

data, target = data.to(device), target.to(device)

❤️在多卡上并行计算

  • 方法1:torch.nn.DataParallel()
    torch.nn.DataParallel()具体的过程:大体就是将模型加载的每个卡上,数据平均分到每个卡上,原则上保证batch_size大于卡的数目就行。

      device = torch.device('cuda:2') #device = torch.device("cuda:1" if use_cuda else "cpu")  
      model = resnet19()  
      if torch.cuda.device_count() > 1: #10 
          print(torch.cuda.device_count())  
          model = nn.DataParallel(model, device_ids = [2,3,4])
      model.to(device)   
    

    这段代码运行之后占用的GPU是:0,2,3,4。为什么会占用0??我感到很神奇!!!原来:The parallelized module must have its parameters and buffers on device_ids[0] before running this DataParallel module. 就是说,即使我指定的卡没有0卡,他也会在0卡里面放参数和缓存。
    如何避免这种现象呢?
    改变默认的device_ids[0]指向的卡。默认device_ids[0]指向的就是0卡,只需要通过环境变量,让device_ids[0]指向其他卡片即可。

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3, 4"	#仅有2,3,4(物理上)卡对程序可见,对应的逻辑上的卡号就算0,1,2
    	    ......
    device = torch.device('cuda') #device = torch.device("cuda:1" if use_cuda else "cpu")  
    model = resnet19()  
    if torch.cuda.device_count() > 1: #10 
        print(torch.cuda.device_count())  
        model = nn.DataParallel(model, device_ids = [0,1,2])
    model.to(device)   
    
  • distributedDataparallel
    没有用过,之后学习一下。 参考

参考

参考1
参考2

你可能感兴趣的:(python,pytorch)