torch的拼接函数_pytorch常用函数总结(持续更新)

pytorch常用函数总结

torch.max(input,dim)

求取指定维度上的最大值,,返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。比如:

demo.shape

Out[7]: torch.Size([10, 3, 10, 10])

torch.max(demo,1)[0].shape

Out[8]: torch.Size([10, 10, 10])

torch.max(demo,1)[0]这其中的[0]取得就是返回的最大值,torch.max(demo,1)[1]就是返回的最大值对应的位置索引。例子如下:

a

Out[8]:

tensor([[1., 2., 3.],

[4., 5., 6.]])

a.max(1)

Out[9]:

torch.return_types.max(

values=tensor([3., 6.]),

indices=tensor([2, 2]))

class torch.nn.ParameterList(parameters=None)

将submodules保存在一个list中。

ParameterList可以像一般的Python list一样被索引。而且ParameterList中包含的parameters已经被正确的注册,对所有的module method可见。

参数说明:

modules (list, optional) – a list of nn.Parameter

例子:

class MyModule(nn.Module):

def __init__(self):

super(MyModule, self).__init__()

self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

def forward(self, x):

# ModuleList can act as an iterable, or be indexed using ints

for i, p in enumerate(self.params):

x = self.params[i // 2].mm(x) + p.mm(x)

return x

torch.cat()函数

cat是concatnate的意思:拼接,联系在一起。

先说cat( )的普通用法

如果我们有两个tensor是A和B,想把他们拼接在一起,需要如下操作:

C = torch.cat( (A,B),0 ) #按维数0拼接(竖着拼)

C = torch.cat( (A,B),1 ) #按维数1拼接(横着拼)

相当于将tensor按照指定维度进行拼接,比如A的shape为128*64*32*32,B的shape为 128*32*64*64,那么按照 torch.cat( (A,B),1)拼接的之后的形状为 128*96*64*64。

注意:

两个tensor要想进行拼接,必须保证除了指定拼接的维度以外其他的维度形状必须相同,比如上面的例子,拼接A和B时,A的形状为128*64*32*32,B的形状为128*32*64*64,只有第二个维度的维数数值不同,其他的维度的维数都是相同的,所以拼接时可按维度1进行拼接(注意,维度的下标是从0开始的,比如 A 的形状对应的维度下标为:\(128_0*64_1*32_2*32_3\))

contiguous()函数的使用

contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形(如:tensor_var.contiguous().view() ),示例如下:

x = torch.Tensor(2,3)

y = x.permute(1,0) # permute:二维tensor的维度变换,此处功能相当于转置transpose

y.view(-1) # 报错,view使用前需调用contiguous()函数

y = x.permute(1,0).contiguous()

y.view(-1) # OK

具体原因有两种说法:

1 transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,而view操作要求tensor的内存连续存储,所以需要contiguous来返回一个contiguous copy;

2 维度变换后的变量是之前变量的浅拷贝,指向同一区域,即view操作会连带原来的变量一同变形,这是不合法的,所以也会报错;---- 这个解释有部分道理,也即contiguous返回了tensor的深拷贝contiguous copy数据;

tensor.repeat()函数

该函数传入的参数个数不少于tensor的维数,其中每个参数代表的是对该维度重复多少次,也就相当于复制的倍数,结合例子更好理解,如下:

>>> import torch

>>>

>>> a = torch.randn(33, 55)

>>> a.size()

torch.Size([33, 55])

>>>

>>> a.repeat(1, 1).size()

torch.Size([33, 55])

>>>

>>> a.repeat(2,1).size()

torch.Size([66, 55])

>>>

>>> a.repeat(1,2).size()

torch.Size([33, 110])

>>>

>>> a.repeat(1,1,1).size()

torch.Size([1, 33, 55])

>>>

>>> a.repeat(2,1,1).size()

torch.Size([2, 33, 55])

>>>

>>> a.repeat(1,2,1).size()

torch.Size([1, 66, 55])

>>>

>>> a.repeat(1,1,2).size()

torch.Size([1, 33, 110])

>>>

>>> a.repeat(1,1,1,1).size()

torch.Size([1, 1, 33, 55])

>>>

>>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数,

>>> # 下面是一些错误示例

>>> a.repeat(2).size() # 1D < 2D, error

Traceback (most recent call last):

File "", line 1, in

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

>>>

>>> b = torch.randn(5,6,7)

>>> b.size() # 3D

torch.Size([5, 6, 7])

>>>

>>> b.repeat(2).size() # 1D < 3D, error

Traceback (most recent call last):

File "", line 1, in

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

>>>

>>> b.repeat(2,1).size() # 2D < 3D, error

Traceback (most recent call last):

File "", line 1, in

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

>>>

>>> b.repeat(2,1,1).size() # 3D = 3D, okay

torch.Size([10, 6, 7])

>>>

torch.masked_select()函数

a = torch.Tensor([[4,5,7], [3,9,8],[2,3,4]])

b = torch.Tensor([[1,1,0], [0,0,1],[1,0,1]]).type(torch.ByteTensor)

c = torch.masked_select(a,b)

print(c)

用法:torch.masked_select(x, mask),mask必须转化成torch.ByteTensor类型。

torch.sort

torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)

对输入张量input沿着指定维按升序排序。如果不给定dim,则默认为输入的最后一维。如果指定参数descending为True,则按降序排序

返回元组 (sorted_tensor, sorted_indices) , sorted_indices 为原始输入中的下标。

参数:

input (Tensor) – 要对比的张量

dim (int, optional) – 沿着此维排序

descending (bool, optional) – 布尔值,控制升降排序

out (tuple, optional) – 输出张量。必须为ByteTensor或者与第一个参数tensor相同类型。

例子:

>>> x = torch.randn(3, 4)

>>> sorted, indices = torch.sort(x)

>>> sorted

-1.6747 0.0610 0.1190 1.4137

-1.4782 0.7159 1.0341 1.3678

-0.3324 -0.0782 0.3518 0.4763

[torch.FloatTensor of size 3x4]

>>> indices

0 1 3 2

2 1 0 3

3 1 0 2

[torch.LongTensor of size 3x4]

>>> sorted, indices = torch.sort(x, 0)

>>> sorted

-1.6747 -0.0782 -1.4782 -0.3324

0.3518 0.0610 0.4763 0.1190

1.0341 0.7159 1.4137 1.3678

[torch.FloatTensor of size 3x4]

>>> indices

0 2 1 2

2 0 2 0

1 1 0 1

[torch.LongTensor of size 3x4]

你可能感兴趣的:(torch的拼接函数)