np.expand_dims函数

即扩展维度,np.expand_dims(a,axis=)即在相应的axis轴上扩展维度
a = np.array([[1,2],[3,5]])

b=(a==0).astype(np.float)

y = np.expand_dims(a, axis=2)
z = np.expand_dims(a, axis=1)
print(a.shape)
print(y.shape)
print(z.shape)
输出
(2, 2)
(2, 2, 1)
(2, 1, 2)

你可能感兴趣的:(np.expand_dims函数)