numpy与pytorch常用对应算子一览

最近将之前用numpy写的解码头写入了网络模型,顺便将迁移时用到的numpy与pytorch的一些对应算子记录于此

np.meshgrid 与 torch.meshgrid

numpy 中: X, Y = np.meshgrid(np.arange(x), np.arange(y))

若 x == 5 与 y == 4

X == np.array([[0., 1., 2., 3., 4.],
               [0., 1., 2., 3., 4.],
               [0., 1., 2., 3., 4.],
               [0., 1., 2., 3., 4.]])

Y == np.array([[0., 0., 0., 0., 0.],
               [1., 1., 1., 1., 1.],
               [2., 2., 2., 2., 2.],
               [3., 3., 3., 3., 3.]])

与之对应的

pytorch 中: Y, X = torch.meshgrid(torch.arange(y), torch.arange(x), indexing='ij')

np.concatenate 与 torch.cat

np.concatenate((a, b), axis=-1) 与 torch.cat((a, b), dim=-1) 效果相同

np.expand_dims 与 torch.unsqueeze

np.expand_dims(x, axis=-1) 与 torch.unsqueeze(x, dim=-1) 效果相同

np.squeeze 与 torch.squeeze

np.squeeze(x, axis=-1) 与 torch.squeeze(x, dim=-1) 效果相同

np.split 与 torch.split 和 torch.chunk

试图在某一维上等分成x分时,np.split 与 torch.chunk 对应,即 

np.split(array, x, axis=0) 与 torch.chunk(tensor, x, dim=0) 效果相同

试图在某一维上按x(一个列表,numpy中代表索引,pytorch中代表切分长度)切分时,np.split 与 torch.split 对应,即 

np.split(array, [1, 3, 5], axis=0) 与 torch.split(tensor, [1, 2, 2, 1], dim=0) 效果相同

np.dot 与 torch.mm

两者在结果上可能存在转置关系(即可能需要.T),迁移时还需验证

array.repeat 与 tensor.repeat_interleave

array.repeat(t) 与 tensor.repeat_interleave(t) 效果相同, 其中t为重复次数

x = np.array([0, 1])
print(x.repeat(3))
>>> [0 0 0 1 1 1]

后续会继续补充...

你可能感兴趣的:(numpy,pytorch,python)