Pytorch 使用经验分享&知识点总结

在科研过程中总结的一些琐碎的pytorch相关知识点。

目录

  • 1. 数据加载
  • 2. 数据操作
  • 3. 模型操作
    • 3.1 模式切换
    • 3.2 梯度更新
    • 3.3 模型保存与加载


1. 数据加载

  • 锁页内存(pin_memory)是决定数据放在锁业内存还是硬盘的虚拟内存中,默认值为 False。如果设置为True,则表示数据放在锁业内存中。注意:显卡中的内存全部是锁页内存,所以放在锁页内存中可以加快读取速度。当计算机内存充足时,可将该值设置为 True。这一参数一般在 data_loader() 函数中设置。

  • num_worker 取值最好是 2的幂次方-1, 如 0,1,3,7 等,因为会自动加 1. 默认值为 1. 这一参数一般在 data_loader() 函数中设置。

  • GPU 利用率为低是因为显卡在等数据,解决办法(1)优化 data_loader() 函数;(2)增大batch size 等


2. 数据操作

  • pytorch两个基本对象:Tensor(张量)和 Variable(变量)。
  • torch.Tensortorch.tensor 的区别:
    torch.Tensor(data):将数据转化torch.FloatTensor类型。
    torch.tensor(data):根据数据类型或者dtype参数值将数据转化为torch.FloatTensor、torch.LongTensor、torch.DoubleTensor等类型。
  • torch.contiguous() :类似于 C++ 中的深拷贝。详解见此篇博客。
  • torch.stack() 作用:用于连接大小相同的张量,并扩展维度,类比 torch.cat(). 注意:在哪个维度上操作,就将 dim 设置为哪个维度。 详解见此篇博客。
  • 使用 torch.zeros() 创建的张量默认在 CPU 上,如要在 GPU 上使用记得进行数据转移。
  • 解决 torch 对象打印时有省略号的问题:torch.set_printoptions(threshold=np.inf),该命令多用于打印完整日志。
  • numpy 类型数据只能在 CPU 上运行。注意数据在torch类型与numpy类型间相互转换时数据的存放位置(如:不能将GPU上的张量数据直接转化为numpy类型数据)。

3. 模型操作

3.1 模式切换

  • model.eval()model.train() 区别在于是否启用 归一化层 + dropout,前者不启用,后者启用。

3.2 梯度更新

  • Module中的层在定义时,相关Variablerequires_grad参数默认是True。而用户手动定义Variable时,参数requires_grad默认值是False,volatile值也默认为False。volatile的优先级比requires_grad高,volatile属性为True的节点不会求导(所以可以在测试阶段设置为 True)。 如果要修改可使用 variable_name.require_grad_(True) 实现。

  • 反向传播中梯度回传与更新的实现三步走: (1)optimizer.zero_grad()(梯度清零)(2)loss.backward()(梯度回传)(3)optimizer.step()(梯度更新)

  • model.zero_grad ()optimizer.zero_grad () 使用区别:当 optimizer = optim.Optimizer (net.parameters ()),即网络中参数均未冻结,全部需要更新时,二者等效,其中Optimizer可以是Adam、SGD等优化器;若网络中部分参数被冻结或多个网络共用同一个优化器,则二者不等价。详解见此篇博客。

  • with torch.no_grad() 作用:停止autograd模块的工作,以起到加速和节省显存的作用。一般用在验证和测试阶段。注意:新版本Pytorch中,volatile 已被弃用,需替换为:with torch.no_grad().

3.3 模型保存与加载

  • torch.save(model, path) :将训练好的模型 model 保存至 path 路径下。
  • torch.load(model_path, map_location):将给定路径的预训练模型加载至指定设备上,详解见此篇博客。

参考资料

  • Pytorch中contiguous()函数理解_.contiguous()_清晨的光明的博客-CSDN博客
  • pytorch拼接函数:torch.stack()和torch.cat()–详解及例子_python torch拼接_紫芝的博客-CSDN博客
  • pytorch之model.zero_grad() 与 optimizer.zero_grad()_models.zero_grad()_旺旺棒棒冰的博客-CSDN博客
  • Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()_宁静致远*的博客-CSDN博客

你可能感兴趣的:(碎片笔记,问题清除指南,深度学习,python,pytorch)