前言:今天在看了一套代码,里面多次用到view和repeat函数,暂且记录一下这俩函数的用法吧。
在pytorch中,view函数可用于重构tensor张量的维度,类似于array数组的resize操作,但用法上存在一定的区别。
>>> import torch
>>> x = torch.randn(2,3) # 随机生成2行3列的数组
tensor([[-0.5826, 0.2356, 1.5760],
[ 0.7678, 0.3738, -0.1174]])
>>> x.shape
torch.Size([2, 3])
>>> y = torch.ones(4) # 生成长度为4的全1数组
tensor([1., 1., 1., 1.])
>>> y.shape
torch.Size([4])
含义:将Tensor张量冲构成rows * cols的张量形式(此处Tensor表示需要重构的任一张量)
>>> xx1 = x.view(3,2)
>>> xx1
tensor([[-0.5826, 0.2356],
[ 1.5760, 0.7678],
[ 0.3738, -0.1174]])
>>> xx1.shape
torch.Size([3, 2])
>>> yy1 = y.view(2,2)
>>> yy1
tensor([[1., 1.],
[1., 1.]])
>>> yy1.shape
torch.Size([2, 2])
含义:当参数中存在-1时,表示该位置的数值需要经过计算(参数列表里只能存在至多一个-1)
>>> xx2 = x.view(-1,1)
>>> xx2
tensor([[-0.5826],
[ 0.2356],
[ 1.5760],
[ 0.7678],
[ 0.3738],
[-0.1174]])
>>> xx2.shape
torch.Size([6, 1])
>>> yy2 = y.view(1, -1)
>>> yy2
tensor([[1., 1., 1., 1.]])
>>> yy2.shape
torch.Size([1, 4])
Tensor.repeat(*size): 用于在指定维度对张量进行复制操作
参数 sizes (torch.Size or int...) 为沿着各维度复制的次数。注:size参数的个数不能少于Tensor的维度
当size参数只有一个时,即 Tensor.repeat(num),表示在Tensor的列方向上复制num倍;
当size参数有两个时,Tensor.repeat(num_row, num_col),第一个num_row表示行方向的复制倍数,第二个num_col表示列方向的复制倍数;
当size参数有三个时,Tensor.repeat(num_channel, num_row, num_col),第一个num_channel表示通道数的复制倍数,num_row表示行方向的复制倍数,第二个num_col表示列方向的复制倍数。
>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(3)
tensor([1, 2, 3, 1, 2, 3, 1, 2, 3]) # 在列方向上复制三倍
>>> x.repeat(4, 2)
tensor([[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3]]) # 在行方向上复制4倍,列方向上复制2倍
>>> x.repeat(4, 2, 2)
tensor([[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]],
[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]],
[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]],
[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]]]) #在通道方向复制4倍,行列方向各复制2倍
>>> x.repeat(4, 2, 2).size()
torch.Size([4, 2, 6])
与一维向量相似,Tensor.repeat(num_row, num_col)表示在行方向上复制num_row倍,在列方向上复制num_col倍;
Tensor.repeat(num_channel, num_row, num_col),表示在通道方向上复制num_channel倍,在行方向上复制num_row倍,在列方向上复制num_col倍。
>>> y = torch.rand(3, 2)
tensor([[0.0027, 0.6878],
[0.5797, 0.0492],
[0.8220, 0.5905]])
>>> y.repeat(3,2) # 在行方向复制3倍,在列方向复制2倍
tensor([[0.0027, 0.6878, 0.0027, 0.6878],
[0.5797, 0.0492, 0.5797, 0.0492],
[0.8220, 0.5905, 0.8220, 0.5905],
[0.0027, 0.6878, 0.0027, 0.6878],
[0.5797, 0.0492, 0.5797, 0.0492],
[0.8220, 0.5905, 0.8220, 0.5905],
[0.0027, 0.6878, 0.0027, 0.6878],
[0.5797, 0.0492, 0.5797, 0.0492],
[0.8220, 0.5905, 0.8220, 0.5905]])
>>> y.repeat(2,3,2) # 在通道方向复制两倍,在行方向复制3倍,在列方向复制2倍
tensor([[[0.0027, 0.6878, 0.0027, 0.6878],
[0.5797, 0.0492, 0.5797, 0.0492],
[0.8220, 0.5905, 0.8220, 0.5905],
[0.0027, 0.6878, 0.0027, 0.6878],
[0.5797, 0.0492, 0.5797, 0.0492],
[0.8220, 0.5905, 0.8220, 0.5905],
[0.0027, 0.6878, 0.0027, 0.6878],
[0.5797, 0.0492, 0.5797, 0.0492],
[0.8220, 0.5905, 0.8220, 0.5905]],
[[0.0027, 0.6878, 0.0027, 0.6878],
[0.5797, 0.0492, 0.5797, 0.0492],
[0.8220, 0.5905, 0.8220, 0.5905],
[0.0027, 0.6878, 0.0027, 0.6878],
[0.5797, 0.0492, 0.5797, 0.0492],
[0.8220, 0.5905, 0.8220, 0.5905],
[0.0027, 0.6878, 0.0027, 0.6878],
[0.5797, 0.0492, 0.5797, 0.0492],
[0.8220, 0.5905, 0.8220, 0.5905]]])
>>> x = torch.tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
>>> y = torch.rand(2,3)
tensor([[0.5423, 0.4213, 0.9477],
[0.0385, 0.8777, 0.5502]])
>>> xx = x.view(-1, 1).repeat(1, len(y)).view(-1)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
>>> yy
tensor([0.5423, 0.5423, 0.5423, 0.5423, 0.4213, 0.4213, 0.4213, 0.4213, 0.9477,
0.9477, 0.9477, 0.9477, 0.0385, 0.0385, 0.0385, 0.0385, 0.8777, 0.8777,
0.8777, 0.8777, 0.5502, 0.5502, 0.5502, 0.5502])