pytorch 实现施密特正交化

对于一个给定的pytorch张量,如果想要对这里面的每一个行向量进行施密特正交化,可以使用下面的代码:

import numpy as np
from sympy.matrices import Matrix, GramSchmidt
import torch
import torch.nn.functional as F

def orthogo_tensor(x):
    m, n = x.size()
    x_np = x.t().numpy()
    matrix = [Matrix(col) for col in x_np.T]
    gram = GramSchmidt(matrix)
    ort_list = []
    for i in range(m):
        vector = []
        for j in range(n):
            vector.append(float(gram[i][j]))
        ort_list.append(vector)
    ort_list = np.mat(ort_list)
    ort_list = torch.from_numpy(ort_list)
    ort_list = F.normalize(ort_list,dim=1)
    return ort_list

x = torch.randn(4,6)
x = orthogo_tensor(x)
print(x)
print(x.matmul(x.t()))

程序运行结果如下所示:

tensor([[-0.1633,  0.5601, -0.6081, -0.2640, -0.4215, -0.2061],
        [ 0.0386, -0.0236, -0.0674,  0.4131, -0.5576,  0.7154],
        [-0.3210,  0.6896,  0.5430, -0.1018,  0.1794,  0.2898],
        [-0.3445, -0.3674,  0.3962, -0.5217, -0.5563, -0.0886]],
       dtype=torch.float64)
tensor([[ 1.0000e+00, -3.6843e-09,  1.5446e-08, -7.1455e-08],
        [-3.6843e-09,  1.0000e+00, -2.1965e-08, -3.5087e-08],
        [ 1.5446e-08, -2.1965e-08,  1.0000e+00, -2.0042e-07],
        [-7.1455e-08, -3.5087e-08, -2.0042e-07,  1.0000e+00]],
       dtype=torch.float64)

你可能感兴趣的:(pytorch)