【EDSR代码问题】【Pytorch1.1】ImportError: cannot import name '_update_worker_pids' from 'torch._C'

从5月3日开始,恢复Colab上的EDSR模型训练,文件dataloader.py开始报错:

ImportError: cannot import name '_update_worker_pids'
   
   
   
   

仔细看了下原来是Colab上默认pytorch版本变成了1.1,由于没试过Colab上pytorch版本回滚,而且可预见地认为更换版本会很麻烦(每次都需要改版本),因为尝试修改EDSR代码解决该问题。

看了下pytorch1.1的GitHub源码,将EDSR中的dataloader.py需要更改处整理如下:

1


   
   
   
   
  1. from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
  2. _remove_worker_pids, _error_if_any_worker_fails
  3. from torch.utils.data.dataloader import DataLoader
  4. from torch.utils.data.dataloader import _DataLoaderIter
  5. from torch.utils.data.dataloader import ManagerWatchdog
  6. from torch.utils.data.dataloader import _pin_memory_loop
  7. from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL
  8. from torch.utils.data.dataloader import ExceptionWrapper
  9. from torch.utils.data.dataloader import _use_shared_memory
  10. from torch.utils.data.dataloader import numpy_type_map
  11. from torch.utils.data.dataloader import default_collate
  12. from torch.utils.data.dataloader import pin_memory_batch
  13. from torch.utils.data.dataloader import _SIGCHLD_handler_set
  14. from torch.utils.data.dataloader import _set_SIGCHLD_handler

改为


   
   
   
   
  1. from torch._C import _set_worker_signal_handlers
  2. from torch.utils.data import _utils
  3. from torch.utils.data.dataloader import DataLoader
  4. from torch.utils.data.dataloader import _DataLoaderIter
  5. _use_shared_memory = False

一定记住添加最后一行_use_shared_memory = False


2

        watchdog = ManagerWatchdog()
   
   
   
   

改为

        watchdog = _utils.worker.ManagerWatchdog()
   
   
   
   

3


   
   
   
   
  1. try:
  2. r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)

改为


   
   
   
   
  1. try:
  2. r = index_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)

4


   
   
   
   
  1. except Exception:
  2. data_queue.put((idx, ExceptionWrapper(sys.exc_info())))

改为


   
   
   
   
  1. except Exception:
  2. data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info())))

5

                    target=_pin_memory_loop,
   
   
   
   

改为

                    target=_utils.pin_memory._pin_memory_loop,
   
   
   
   

6


   
   
   
   
  1. _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
  2. _set_SIGCHLD_handler()

改为


   
   
   
   
  1. _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
  2. _utils.signal_handling._set_SIGCHLD_handler()

7

        collate_fn=default_collate
   
   
   
   

改为

        collate_fn=_utils.collate.default_collate
   
   
   
   

总结一下其实就是torch1.1发布者将很多之前需要从torch._C,torch.utils.data.dataloader等中引用的函数整合放入了torch.utils.data下的_utils中,从pytorch1.1起只需要from torch.utils.data import _utils即可。RCAN的dataloader.py和EDSR中的稍有差异,按照同样思路修改即可。

 

 


20190506补充:

在前边BUG调试好后又遇到问题:

ValueError: x and y must have same first dimension, but have shapes (59,) and (58,)
   
   
   
   

解决办法:

将trainer.py中所有的

self.optimizer.get_last_epoch() + 1
   
   
   
   

改为

self.optimizer.get_last_epoch()
   
   
   
   

 

你可能感兴趣的:(pytorch)