Python view()和repeat()函数的用法

前言:今天在看了一套代码,里面多次用到view和repeat函数,暂且记录一下这俩函数的用法吧。

1. view函数

在pytorch中,view函数可用于重构tensor张量的维度,类似于array数组的resize操作,但用法上存在一定的区别。

>>> import torch
>>> x = torch.randn(2,3)  # 随机生成2行3列的数组
tensor([[-0.5826,  0.2356,  1.5760],
        [ 0.7678,  0.3738, -0.1174]])
>>> x.shape
torch.Size([2, 3])

>>> y = torch.ones(4)     # 生成长度为4的全1数组
tensor([1., 1., 1., 1.])
>>> y.shape
torch.Size([4])

1.1 Tensor.view(rows, cols,...)

含义:将Tensor张量冲构成rows * cols的张量形式(此处Tensor表示需要重构的任一张量)

>>> xx1 = x.view(3,2)
>>> xx1
tensor([[-0.5826,  0.2356],
        [ 1.5760,  0.7678],
        [ 0.3738, -0.1174]])
>>> xx1.shape
torch.Size([3, 2])

>>> yy1 = y.view(2,2)
>>> yy1
tensor([[1., 1.],
        [1., 1.]])
>>> yy1.shape
torch.Size([2, 2])

1.2 Tensor.view(-1)或Tensor.view(-1, 1)

含义:当参数中存在-1时,表示该位置的数值需要经过计算(参数列表里只能存在至多一个-1)

>>> xx2 = x.view(-1,1)
>>> xx2
tensor([[-0.5826],
        [ 0.2356],
        [ 1.5760],
        [ 0.7678],
        [ 0.3738],
        [-0.1174]])
>>> xx2.shape
torch.Size([6, 1])  

>>> yy2 = y.view(1, -1)
>>> yy2
tensor([[1., 1., 1., 1.]])
>>> yy2.shape
torch.Size([1, 4])

2. repeat函数

Tensor.repeat(*size): 用于在指定维度对张量进行复制操作

参数 sizes (torch.Size or int...) 为沿着各维度复制的次数。注:size参数的个数不能少于Tensor的维度

2.1 一维向量的复制

当size参数只有一个时,即 Tensor.repeat(num),表示在Tensor的列方向上复制num倍;

当size参数有两个时,Tensor.repeat(num_row, num_col),第一个num_row表示行方向的复制倍数,第二个num_col表示列方向的复制倍数;

当size参数有三个时,Tensor.repeat(num_channel, num_row, num_col),第一个num_channel表示通道数的复制倍数,num_row表示行方向的复制倍数,第二个num_col表示列方向的复制倍数。

>>> x = torch.tensor([1, 2, 3])  
>>> x.repeat(3)
tensor([1, 2, 3, 1, 2, 3, 1, 2, 3])   # 在列方向上复制三倍

>>> x.repeat(4, 2)
tensor([[ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3]])    # 在行方向上复制4倍,列方向上复制2倍

>>> x.repeat(4, 2, 2) 
tensor([[[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]],
        [[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]],
        [[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]],
        [[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]]])        #在通道方向复制4倍,行列方向各复制2倍

>>> x.repeat(4, 2, 2).size()
torch.Size([4, 2, 6])

2.2 二维向量的复制

与一维向量相似,Tensor.repeat(num_row, num_col)表示在行方向上复制num_row倍,在列方向上复制num_col倍;

Tensor.repeat(num_channel, num_row, num_col),表示在通道方向上复制num_channel倍,在行方向上复制num_row倍,在列方向上复制num_col倍。

>>> y = torch.rand(3, 2)
tensor([[0.0027, 0.6878],
        [0.5797, 0.0492],
        [0.8220, 0.5905]])

>>> y.repeat(3,2)       # 在行方向复制3倍,在列方向复制2倍
tensor([[0.0027, 0.6878, 0.0027, 0.6878],
        [0.5797, 0.0492, 0.5797, 0.0492],
        [0.8220, 0.5905, 0.8220, 0.5905],
        [0.0027, 0.6878, 0.0027, 0.6878],
        [0.5797, 0.0492, 0.5797, 0.0492],
        [0.8220, 0.5905, 0.8220, 0.5905],
        [0.0027, 0.6878, 0.0027, 0.6878],
        [0.5797, 0.0492, 0.5797, 0.0492],
        [0.8220, 0.5905, 0.8220, 0.5905]])


>>> y.repeat(2,3,2)  # 在通道方向复制两倍,在行方向复制3倍,在列方向复制2倍
tensor([[[0.0027, 0.6878, 0.0027, 0.6878],
         [0.5797, 0.0492, 0.5797, 0.0492],
         [0.8220, 0.5905, 0.8220, 0.5905],
         [0.0027, 0.6878, 0.0027, 0.6878],
         [0.5797, 0.0492, 0.5797, 0.0492],
         [0.8220, 0.5905, 0.8220, 0.5905],
         [0.0027, 0.6878, 0.0027, 0.6878],
         [0.5797, 0.0492, 0.5797, 0.0492],
         [0.8220, 0.5905, 0.8220, 0.5905]],
        [[0.0027, 0.6878, 0.0027, 0.6878],
         [0.5797, 0.0492, 0.5797, 0.0492],
         [0.8220, 0.5905, 0.8220, 0.5905],
         [0.0027, 0.6878, 0.0027, 0.6878],
         [0.5797, 0.0492, 0.5797, 0.0492],
         [0.8220, 0.5905, 0.8220, 0.5905],
         [0.0027, 0.6878, 0.0027, 0.6878],
         [0.5797, 0.0492, 0.5797, 0.0492],
         [0.8220, 0.5905, 0.8220, 0.5905]]])

3. 总结示例

>>> x = torch.tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])

>>> y = torch.rand(2,3)
tensor([[0.5423, 0.4213, 0.9477],
        [0.0385, 0.8777, 0.5502]])

>>> xx = x.view(-1, 1).repeat(1, len(y)).view(-1)
tensor([1, 1, 2, 2, 3, 3, 4, 4])

>>> yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
>>> yy
tensor([0.5423, 0.5423, 0.5423, 0.5423, 0.4213, 0.4213, 0.4213, 0.4213, 0.9477,
        0.9477, 0.9477, 0.9477, 0.0385, 0.0385, 0.0385, 0.0385, 0.8777, 0.8777,
        0.8777, 0.8777, 0.5502, 0.5502, 0.5502, 0.5502])

你可能感兴趣的:(杂碎小记,python,pytorch)