pytorch指定GPU

    
    os.environ['CUDA_VISIBLE_DEVICES'] =34//设置用哪个GPU
    # Create model
    model = MODEL(opt) //加载模型
    model = model.cuda()
    model = nn.DataParallel(model, device_ids=None)

nn.DataParallel 函数

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)

[1] - module :待进行并行的模块.
[2] - device_ids : GPU 列表,其值可以是 torch.device 类型,也可以是 int list. 默认使用全部 GPUs.
[3] - output_device : GPUID 或 torch.device. 指定输出的 GPU,默认为第一个,即 device_ids[0].

DataParallel 原理

假设读入一个 batch 的数据,其大小为 [30, 5, 2],假设采用三张 GPUs,其运行过程大致为:
[1] - 将模型放到主 GPU 上,一般为 cuda:0;
[2] - 把模型同步到 3 张 GPUs 上;
[3] - 将总输入 batch 的数据平分为 3 份,这里每一份大小为 [10, 5, 2];
[4] - 依次分别作为每个副本模型的输入;
[5] - 每个副本模型分别独立进行前向计算,假设为 [4, 5, 2];
[6] - 从 3 个 GPUs 中收集分别计算后的结果,并按照次序拼接,即 [12, 5, 2],计算 loss;
[7] - 更新梯度,后向计算.

你可能感兴趣的:(pytorch,深度学习,机器学习)