目录
1--torch.nn.init.normal()和torch.nn.init.constant_()函数
2--torch.from_numpy()函数
3--torch.index_select()函数
torch.nn.init.normal()函数:torch.nn.init.normal(tensor, mean, std)基于输入参数(均值mean和标准差std)初始化输入张量tensor; torch.nn.init.constant_()函数:torch.nn.init.constant_(tensor, val)基于输入参数(val)初始化输入张量tensor,即tensor的值均初始化为val。
torch.from_numpy()函数:torch.from_numpy(numpy_array)基于输入numpy数组(numpy_array)返回一个tensor张量(数据不变,类型转换),作用是转换numpy_array -> tensor。
torch.index_select()函数:torch.index_select(input, dim, index, out=None),从input(tensor类型)指定维度dim中,根据索引号集合(index)返回tensor数据。代码示例如下:
import torch
A = torch.randn(3,3)
print(A)
B = torch.index_select(input = A, dim = 0, index = torch.tensor([2, 1, 0]))
print(B)
C = torch.index_select(input = A, dim = 1, index = torch.tensor([2, 1, 0]))
print(C)
# Result
# A
tensor([[-0.3668, -0.3146, -0.6733],
[ 0.9665, 0.0659, 0.3673],
[ 0.5270, 0.4577, 0.4240]])
# B
tensor([[ 0.5270, 0.4577, 0.4240],
[ 0.9665, 0.0659, 0.3673],
[-0.3668, -0.3146, -0.6733]])
# C
tensor([[-0.6733, -0.3146, -0.3668],
[ 0.3673, 0.0659, 0.9665],
[ 0.4240, 0.4577, 0.5270]])