Pytorch训练问题:AssertionError: Invalid device id

在Linux中使用显卡训练网络时,一般会通过device id来确定使用的显卡。我们从GitHub上获取的源码中的device id和我们本地的device id肯定不一致,所以训练时一定要注意device id修改。

以下示例:

源码:

model = nn.DataParallel(
        model.cuda(), device_ids=[0,1]
    

  源码中使用了id为0和1 的显卡进行训练。

 本地训练报错:

       AssertionError: Invalid device id

 本地显卡指示:

      CUDA Device count:  1

       本地只有一个显卡,代码中带入了2个id,这时候肯定会报错。修改代码如下:

model = nn.DataParallel(
        model.cuda(), device_ids=[0]
    

注意:

      在使用多显卡进行训练时,一定要注意显卡id设置。如遇问题可以参考:

https://blog.csdn.net/qq_41563394/article/details/106555626

https://www.codeleading.com/article/23452065003/

你可能感兴趣的:(python)