【pytorch tensorflow numpy对照表】

tensorflow

numpy

pytorch

作用

类型

用法

基本操作

  np.array([x,y],type) torch.Tensor([x,y]) 从list转*,类型转化 矩阵  
  np.zeros((x,y))   全部为0,形状是(x,y) 矩阵  
    torch.from_numpy(ndarry) 从ndarry转*,类型转化 矩阵  
    x.numpy() 从*转ndarry,类型转化 矩阵  
tf.expand_dims(x,axis=1) x[:, np.newaxis] x.unsqueeze(0) 添加一个维度    
    x.squeeze(0) 减少一个维度    

统计

tf.reduce_mean(tensor,axis) np.mean() torch.mean(tensor,axis) 取均值 矩阵 tf详情
tf.reduce_sum(tensor,axis) np.sum()

torch.sum(tensor,axis)

求和 矩阵 tf详情
 

np.max()

torch.max(tensor,axis) 最大值 矩阵  
  np.min() torch.min(tensor,axis) 最小值 矩阵  
  np.argmax() x.argmax() 最大值索引 矩阵  
  np.argmin() x.argmin() 最小值索引 矩阵  
  np.std()   标准偏差 矩阵  
  np.var()   方差 矩阵  
  np.cumprod() x.prod() 累乘积值 矩阵  
  np.linalg.norm(x) torch.norm(tensor, p=2)  范数 矩阵  

算法

  x.sort()   排序 矩阵  
  x * b   广播 矩阵  
  np.where(条件,x,y)   满足条件- x,不满足- y    

运算

tf.square(tensor)   torch.square(tensor) x**2求平方 矩阵 tf详情、torch详情
  np.sqrt(x)   x**(1/2) 求平方根 矩阵  
  x.T   x转置 矩阵(axis=2),向量  
  x.transpose((1,0,2)) 或 x.swapaxes(1,2)   高维转置 矩阵(axis>2)  
tf.subtract(x,y) x - y 或 np.subtract(x, y) torch.sub(x, y) (x - y) 矩阵 tf详情、
tf.add(x,y) x + y 或 np.add(x, y) torch.add(x, y) (x + y) 矩阵  
x * y 或 tf.multiply(x, y) x * y torch.mul(x, y) 或 x.mul(y) (x * y)  点乘 各个元素相乘 矩阵  
tf.matmul(x,y) np.dot(w,x) 或 np.multiply(x, y) x.mm(y) (x * y)  叉乘 矩阵  
  x / y 或 np.divide(x, y) torch.div(x, y) (x / y) 矩阵  
tf.maximum(x,y)   torch.maximum(x,y) 求x,y最大值 实数  

矩阵其他操作

tf.stack   torch.stack([a,b]) 拼接,[a,b] 矩阵  
  np.concatenate((a,b],axis) torch.cat([a,b],axis) 拼接,axis必须存在 矩阵  
tf.unstack x[:,:,-1] torch.chunk(tensor,num,dim=0) 分解 矩阵  
tf.reshape np.reshape torch.reshape   矩阵  
tf.sequence_mask()     掩码    

 

你可能感兴趣的:(【pytorch tensorflow numpy对照表】)