Dim1 的标量
Dim是size的长度,size\shape是tensor的形状,tensor指的是矩阵中具体的数值
a.numel()
.numel()返回tensor的内存大小
a.dim()返回长度
import torch
a=torch.randn(2,3)
a.type()
type(a)
##合法化检验
isinstance(a,torch.FloatTensor)
##在CPU上部署,不属于GPU(GPU需要加cuda)
isinstance(a,torch.cuda.DoubleTensor)
##在GPU上部署
data=a.cuda()
isinstance(a,torch.cuda.DoubleTensor)
##Dim1
#标量,常用来计算误差Loss
torch.tensor(1.)
torch.tensor(1.3)
#Dim0
a=torch.tensor(2.2)
a.shape
len(a.shape)
a.size()
Dim1的向量的size\shape
#Dim1的shape或size
e=torch.ones(2)
f=print(e.shape)
Dim2的tensor
##Dim2
g=torch.randn(2,3)
print(g)
print(g.shape)
##第一个元素
print(g.size(0))
##第二个元素
print(g.size(1))
Dim3
##Dim3 随机均匀分布
h=torch.rand(1,2,3)
print(h)
print(h[0])
print(h.shape)
list(h.shape)
使用场景:RNN Input Batch
Dim4 [b,c,h,w] 使用场景:图片 CNN
##Dim4
i=torch.rand(2,3,28,28)
##彩色图片1张照片,3个通道,size28*28
print(i)
print(i.shape)
方法一Import from numpy
从numpy导入的float其实是double类型
import numpy as np
import torch
a=np.array([2,3.3])
a=torch.from_numpy(a)
print(a)
##矩阵
b=np.ones([2,3])
b=torch.from_numpy(b)
print(b)
tensor([2.0000, 3.3000], dtype=torch.float64)
tensor([[1., 1., 1.],
[1., 1., 1.]], dtype=torch.float64)
方法二 Import from List
数据量不大
tensor接受现有的数据
Tensor\FloatTensor接受数据的维度
Tensor代表默认类型torch中的默认类型为FloatTensor
增强学习一般使用double,其他一般使用float
Tensor接受现有的数据,需要用([ ])少用
接受数据维度(d1,d2,d3)常用
#import from list
c=torch.tensor([2.,3.2])
print(c)
d=torch.FloatTensor([2.,3.2])
print(d)
e=torch.tensor([[2.,3.2],[1.,23.369]])
print(e)
tensor([2.0000, 3.2000])
tensor([2.0000, 3.2000])
tensor([[ 2.0000, 3.2000],
[ 1.0000, 23.3690]])
2.生成未初始化的数据 uninitialized
只作为容器,后续写数据,不然会出现奇奇怪怪的数据
torch.empty()
torch.FloatTensor(d1,d2,d3) 推荐使用,通常作为载体
PS:torch.FloatTensor([2,3])=torch.tensor([2,3])
torch.IntTensr(d1,d2,d3)
##uninitialized
f=torch.empty(1)
g=torch.Tensor(2,4)
h=torch.IntTensor(2,3)
i=torch.FloatTensor(2,3)
print(f)
print(g)
print(h)
print(i)
tensor([nan])
tensor([[0.0000e+00, 0.0000e+00, 2.8026e-45, 0.0000e+00],
[1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00]])
tensor([[0, 0, 2],
[0, 1, 0]], dtype=torch.int32)
tensor([[ 0.0000e+00, 2.0000e+00, -1.0842e-19],
[ 2.1500e+00, 0.0000e+00, 1.8750e+00]])
3.set default type
Tensor: FloatTensor IntTensor
Tensor代表默认类型torch中的默认类型为FloatTensor
增强学习一般使用double,其他一般使用float
##set default type
k=torch.tensor([1.2,3])
print(k.type())
torch.set_default_tensor_type(torch.DoubleTensor)
l=torch.tensor((1.2,3))
print(l.type())
torch.FloatTensor
torch.DoubleTensor
4.随机初始化 rand/rand_like, randint
rand [0,1] 不包括1
rand_like 接受一个tensor,再生成rand
randint 均匀采样0~10的tensor
要用x=10*torch.rand(d1,d2) randint只能采样整数额
torch.randint(min,max,[d1,d2])
##随机生成
m=torch.rand(3,3)
print(m)
n=torch.rand_like(m)
print(n)
o=torch.randint(1,10,[3,3])
print(o)
tensor([[0.5955, 0.8998, 0.2884],
[0.7077, 0.3734, 0.6439],
[0.7762, 0.7977, 0.2372]])
tensor([[0.9584, 0.1130, 0.0939],
[0.2921, 0.6334, 0.6916],
[0.7637, 0.8667, 0.4734]])
tensor([[7, 6, 2],
[3, 4, 7],
[5, 4, 3]])
正态分布
##正态分布
p=torch.randn(3,3)
print(p)
##自定义均值和方差 【3,3】
##自定义均值和方差 【3,3】
q=torch.normal(mean=torch.full([10],0.),std=torch.arange(1,0,-0.1),)
print(q)
tensor([[ 1.6483, 1.5185, 1.0172],
[ 0.0258, 0.6111, 1.5577],
[-1.1258, 1.7654, 0.7265]])
tensor([ 0.9079, -1.1444, 0.8479, 0.1685, -0.3165, 0.1183, -0.0872, 0.0806,
-0.0075, 0.0147])
full 全部赋值
##full
r=torch.full([2,3],7,dtype=torch.long)
print(r)
##full生成标量
s=torch.full([],7,dtype=torch.long)
print(s)
##带括号的向量元素
t=torch.full([1],7,dtype=torch.long)
print(t)
tensor([[7, 7, 7],
[7, 7, 7]])
tensor(7)
tensor([7])
递增递减生成等差数列
##arrange\range递增递减
u=torch.arange(0,10)
print(u)
v=torch.arange(0,10,2)
print(v)
##range不建议使用
w=torch.range(0,10)
print(w)
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 2, 4, 6, 8])
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
生成等分linspace/logspace
linspace包含尾部
logspace的base参数可以使设置为2,10,e等底数,返回*10的值
##生成等分
x=torch.linspace(0,10,4)
print(x)
y=torch.logspace(0,10,4)
print(y)
tensor([ 0.0000, 3.3333, 6.6667, 10.0000])
tensor([1.0000e+00, 2.1544e+03, 4.6416e+06, 1.0000e+10])
生成全部是0或1
##Ones/zeros/eye
z=torch.ones(3,3)
print(z)
A=torch.zeros(3,3)
print(A)
B=torch.eye(3,3)
print(B)
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
随机打散randperm
numpy中用random.shuffle
##随机打散
C=torch.randperm(10)
print(C)
##协同打散,两个参数打散的顺序是一样
D=torch.rand(2,3)
E=torch.rand(2,2)
idx=torch.randperm(2)
print(D[idx])
print(E[idx])
tensor([1, 3, 0, 2, 9, 7, 5, 4, 8, 6])
tensor([[0.4215, 0.8444, 0.9073],
[0.7867, 0.2199, 0.9827]])
tensor([[0.8101, 0.3268],
[0.9829, 0.5908]])
1.索引Indexing
从左边开始索引
##索引与切片
##图片
F=torch.rand(4,3,28,28)
#第一张图片的维度
print(F[0].shape)
#第一张图片的第1个通道
print(F[0,0].shape)
#第一张图片第一个通道,2行4列的像素点
print(F[0,0,2,4])
torch.Size([3, 28, 28])
torch.Size([28, 28])
tensor(0.5149)
取连续片段 select first/last N
图片R、G、B通道
##select first\last N
##第1张和第2张图片,包头不包尾
print(F[:2].shape)
##第一张,第二张图片,第1个通道
print(F[:2,:1,:,:].shape)
##第1、第2张图片,从后面数两个通道
print(F[:2,1:,:,:].shape)
##第1、第2张图片,从倒数第1个通道
print(F[:2,-1:,:,:].shape)
torch.Size([2, 3, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 2, 28, 28])
torch.Size([2, 1, 28, 28])
选取不连续,其中有一定的间隔 select by index
0:28: 等同于 0:28:1
: 当前的维度全取
n: \ :n n→最后 \ 0→n
:: 从头到尾,包头不包尾
0:28:2 0→28,步长2
通用方式:startstep
##连续有间隔
print(F[:,:,0:28:2,0:28:2].shape)
print(F[:,:,::2,::2].shape)
torch.Size([4, 3, 14, 14])
torch.Size([4, 3, 14, 14])
具体的索引select by specific index
#第一个维度,第二张、第三张图片,
print(F.index_select(0,torch.tensor([0,2])).shape)
#4张图片,第二个通道,list→tensor
print(F.index_select(1,torch.tensor([1,2])).shape)
#所有的行
print(F.index_select(2,torch.arange(28)).shape)
#0-7行
print(F.index_select(2,torch.arange(8)).shape)
torch.Size([2, 3, 28, 28])
torch.Size([4, 2, 28, 28])
torch.Size([4, 3, 28, 28])
torch.Size([4, 3, 8, 28])
任意多的维度…
有…出现时,右边的索引需要理解为最右边,为了方便
a[0,…]=a[0]
print(F[...].shape)
print(F[0,...,::2].shape)
print(F[0,:,:,::2].shape)
print(F[...,:2].shape)
print(F[:,:,:,:2].shape)
torch.Size([4, 3, 28, 28])
torch.Size([3, 28, 14])
torch.Size([3, 28, 14])
torch.Size([4, 3, 28, 2])
torch.Size([4, 3, 28, 2])
select by mask
.masked_select(),mask会把数据打平,因为大于0.5的元素个数是根据内容才能确定的
ByteTensor类型
##mask
print('$$'*10)
G=torch.randn(3,4)
print(G)
#将大于0.5的数据替换为1,其他为0
mask=G.ge(0.5)
print(mask)
I=torch.masked_select(G,mask)
print(I)
print(I.shape)
tensor([[-0.6143, 1.0364, 0.1573, -0.4340],
[-0.6342, 0.1852, 0.5040, 0.0891],
[ 0.2618, 0.8350, 1.4133, 2.6830]])
tensor([[False, True, False, False],
[False, False, True, False],
[False, True, True, True]])
tensor([1.0364, 0.5040, 0.8350, 1.4133, 2.6830])
torch.Size([5])
把数组打平,取打平后的编码select by flatten index
使用少
##take
src=torch.tensor([[4,3,5],[6,7,8]])
print(src)
ss=torch.take(src,torch.tensor([0,2,5]))
print(ss)
tensor([[4, 3, 5],
[6, 7, 8]])
tensor([4, 5, 8])
1.常用的API
view/reshape
类似numpy中的reshape,改变shape
squeeze/unsqueeze
挤压删减、增加维度
transpose/t/permute
转置、单次交换、多次交换
expand/repeat
拓展
view&reshape,可以通用
view:前提,保证numel()一致即可
J=torch.rand(4,1,28,28)
print(J.shape)
#将通道和长宽打通在一起,常用作全连接层
K=J.view(4,28*28)
print(K)
print(K.shape)
torch.Size([4, 1, 28, 28])
tensor([[0.4228, 0.2521, 0.7874, …, 0.0730, 0.1397, 0.2620],
[0.3582, 0.2127, 0.5870, …, 0.4519, 0.1197, 0.6349],
[0.0022, 0.8474, 0.6384, …, 0.7844, 0.4403, 0.5294],
[0.0521, 0.6313, 0.8732, …, 0.1590, 0.4304, 0.9521]])
torch.Size([4, 784])
##将前三个通道合在一起,4*1*28,只关注所有行的数据信息
L=J.view(4,28*28)
print(L.shape)
##只关注来自哪个照片
M=J.view(4*28,28)
print(M.shape)
##丢失数据信息
N=J.view(4,784)
print(N.shape)
torch.Size([4, 784])
torch.Size([112, 28])
torch.Size([4, 784])
squeeze&unsqueeze
插入unsqueeze 取值范围 [-a.dim()-1,a.dim()+1]
数组顺序 [0 1 2 3 4 …… n]
倒序 [-n,…… -3 -2 -1]
包头不包尾
O=torch.rand(4,1,28,28)
print(O.shape)
print(O.unsqueeze(0).shape)
print(O.unsqueeze(-1).shape)
print(O.unsqueeze(4).shape)
print(O.unsqueeze(-4).shape)
print(O.unsqueeze(-5).shape)
torch.Size([4, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 1, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
P=torch.tensor([1.2,2.3])
print(P)
print(P.shape)
Q=P.unsqueeze(-1)
print(Q)
print(Q.shape)
R=P.unsqueeze(0)
print(R)
print(R.shape)
tensor([1.2000, 2.3000])
torch.Size([2])
tensor([[1.2000],
[2.3000]])
torch.Size([2, 1])
tensor([[1.2000, 2.3000]])
torch.Size([1, 2])
For example
bias相当于给每个channel上的所有像素增加一个偏置
bias=32
S=torch.rand(32)
T=torch.rand(4,32,14,14)
S=S.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(S.shape)
后续讲扩张,在于T相加、
torch.Size([1, 32, 1, 1])
squeeze
##不给参数能挤压的全部挤压
print(S.squeeze().shape)
print(S.squeeze(0).shape)
print(S.squeeze(-1).shape)
#由于第二个位置不为1,所以不能挤压,返回原来的样子
print(S.squeeze(1).shape)
print(S.squeeze(-4).shape)
torch.Size([32])
torch.Size([32, 1, 1])
torch.Size([1, 32, 1])
torch.Size([1, 32, 1, 1])
torch.Size([32, 1, 1])
Expand/repeat
Expand:broadcasting 扩展,并没有增加数据 节约内存,执行速度快,推荐
repeat:memory copied 复制,增加数据,会主动复制数据,不推荐
扩展的前提,dim一致
U=torch.rand(4,32,14,14)
print(U.shape)
V=torch.rand(1,32,1,1)
print(V.shape)
##只能拓展原来为1的地方
V=V.expand(4,32,14,14)
print(V.shape)
##为-1时,表示此处不做拓展
V=V.expand(-1,32,-1,-1)
print(V.shape)
torch.Size([4, 32, 14, 14])
torch.Size([1, 32, 1, 1])
torch.Size([4, 32, 14, 14])
torch.Size([4, 32, 14, 14])
repeat
在某一个维度上拷贝多少次
W=torch.rand(1,32,1,1)
print(W.repeat(4,32,1,1).shape)
print(W.repeat(4,1,1,1).shape)
print(W.repeat(4,1,32,32).shape)
torch.Size([4, 1024, 1, 1])
torch.Size([4, 32, 1, 1])
torch.Size([4, 32, 32, 32])
transpose/t/permute
转置操作
.t 只适用于矩阵
X=torch.randn(3,4)
print(X)
print(X.t())
tensor([[ 0.7207, -1.3770, -0.7291, -1.1218],
[-0.3498, -1.6037, -2.2684, -0.4572],
[ 0.6358, 1.4500, 0.4431, 0.2253]])
tensor([[ 0.7207, -0.3498, 0.6358],
[-1.3770, -1.6037, 1.4500],
[-0.7291, -2.2684, 0.4431],
[-1.1218, -0.4572, 0.2253]])
transpose交换维度
数据的维度顺序必须和存储顺序一致
view会导致维度顺序关系变模糊,所以需要认为跟踪
##将维度进行交换后,如果view后,需要人为的换回来,防止数据丢失的情况
Y=torch.rand(4,3,32,32)
##contiguous将数据重新变成连续的数据
Z1=Y.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
Z2=Y.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
print(Y.shape)
print(Z1.shape)
print(Z2.shape)
##Z2破坏了维度的数据
Z=Y.transpose(1,3)
print(Z.shape)
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 32, 32, 3])
permute
[b h w c]是numpyu存储图片的格式,需要这一步才能到处numpy
aa=torch.rand(4,3,28,32)
print(aa.shape)
print(aa.transpose(1,3).transpose(1,2).shape)
print(aa.permute(0,2,3,1).shape)
torch.Size([4, 3, 28, 32])
torch.Size([4, 28, 32, 3])
torch.Size([4, 28, 32, 3])
Expand
without copying data
size一致,可进行对应位置元素相加
API:cat
stack
split
chunk
cat
cat的维度可以不一样,非cat的维度必须一样
a=torch.rand(4,32,8)
b=torch.rand(5,32,8)
c=torch.cat([a,b],dim=0)
print(c.shape)
torch.Size([9, 32, 8])
stack
维度必须完全一致
会创建一个新的维度,新维度等于0时,为第一个数组,新维度等于1时,为第二个数组
a=torch.rand(5,32,8)
b=torch.rand(5,32,8)
c=torch.cat([a,b],dim=0)
print(c.shape)
d=torch.stack([a,b],dim=2)
print(d.shape)
torch.Size([10, 32, 8])
torch.Size([5, 32, 2, 8])
split&chunk
类似拆分的操作
split
根据长度来拆分
e=torch.rand(4,32,8)
f,g=e.split([3,1],dim=0)
print(f.shape)
print(g.shape)
h,i=e.split(2,dim=0)
print(h.shape)
print(i.shape)
torch.Size([3, 32, 8])
torch.Size([1, 32, 8])
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])
chunk
根据数量来拆分
l,m=e.chunk(2,dim=0)
print(l.shape)
print(m.shape)
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])
API:
add/minus/multiply/divide
matmul
pow
sqrt/rsqrt
round
±*/ 可以直接使用运算符号
n=torch.rand(3,4)
o=torch.rand(4)
print(n+o)
print(torch.add(n,o))
print(torch.all(torch.eq(n-o,torch.sub(n,o))))
print(torch.all(torch.eq(n*o,torch.mul(n,o))))
print(torch.all(torch.eq(n/o,torch.div(n,o))))
tensor([[1.1736, 0.6848, 0.7556, 0.5737],
[1.5681, 0.8083, 0.3853, 0.9843],
[1.0794, 0.7089, 1.1611, 0.8224]])
tensor([[1.1736, 0.6848, 0.7556, 0.5737],
[1.5681, 0.8083, 0.3853, 0.9843],
[1.0794, 0.7089, 1.1611, 0.8224]])
tensor(True)
tensor(True)
tensor(True)
matmul
矩阵相乘
torch.mm
只适合2d,不推荐
torch.matmul=@ @使用较少
p=torch.ones(2,2)
q=torch.ones(2,2)
print(p)
print(q)
print(torch.mm(p,q))
print(torch.matmul(p,q))
print(p@q)
tensor([[1., 1.],
[1., 1.]])
tensor([[1., 1.],
[1., 1.]])
tensor([[2., 2.],
[2., 2.]])
tensor([[2., 2.],
[2., 2.]])
tensor([[2., 2.],
[2., 2.]])
利用matmul矩阵相乘时,若前面几维相同,只对后面两维进行相乘,其他保持不变
若前面几维不相同,先利用broadcasting,再进行后面两维相乘。
r=torch.rand(4,3,28,64)
s=torch.rand(4,3,64,32)
print(torch.matmul(r,s).shape)
t=torch.rand(4,1,64,32)
print(torch.matmul(r,t).shape)
torch.Size([4, 3, 28, 32])
torch.Size([4, 3, 28, 32])
Power 次方的运算= **
.sqrt()平方根
u=torch.full([2,2],3,dtype=torch.long)
v=u.pow(2)
print(v)
print(u**2)
##Long类型的数据不支持log对数运算,为什么Tensor是Long类型?
#因为创建List数组时默认使用的是int,suoyi List转成torch.Tensor后,数据类型变成了Long
##解决方法:提前将数据类型指定为浮点型
w=torch.sqrt(v.to(torch.double))
print(w)
print(v**(0.5))
tensor([[9, 9],
[9, 9]])
tensor([[9, 9],
[9, 9]])
tensor([[3., 3.],
[3., 3.]], dtype=torch.float64)
tensor([[3., 3.],
[3., 3.]])
log
幂次方根
x=torch.exp(torch.ones(2,2))
print(x)
print(torch.log(x))
tensor([[2.7183, 2.7183],
[2.7183, 2.7183]])
tensor([[1., 1.],
[1., 1.]])
Approximation 近似
.floor()向下近似
.ceil*()向上近似
.round()四舍五入
.trunc()保留整数
.frac()保留小数
y=torch.tensor(3.14)
print(y.floor())
print(y.ceil())
print(y.trunc())
print(y.frac())
z=torch.tensor(3.4899)
print(z.round())
aa=torch.tensor(3.5)
print(aa.round())
tensor(3.)
tensor(4.)
tensor(3.)
tensor(0.1400)
tensor(3.)
tensor(4.)
clamp
裁剪
场景:gradient clipping梯度裁剪
.clamp(min)
.clamp(min,max)
grad=torch.rand(2,3)*15
print(grad.max())
print(grad.clamp(10))
print(grad.clamp(0,10))
tensor(13.5087)
tensor([[10.0000, 10.0000, 10.0000],
[13.5087, 10.8398, 10.6826]])
tensor([[ 1.4251, 4.0614, 0.2534],
[10.0000, 10.0000, 10.0000]])
statistics
API:
norm
mean sum
prod
max,min, argmin, argmax
kthvalue, topk
norm范数
norm范数vs normalize正则化
norm-p
bb=torch.full([8],1,dtype=torch.long)
cc=bb.view(2,4)
dd=bb.view(2,2,2)
print(cc)
print(dd)
bb=bb.to(torch.double)
cc=cc.to(torch.double)
dd=dd.to(torch.double)
print(bb.norm(1))
print(cc.norm(1))
print(dd.norm(1))
print(cc.norm(1,dim=1))
print(cc.norm(2,dim=1))
print(dd.norm(1,dim=1))
print(dd.norm(2,dim=1))
tensor([[1, 1, 1, 1],
[1, 1, 1, 1]])
tensor([[[1, 1],
[1, 1]],
[[1, 1],
[1, 1]]])
tensor(8., dtype=torch.float64)
tensor(8., dtype=torch.float64)
tensor(8., dtype=torch.float64)
tensor([4., 4.], dtype=torch.float64)
tensor([2., 2.], dtype=torch.float64)
tensor([[2., 2.],
[2., 2.]], dtype=torch.float64)
tensor([[1.4142, 1.4142],
[1.4142, 1.4142]], dtype=torch.float64)
mean, sum, min, max, prod,argmax,argmin
统计属性
argmax,argmin 最大值、最小值的位置
max,min 返回最大值最小值,以及其位置
print('**.**')
ee=torch.arange(8).view(2,4).float()
print(ee.min())
print(ee.max())
print(ee.mean())
print(ee.prod())
print(ee.sum())
print(ee.argmax())
print(ee.argmin())
tensor(0.)
tensor(7.)
tensor(3.5000)
tensor(0.)
tensor(28.)
tensor(7)
tensor(0)
如果不打平,需要指定维度
.argmax(dim= )
dim keepdim
keepdim:使得dim保持一致
dim指定维度进行统计值
Top-k k-th
top-k 返回挨着顺序返回,最大值,次大值,第三大……
k-thvalue 第k个的value
ff=torch.rand(4,10)
print(ff.topk(3,dim=1))
##从最小的来
print(ff.topk(3,dim=1,largest=False))
print(ff.kthvalue(8,dim=1))
print(ff.kthvalue(3))
print(ff.kthvalue(3,dim=1))
torch.return_types.topk(
values=tensor([[0.9040, 0.8903, 0.8289],
[0.9646, 0.9002, 0.8611],
[0.9200, 0.7341, 0.6944],
[0.9485, 0.9281, 0.8831]]),
indices=tensor([[4, 8, 7],
[7, 0, 6],
[0, 1, 8],
[1, 3, 5]]))
torch.return_types.topk(
values=tensor([[0.0208, 0.1013, 0.2037],
[0.0050, 0.0072, 0.1985],
[0.0533, 0.0768, 0.0842],
[0.0490, 0.1635, 0.1881]]),
indices=tensor([[0, 2, 9],
[3, 2, 1],
[7, 3, 4],
[7, 8, 9]]))
torch.return_types.kthvalue(
values=tensor([0.8289, 0.8611, 0.6944, 0.8831]),
indices=tensor([7, 6, 8, 5]))
torch.return_types.kthvalue(
values=tensor([0.2037, 0.1985, 0.0842, 0.1881]),
indices=tensor([9, 1, 4, 9]))
torch.return_types.kthvalue(
values=tensor([0.2037, 0.1985, 0.0842, 0.1881]),
indices=tensor([9, 1, 4, 9]))
compare
, < >= ,<= , != ,==
.eq(,) 判断是否两个元素中的数组是否相等
.equal 判断两个元素是否完全相等
print(torch.gt(ff,0))
print(ff!=0)
gg=torch.ones(2,3)
hh=torch.randn(2,3)
print(torch.eq(gg,hh))
print(torch.eq(gg,gg))
print(torch.equal(gg,gg))
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
tensor([[False, False, False],
[False, False, False]])
tensor([[True, True, True],
[True, True, True]])
True
Process finished with exit code 0
where gather
where
cond=torch.rand(2,2)
print(cond)
ii=torch.ones(2,2)
jj=torch.zeros(2,2)
kk=torch.where(cond>0.5,ii,jj)
print(kk)
tensor([[0.7220, 0.5608],
[0.3790, 0.5327]])
tensor([[1., 1.],
[0., 1.]])
gather
查表的过程
prob=torch.randn(4,10)
idx=prob.topk(dim=1,k=3)
print(idx)
idx=idx[1]
print(idx)
label=torch.arange(10)+100
print(label)
print(torch.gather(label.expand(4,10),dim=1,indx=idx.long()))
torch.return_types.topk(
values=tensor([[1.4600, 0.9766, 0.7980],
[1.6359, 0.9185, 0.8949],
[1.7059, 1.1027, 0.0398],
[1.1878, 1.1326, 0.6822]]),
indices=tensor([[0, 4, 3],
[0, 4, 5],
[0, 6, 7],
[2, 5, 8]]))
tensor([[0, 4, 3],
[0, 4, 5],
[0, 6, 7],
[2, 5, 8]])
tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])