PyTorch 19. PyTorch中相似操作的区别与联系

PyTorch 19. PyTorch中相似操作的区别与联系

  • view() 和 reshape()
    • 总结
  • expand()和repeat()
    • expand()
    • repeat()
  • 乘法操作
    • 二维矩阵乘法 torch.mm()
    • 三维带batch的矩阵乘法torch.bmm()
    • 多维矩阵乘法 torch.matmul()
    • 矩阵逐元素(Element-wise)乘法torch.mul()
    • 两个运算符@和*

view() 和 reshape()

写在开头:
有一篇大佬的总结非常到位:博客

总结

  1. view() 在操作tensor时,需要tensor是内存连续的,而且在进行尺寸变换时,view()操作不会新开辟内存空间。但是要保证tensor连续,对tensor进行tensor.contiguous()时,会开辟新的内存空间,存放内存连续的数据。
  2. reshape()操作,与view()的作用一模一样,但是它比view()更高级,被操作的tensor是内存连续时,直接采用reshape不会开辟新的内存;被操作的tensor不是内存连续时,reshape操作会开辟新的内存,再对tensor进行reshape。
  3. 最后,用reshape操作就完事了

expand()和repeat()

expand()

返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小等于1的维度扩展到更大的尺寸
例子:

import torch
x = torch.tensor([1, 2, 3])
x.expand(2,3)
tensor([[1, 2, 3],
			[1, 2, 3]])

注意 expand()只能扩展维度为1的维数,维数不为1的部分要保持一致

repeat()

沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据
例子

import torch

x = torch.tensor([1, 2, 3])
x.repeat(3,2)
tensor([[1, 2, 3, 1, 2, 3],
		[1, 2, 3, 1, 2, 3],
		[1, 2, 3, 1, 2, 3]])
x2 = torch.randn(2, 3, 4)
x2.repeat(2, 1, 3).shape

torch.Tensor([4, 3, 12])

乘法操作

pytorch中的乘法操作有:torch.mm(), torch.bmm(), torch.matmul(), torch.mul(), 运算符,以及torch.einsum()

二维矩阵乘法 torch.mm()

该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
torch.mm(mat1, mat2, out=None), 其中mat1为(nxm),mat2为(mxd),输出维度是(nxd)

三维带batch的矩阵乘法torch.bmm()

该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。

由于神经网络训练一般采用mini-batch,经常输入的三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None),其中bmat1为(bxnxm),bmat2为(bxmxd),输出out的维度是(bxnxd)

多维矩阵乘法 torch.matmul()

torch,matmul(input, other, out=None)支持broadcast操作
针对多维数据matmul()乘法,可以认为该matmul()乘法使用两个参数的后两个维度计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别为input->(100x500x99x11)other->(500x11x99)那么我们可以认为该乘法首先进行后两位矩阵乘法得到(99x11)x(11x99)->(99,99),然后分析两个参数的batch size分别为(1000x500)500,可以广播为(1000x500),因此最终输出的维度是(1000x500x99x99)

矩阵逐元素(Element-wise)乘法torch.mul()

函数torch.mul(mat1, other, out=None),其中other乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast即可。

两个运算符@和*

  1. @:矩阵乘法,自动执行合适的矩阵乘法函数
  2. *:elemnet-wise乘法

你可能感兴趣的:(Pytorch复习,pytorch,深度学习,python)