mmrotate在cpu上运行的方法

修改如下代码:

mmrotate/apis/train.py Line 60

# put model on gpus
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)
    else:
        # model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
        model = MMDataParallel(model.cpu(), device_ids=cfg.gpu_ids)

最后两行 model.cuda()换成model.cpu()

你可能感兴趣的:(python,深度学习,python,机器学习)