要理解transpose里参数意义,首先要看x的shape属性,
输入 x.shape, 返回一个无组(2,3,4) 就如你设定的,reshape(2,3,4)
(2,3,4)这个元组的索引(0,1,2) 0对应2,1对应3,2对应4. .
这个由数组X的shape属性的索引组成的元组,才是transpose的真正的意义
代码中transpose参数元组(1,0,2)可以理解为是索引组成的元组,
1对应的还是3,0对应的还是2,2对应的还是4, 通过索引的位置变换,数组X的shape属性为(3,2,4)
**************分隔线*************
如果transpose参数元组(0,2,1),则数组X的shape属性为(2,4,3)
**************分隔线*************
没有进行transpose变换前,每个数都有一个索引,如13的索引为 (1,0,1)
按照上面的变换,13的索引为(0,1,1)
是不是在 下面的那个位置了? 其它的数字也一样,再如 19 ,之前的索引为(1,1,3),
按照上面的变换,19的索引为(1,1,3) 没为变,
import numpy as np
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
"""
Parameters
----------
input_data : 由(数据量, 通道, 高, 长)的4维数组构成的输入数据
filter_h : 卷积核的高
filter_w : 卷积核的长
stride : 步幅
pad : 填充
Returns
-------
col : 2维数组
"""
# 输入数据的形状
# N:批数目,C:通道数,H:输入数据高,W:输入数据长
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1 # 输出数据的高
out_w = (W + 2*pad - filter_w)//stride + 1 # 输出数据的长
# 填充 H,W
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
# (N, C, filter_h, filter_w, out_h, out_w)的0矩阵
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
# 按(0, 4, 5, 1, 2, 3)顺序,交换col的列,然后改变形状
print("*********")
print(col.shape)
print(col[0])
print("*********")
# print(col[0][2][4][4][1][1])
col = col.transpose(0,4,5,1,2,3)
print(col.shape)
print(col[0])
# print(col[0][1][1][2][4][4])
return col
寻找 张量( 0,0,1 )(0,1,0)(1,0,0)看它最后变的位置即索引
As given in the documentation -
numpy.transpose(a, axes=None)
axes : list of ints, optional By default, reverse the dimensions, otherwise permute the axes according to the values given.
The second argument is the axes using which the values are permuted. That is for example if the index of initial element is (x,y,z)
(where x
is 0th axes, y
is 1st axes, and z
is 2nd axes) , the position of that element in the resulting array becomes (z,y,x) (that is 2nd axes first, then 1st axes, and last 0th axes) , based on the argument you provided for axes
.
Since you are transposing an array of shape (2,2,2)
, the transposed shape is also (2,2,2)
, and the positions would change as -
(0,0,0) -> (0,0,0)
(1,0,0) -> (0,0,1)
(0,1,0) -> (0,1,0)
(1,1,0) -> (0,1,1)
...
Since the axes you choose are trivial, lets explain this for another axes. Example -
In [54]: A = np.arange(30).reshape((2, 3, 5))
In [55]: A
Out[55]:
array([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]],
[[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]])
In [56]: np.transpose(A,(1,2,0))
Out[56]:
array([[[ 0, 15],
[ 1, 16],
[ 2, 17],
[ 3, 18],
[ 4, 19]],
[[ 5, 20],
[ 6, 21],
[ 7, 22],
[ 8, 23],
[ 9, 24]],
[[10, 25],
[11, 26],
[12, 27],
[13, 28],
[14, 29]]])
Here, the first element (0,0,0)
becomes the (0,0,0)
element in the result.
The second element (0,0,1)
becomes the (0,1,0)
element in the result. And so on -
Here's a little more clarification:
Don't confuse the parameters of np.reshape(z, y, x)
with those of np.transpose(0, 1, 2)
.
np.reshape()
uses the dimensions of our matrix, think (sheets, rows, columns)
, to specify its layout.
np.transpose()
uses the integers 0, 1, and 2 to represent the axes we want to swap, and correspond to z, y, and x, respectively.
For example, if we have data in a matrix of 2 sheets, 3 rows, and 5 columns...
We can take the next step and think in terms of lists. So, the z, y, x
or sheets, rows, columns
representation of a 2x3x5 matrix is...
[[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]],
[[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]]
...but the module we're feeding this data into requires a layout such that sheet 1 contains the first row of each of our sheets and sheet 2 contains the second row and so on. Then, we'll need to transpose our data with np.transpose(1, 0, 2). This swaps the z and the y axes and transposes the data.
[[[ 0, 1, 2, 3, 4],
[15, 16, 17, 18, 19]],
[[ 5, 6, 7, 8, 9],
[20, 21, 22, 23, 24]],
[[10, 11, 12, 13, 14],
[25, 26, 27, 28, 29]]]
Note the difference from using np.reshape(3, 2, 5) as this doesn't transpose the data--only re-arranges it.
[[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9]],
[[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]]