【EDSR代码问题】【Pytorch1.1】ImportError: cannot import name '_update_worker_pids' from 'torch._C'

从5月3日开始,恢复Colab上的EDSR模型训练,文件dataloader.py开始报错:

ImportError: cannot import name '_update_worker_pids'

仔细看了下原来是Colab上默认pytorch版本变成了1.1,由于没试过Colab上pytorch版本回滚,而且可预见地认为更换版本会很麻烦(每次都需要改版本),因为尝试修改EDSR代码解决该问题。

看了下pytorch1.1的GitHub源码,将EDSR中的dataloader.py需要更改处整理如下:

1

from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
    _remove_worker_pids, _error_if_any_worker_fails
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data.dataloader import ManagerWatchdog
from torch.utils.data.dataloader import _pin_memory_loop
from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL

from torch.utils.data.dataloader import ExceptionWrapper
from torch.utils.data.dataloader import _use_shared_memory
from torch.utils.data.dataloader import numpy_type_map
from torch.utils.data.dataloader import default_collate
from torch.utils.data.dataloader import pin_memory_batch
from torch.utils.data.dataloader import _SIGCHLD_handler_set
from torch.utils.data.dataloader import _set_SIGCHLD_handler

改为

from torch._C import _set_worker_signal_handlers
from torch.utils.data import _utils
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import _DataLoaderIter

_use_shared_memory = False

一定记住添加最后一行_use_shared_memory = False


2

        watchdog = ManagerWatchdog()

改为

        watchdog = _utils.worker.ManagerWatchdog()

3

            try:
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)

改为

            try:
                r = index_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)

4

            except Exception:
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))

改为

            except Exception:
                data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info())))

5

                    target=_pin_memory_loop,

改为

                    target=_utils.pin_memory._pin_memory_loop,

6

            _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
            _set_SIGCHLD_handler()

改为

            _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
            _utils.signal_handling._set_SIGCHLD_handler()

7

        collate_fn=default_collate

改为

        collate_fn=_utils.collate.default_collate

总结一下其实就是torch1.1发布者将很多之前需要从torch._C,torch.utils.data.dataloader等中引用的函数整合放入了torch.utils.data下的_utils中,从pytorch1.1起只需要from torch.utils.data import _utils即可。RCAN的dataloader.py和EDSR中的稍有差异,按照同样思路修改即可。

 

 


20190506补充:

在前边BUG调试好后又遇到问题:

ValueError: x and y must have same first dimension, but have shapes (59,) and (58,)

解决办法:

将trainer.py中所有的

self.optimizer.get_last_epoch() + 1

改为

self.optimizer.get_last_epoch()

 

你可能感兴趣的:(PyTorch,Pytorch1.1)