torch.tensordot (a, b, dims = 2)
paddle.tensordot (x, y, axes = 2, name = None)
import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[2,4,3])
# dims = 0
# out_dims = A.dim + B.dim
# out-shape:[3,2,4,2,4,3]
m3 = torch.tensordot(m1,m2,dims=0)
print(m3.shape)
import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[2,4,3])
# axes = 0
# out_dims = A.dim + B.dim
# out-shape:[3,2,4,2,4,3]
m3 = paddle.tensordot(m1,m2,axes=0)
print(m3.shape)
import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[4,5,3])
# dims = 1
# range(-1, 0) => -1
# range(1) => 0
# A-index -1 = B-index 0
# out-shape:[3,2,5,3]
m3 = torch.tensordot(m1,m2,dims=1)
print(m3.shape)
# 举个例子,两者不一定需要形状一致
# out - shape:[3]
m1 = torch.ones(size=[3,2])
m2 = torch.ones(size=[2])
m3 = torch.tensordot(m1,m2,dims=1)
print(m3.shape)
import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[4,5,3])
# axes = 1
# range(-1, 0) => -1
# range(1) => 0
# A-index -1 = B-index 0
# out-shape:[3,2,5,3]
m3 = paddle.tensordot(m1,m2,axes=1)
print(m3.shape)
# 举个例子,两者不一定需要形状一致
# out - shape:[3]
m1 = paddle.ones(shape=[3,2])
m2 = paddle.ones(shape=[2])
m3 = paddle.tensordot(m1,m2,axes=1)
print(m3.shape)
import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[2,4,3])
# dims = 2
# range(-2, 0) => [-2,-1]
# range(2) => [0,1]
# A-index -2 = B-index 0 , # A-index -1 = B-index 1
# out-shape:[3,3]
m3 = torch.tensordot(m1,m2,dims=2)
print(m3.shape)
import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[2,4,3])
# axes = 2
# range(-2, 0) => [-2,-1]
# range(2) => [0,1]
# A-index -2 = B-index 0 , # A-index -1 = B-index 1
# out-shape:[3,3]
m3 = paddle.tensordot(m1,m2,axes=2)
print(m3.shape)
import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[3,5,4])
# dims = (0,2)
# A-index 0 = B-index 0 , A-index 2 = B-index 2
# out-shape:[2,5]
m3 = torch.tensordot(m1,m2,dims=[(0,2),(0,2)])
print(m3.shape)
import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[3,5,4])
# axes = (0,2)
# A-index 0 = B-index 0 , A-index 2 = B-index 2
# out-shape:[2,5]
m3 = paddle.tensordot(m1,m2,axes=(0,2))
print(m3.shape)
import torch
m1 = torch.ones(size=[3,2,4])
m2 = torch.ones(size=[4,2,5])
# dims = [(1,2),(1,0)]
# A-index 1 = B-index 1 , A-index 2 = B-index 0
# out-shape:[3,5]
m3 = torch.tensordot(m1,m2,dims=[(1,2),(1,0)])
print(m3.shape)
import paddle
m1 = paddle.ones(shape=[3,2,4])
m2 = paddle.ones(shape=[4,2,5])
# axes = [(1,2),(1,0)]
# A-index 1 = B-index 1 , A-index 2 = B-index 0
# out-shape:[3,5]
m3 = paddle.tensordot(m1,m2,axes= [(1,2),(1,0)])
print(m3.shape)