PyTorch实战(1):矩阵运算

常见的矩阵运算:基于pytorch和numpy实现

import torch
import numpy as np
import torch.nn as nn

X = torch.rand(4, 2, 3)
linear = nn.Linear(3, 2)

"""矩阵加法"""
# Z=W+X
W = torch.randn(4, 2, 3)

Z11 = W + X
print(Z11)
Z12 = W.add(X)
print(Z12)

"""矩阵乘积"""
# Z = WX -线性变换
Z2 = linear(X)
print(Z2)

# Z = W⊗X -矩阵外积 或 矩阵乘法[m,n]⊗[n,p]==[m,p]
W2 = torch.randn(3, 3, 4)
Z31 = torch.mm(X, W2)
print(Z31)
Z32 = np.outer(X, W2)
print(Z32)
Z33 = np.dot(X, W2)
print(Z33)
Z34 = torch.dot()  # 仅能用于向量运算[1,n]⊗[n,1]
print(Z34)

# Z = W⊙X -矩阵元素积:即对应位置元素相乘[m,n]⊙[m,n]==[m,n]
Z41 = torch.mul(X, W)
print(Z41)
Z43 = W * X  # 若W和X是torch.tensor,同样是元素积
print(Z43)
Z42 = np.multiply(X, W)
print(Z42)

补充:

numpy:数组和矩阵的乘法

(1)当为Array的时候:

默认X * W是元素积,multiply是元素积,dot(X,W)是矩阵乘法, dot点乘意味着相加,而multiply只是对应元素相乘,不相加

(2)当为Matrix的时候:

默认X * W是矩阵乘法(区别数组),multiply为元素积,dot(X,W)是矩阵乘法

(3)混合时候的情况:

【一般不要混合】若出现该情况,则默认X * W是矩阵乘法, multiply为元素积,dot(X,W)是矩阵乘法

喜欢的friend,可以关注我的【知乎专栏:Deep Learning for NLP

你可能感兴趣的:(PyTorch,神经网络,科学计算工具)