b = torch.tensor([1,2,5])
b += torch.tensor(5.0) / torch.add(b,torch.tensor(5.0))
>>> tensor([ 6., 7., 10.])
# 矩阵相乘(最后两位)
torch.mm(a,b) / matmul(a,b) / a@b
# 转置
a.t()
# 幂
a = torch.full([2,2],3)
a.pow(2) / a**2
# 开方
a.sqrt() / a**0.5
# 倒数
a.rsqrt()
# exp + log
a = torch.exp(torch.ones(2,2))
torch.log(a)
# 向下、向上、裁剪、小数
a = torch.tensor(3.14)
a.floor(),a.ceil(),a.trunc(),a.frac()
>>> tensor(3.) tensor(4.) tensor(3.) tensor(0.1400)
# 四舍五入
a.round()
# 针对梯度
a.max() / a.median()
# 限制在>=10
a.clamp(10)
# 限制在(0,10)
a.clamp(0,10)
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
# dim叠加维度
# 除了dim,其他通道必须保持一致
torch.cat([a,b],dim=0).shape
>>> torch.Size([9,32,8])
# stack创建新维度
a = torch.rand(4,32,8)
b = torch.rand(4,32,8)
c = torch.cat([a,b],dim=0).shape
>>> torch.Size([2,4,32,8])
# split长度拆分维度
aa,bb = c.split([1,1],dim=0) / c.split(1,dim=0) # 固定长度
aa.shape/bb.shape
>>> torch.Size([1,4,32,8])
# chunk数量拆分维度
aa,bb = c.chunk(2,dim=0)
aa.shape/bb.shape
>>> torch.Size([1,4,32,8])
# 计算范数
a = torch.arange(8).view(2,4).float()
a.norm(2)
# prod乘积
a.min() / max() / mean() / prod()
a.sum() / a.argmax() / a.argmin()
# 不删减维度
a.max(dim=1, keepdim=True)
# 前几个大的,★values&indices分别对应index[0]&index[1]
pred = torch.randn((4, 5))
index = pred.topk(2, dim=1, largest=True, sorted=True)
>>> tensor([[-2.8971, 1.0144, 1.6376],
[ 0.3696, -0.2189, -0.2812],
[-1.5483, 1.1111, -0.3015],
[-1.3930, -0.8354, 0.6924]])
>>> torch.return_types.topk(
values=tensor([[ 1.6376, 1.0144],
[ 0.3696, -0.2189],
[ 1.1111, -0.3015],
[ 0.6924, -0.8354]]),
indices=tensor([[2, 1],
[0, 1],
[1, 2],
[2, 1]]))
# 第一小的
pred = torch.randn((4, 3))
print(pred)
print(pred.kthvalue(1))
>>> tensor([[-0.6578, -1.3338, -2.4590],
[ 0.5629, -2.0589, -1.7233],
[-0.8854, -0.2503, -1.4455],
[ 0.1539, 0.6368, 2.4213]])
>>> torch.return_types.kthvalue(
values=tensor([-2.4590, -2.0589, -1.4455, 0.1539]),
indices=tensor([2, 1, 2, 0]))
torch.gt(0)
# 返回1、0的列表
torch.eq(a,a)
# 返回一个True
torch.equal(a,a)
# 条件取值
a = torch.zeros([2,2])
b = torch.ones([2,2])
torch.where(data,a,b)
# 按照维度和索引取值
torch.gather()
图解PyTorch中的torch.gather函数 - 知乎