PyTorch:tensor常用操作

tensor常用操作

    • 一、维度重新排列
    • 二、Tensor.expand()
    • 三、torch.cat()
    • 四、torch.split()
    • 五、torch.chunk()

一、维度重新排列

view

变换数据的维度,首先将tensor展平为一维,然后进行维度变换操作
view要求数据内存地址是连续的

input = torch.randn(2, 3, 4)
print(input.view(3, 2, 4))

transpose

将数据的两个维度进行变化,变化后内存不连续

permuate

将数据的多个维度进行变化,变化后内存不连续

如果在使用transpose或者permuate之后使用view会报错,这是因为维度变换后内存不连续导致的,只需要在view之前使用contiguous函数即可,示例如下

input = torch.randn(2, 3, 4)
input = input.permuate(2, 1, 0)
print(input.is_contiguous())
# input.view(2, 3, 4)  # wrong
input = input.contiguous()
print(input.view())  # correct

二、Tensor.expand()

Tensor.expand(sizes)

将tensor进行扩展,size为扩展后的维度,-1表示对这一维度不进行拓展

Returns a new view of the self tensor with singleton dimensions expanded to a larger size.

Passing -1 as the size for a dimension means not changing the size of that dimension.

Tensor can be also expanded to a larger number of dimensions, and the new ones will be appended at the front. For the new dimensions, the size cannot be set to -1.

Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0. Any dimension of size 1 can be expanded to an arbitrary value without allocating new memory.

x = torch.tensor([[1], [2], [3]])
print(x)
print(x.size())
'''
tensor([[1],
        [2],
        [3]])
torch.Size([3, 1])
'''


x_ex = x.expand(3, 4)
print(x_ex)
x_ex2 = x.expand(-1, 4)   # -1 means not changing the size of that dimension
print(x_ex2)
'''
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
'''

三、torch.cat()

torch.cat(tensors, dim=0)

将tensor进行拼接,拼接的维度根据dim设置,默认为0(行拼接)

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk()

x = torch.randn(2, 3)
print(x)
'''
tensor([[ 0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191]])
'''

x_row = torch.cat((x, x, x), 0)
print(x_row)
x_col = torch.cat((x, x), 1)
print(x_col)
'''
tensor([[ 0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191],
        [ 0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191],
        [ 0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191]])
tensor([[ 0.7004, -0.0935, -0.2668,  0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191,  0.7922,  0.9567,  1.4191]])
'''

四、torch.split()

torch.split(tensor, split_size_or_sections, dim=0)

将tensor进行切片,split_size_or_sections表示切片的大小,可以为整型或者列表,dim为切片维度,默认为0对行进行切片

Splits the tensor into chunks. Each chunk is a view of the original tensor.

If split_size_or_sections is an integer type, then tensorwill be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.

If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

a = torch.arange(10).reshape(5,2)
print(a)
'''
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
'''


a1 = torch.split(a, 2)
print(a1)
print(a1[1]) # 因为分为了三组,所以可以选择任一项输出
'''
(tensor([[0, 1],
         [2, 3]]), 
 tensor([[4, 5],
         [6, 7]]), 
 tensor([[8, 9]]))
tensor([[4, 5],
        [6, 7]])
'''


a2 = torch.split(a, [1,4])
print(a2)
'''
(tensor([[0, 1]]), 
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))
'''

五、torch.chunk()

torch.chunk(tensor, chunks, dim=0)

  • input - [Tensor] – the tensor to split
  • chunks - [int] – number of chunks to return
  • dim - [int] – dimension along which to split the tensor

torch.chunk()torch.split()功能完全一致,唯一的区别是参数chunks输入只能是整型

Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.

Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.

b = torch.arange(10).reshape(2,5)
print(b)
'''
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
'''


b1 = torch.chunk(b, 2, dim=1)
print(b1)
print(b1[1])
'''
(tensor([[0, 1, 2],
         [5, 6, 7]]), 
 tensor([[3, 4],
         [8, 9]]))
tensor([[3, 4],
        [8, 9]])
'''


b2 = torch.chunk(b, 2, dim=1)
print(b2)
'''
(tensor([[0, 1, 2],
         [5, 6, 7]]), 
 tensor([[3, 4],
         [8, 9]]))
'''

你可能感兴趣的:(PyTorch,pytorch,深度学习,python)