Pytorch & Paddle tensordot 函数应用

tensordot 参数

torch.tensordot (a, b, dims = 2)  

paddle.tensordot (x, y, axes = 2, name = None) 

  • a / x : 类型为float32或者float64的tensor
  • b / y: 和a有相同的type,即张量同类型,但不要求同维度
  • dims / axes: 可以为int32,也可以是list,为int32,表示取a的最后几个维度,与b的前面几个维度相乘,再累加求和,消去(收缩)相乘维度 为list,则是指定a的哪几个维度与b的哪几个维度相乘,消去(收缩)这些相乘的维度
  • name : 操作名

环境

  • python 3.6.8
  • pytorch 1.7
  • paddle 2.3.0

dims / axes = 0

torch

import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[2,4,3])
# dims = 0
# out_dims = A.dim + B.dim
# out-shape:[3,2,4,2,4,3]
m3 = torch.tensordot(m1,m2,dims=0)
print(m3.shape)

paddle

import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[2,4,3])
# axes = 0
# out_dims = A.dim + B.dim
# out-shape:[3,2,4,2,4,3]
m3 = paddle.tensordot(m1,m2,axes=0)
print(m3.shape)

dims / axes = 1

torch

import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[4,5,3])
# dims = 1
# range(-1, 0) => -1
# range(1) => 0
# A-index -1 = B-index 0
# out-shape:[3,2,5,3]
m3 = torch.tensordot(m1,m2,dims=1)
print(m3.shape)

# 举个例子,两者不一定需要形状一致
# out - shape:[3]
m1 = torch.ones(size=[3,2])
m2 = torch.ones(size=[2])
m3 = torch.tensordot(m1,m2,dims=1)
print(m3.shape)

paddle

import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[4,5,3])
# axes = 1
# range(-1, 0) => -1
# range(1) => 0
# A-index -1 = B-index 0
# out-shape:[3,2,5,3]
m3 = paddle.tensordot(m1,m2,axes=1)
print(m3.shape)

# 举个例子,两者不一定需要形状一致
# out - shape:[3]
m1 = paddle.ones(shape=[3,2])
m2 = paddle.ones(shape=[2])
m3 = paddle.tensordot(m1,m2,axes=1)
print(m3.shape)

dims / axes = 2

torch

import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[2,4,3])
# dims = 2
# range(-2, 0) => [-2,-1]
# range(2) => [0,1]
# A-index -2 = B-index 0 , # A-index -1 = B-index 1
# out-shape:[3,3]
m3 = torch.tensordot(m1,m2,dims=2)
print(m3.shape)

paddle

import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[2,4,3])
# axes = 2
# range(-2, 0) => [-2,-1]
# range(2) => [0,1]
# A-index -2 = B-index 0 , # A-index -1 = B-index 1
# out-shape:[3,3]
m3 = paddle.tensordot(m1,m2,axes=2)
print(m3.shape)

dims / axes = tuple()

torch(不支持,取巧实现)

import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[3,5,4])
# dims = (0,2)
# A-index 0 = B-index 0 , A-index 2 = B-index 2
# out-shape:[2,5]
m3 = torch.tensordot(m1,m2,dims=[(0,2),(0,2)])
print(m3.shape)

paddle

import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[3,5,4])
# axes = (0,2)
# A-index 0 = B-index 0 , A-index 2 = B-index 2
# out-shape:[2,5]
m3 = paddle.tensordot(m1,m2,axes=(0,2))
print(m3.shape)

dims / axes = List [ tuple(),tuple() ]

torch

import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[4,2,5])
# dims = [(1,2),(1,0)]
# A-index 1 = B-index 1 , A-index 2 = B-index 0
# out-shape:[3,5]
m3 = torch.tensordot(m1,m2,dims=[(1,2),(1,0)])
print(m3.shape)

paddle

import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[4,2,5])
# axes = [(1,2),(1,0)]
# A-index 1 = B-index 1 , A-index 2 = B-index 0
# out-shape:[3,5]
m3 = paddle.tensordot(m1,m2,axes= [(1,2),(1,0)])
print(m3.shape)

你可能感兴趣的:(pytorch,paddle,pytorch,paddle)