Pytorch学习笔记(四)各组件的Debug记录

torch.nn

DataParallel详解

定义:model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

代码流程

下面主要是详细介绍一下数据并行的流程,自己动手Debug一下更清晰!

  1. 在调用网络模型model(DataParallel类实例)时,实际是进入模型基类Module的特殊方法__call__(module.py)。
    注: 特殊方法__call__用于实现对类实例的调用。
  2. 再通过模型基类Module的self.forward进入DataParallel类的forward方法(data_parallel.py):
    1. 先调用self.scatter在第一个维度分配输入
    2. 调用self.replicate产生模型副本放置在多个GPU上,形成modules列表
    3. 调用parallel_apply执行并行操作(parallel_apply.py)
  3. parallel_apply中通过多线程模块threading,将不同的module(自定义的网络模型类实例),input,GPU_ID以及kwargs分配给不同的线程。通过for循环控制启动线程活动和等待至线程中止。
  4. 通过多线程的run()方法进入线程目标函数_worker中,调用module(自定义的网络模型类实例),返回此次线程的结果并存储在指定字典results中。
    注:调用module时,同样也是先进入模型基类Module的特殊方法__call__,再通过模型基类Module的self.forward进入到你自定义网络模型类的forward方法。
  5. 若子线程均执行完毕又回到parallel_apply,对结果字典results的值进行异常判断并以列表形式返回上一层,DataParallel类的forward方法。
  6. 在DataParallel类的forward方法中调用self.gather从指定设备上的不同GPU收集变量并返回上一层,模型基类Module的特殊方法__call__,最后返回最初网络模型调用之处。

遗留的问题

  1. 假设调用网络模型时输入变量的batch_size为4,在模型基类Module的特殊方法__call__的输入就变成了一个元组含两个变量,也就是把一个batch一分为二,输入数据分块操作是怎么完成的?具体在哪里完成?

你可能感兴趣的:(PyTorch)