TypeError: forward() missing 1 required positional argument: ‘input_ids‘

用的 pytorch 多GPU的数据并行方法 DataParallel ,这老出错

原 batch_size 我设的 8,用的3块GPU,谷歌到该 github issue

https://github.com/Eromera/erfnet_pytorch/issues/2

然后 batch_size 设为 9 目前能跑通

你可能感兴趣的:(TypeError: forward() missing 1 required positional argument: ‘input_ids‘)