Matlab与Python的reshape使用区别

经过测试,发现二维的话,Python需要先转置再用reshape。

三维的话,Matlab则要对每一页先转置展开为一维,然后再把每一页拼起来,然后再按列往新数组中填充个,具体如下代码,Python的结果和Matlab一致,函数支持一维转二维,二维转二维,二维转三维,三维转二维,三维转三维。注意:在Maltab中假设数组形状为a,b,c,则在Python要改为c,a,b。

import numpy as np


def myreshape(x: np.ndarray, dim: tuple) -> np.ndarray:
    flag = False
    if np.iscomplexobj(x):
        flag = True
    if flag:
        res = np.zeros(dim, dtype=complex)
    else:
        res = np.zeros(dim)
    if len(x.shape) == 1:
        if len(dim) == 2:
            m, n = dim
            temp = x.flatten()
            for i in range(n):
                res[:, i] = temp[m * i:m * (i + 1)]
    elif len(x.shape) == 2:
        if len(dim) == 2:
            m, n = dim
            res = x.T.reshape((m, n))
        else:
            l, m, n = dim
            temp = x.T.flatten()
            idx = 0
            for i in range(l):
                for j in range(n):
                    res[i, :, j] = temp[m * idx:m * (idx + 1)]
                    idx += 1
    else:
        if len(dim) == 2:
            m, n = dim
            l1, m1, n1 = x.shape
            if flag:
                temp = np.zeros(l1 * m1 * n1, dtype=complex)
            else:
                temp = np.zeros(l1 * m1 * n1)
            for i in range(l1):
                temp[(m1 * n1) * i:(m1 * n1) * (i + 1)] = x[i, :, :].T.ravel()
            for i in range(n):
                res[:, i] = temp[m * i:m * (i + 1)]
        else:
            l, m, n = dim
            l1, m1, n1 = x.shape
            if flag:
                temp = np.zeros(l1 * m1 * n1, dtype=complex)
            else:
                temp = np.zeros(l1 * m1 * n1)
            for i in range(l1):
                temp[(m1 * n1) * i:(m1 * n1) * (i + 1)] = x[i, :, :].T.ravel()
            idx = 0
            for i in range(l):
                for j in range(n):
                    res[i, :, j] = temp[m * idx:m * (idx + 1)]
                    idx += 1
    return res


a = np.array([[[1j, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]])
b = myreshape(a, (2, 3, 2))
aa = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
print(aa.T.reshape(1, -1))
print(aa.reshape(-1, 1))
c = np.iscomplexobj(a)

你可能感兴趣的:(Matlab-Python,python,matlab,numpy)