pytorch实现 tensorflow.keras.DO函数

pytorch实现 tensorflow DOT

import tensorflow as tf
import numpy as np

x1 = np.arange(4 * 4 * 4).reshape(4, 4, 4)
x2 = np.flip(np.arange(4 * 4).reshape(4, 4), 1).copy()
print(x1.shape, x2.shape)

(4, 4, 4) (4, 4)

#可以发现数据维度对比下面的torch发生了变化,少了维度
dotted = tf.keras.layers.Dot(axes=(1, 1))([x2, x1])
print(dotted)

pytorch实现 tensorflow.keras.DO函数_第1张图片

torch

import torch
x1 = torch.from_numpy(np.arange(4 * 4 * 4).reshape(4, 4, 4))
x2 = torch.from_numpy(np.flip(np.arange(4 * 4).reshape(4, 4), 1).copy())
dotted = torch.tensordot(x2, x1, dims=([1], [1]))
dotted

pytorch实现 tensorflow.keras.DO函数_第2张图片

解决方法

coo =[]
for i in range(len(dotted)):
    coo.append(dotted[i][i])
c00 = torch.stack(coo)
c00

pytorch实现 tensorflow.keras.DO函数_第3张图片

你可能感兴趣的:(pytorch)