Pytorch Tensor Slice

1. 普通的slice

In [2]: x = torch.arange(12).reshape(4,3)

In [3]: x
Out[3]:
tensor([[  0.,   1.,   2.],
        [  3.,   4.,   5.],
        [  6.,   7.,   8.],
        [  9.,  10.,  11.]])

In [4]: x.dtype
Out[4]: torch.float32

In [5]: y = x[2:, :]

In [6]: y
Out[6]:
tensor([[  6.,   7.,   8.],
        [  9.,  10.,  11.]])

这个时候,变量xy共享内存位置,如果将 y 的值改变, x的值也会改变:

改变方式 1

In [15]: y[:,:] = 666

In [16]: y
Out[16]:
tensor([[ 666.,  666.,  666.],
        [ 666.,  666.,  666.]])

In [17]: x
Out[17]:
tensor([[   0.,    1.,    2.],
        [   3.,    4.,    5.],
        [ 666.,  666.,  666.],
        [ 666.,  666.,  666.]])

改变方式 2

In [12]: y.fill_(0)
Out[12]:
tensor([[ 0.,  0.,  0.],
        [ 0.,  0.,  0.]])

In [13]: y
Out[13]:
tensor([[ 0.,  0.,  0.],
        [ 0.,  0.,  0.]])

In [14]: x
Out[14]:
tensor([[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.]])

2. Mask(dtype=torch.uint8) 作为slice的时候,不会有上述效果

In [2]: x = torch.arange(12).reshape(4, -1)

In [3]: x
Out[3]:
tensor([[  0.,   1.,   2.],
        [  3.,   4.,   5.],
        [  6.,   7.,   8.],
        [  9.,  10.,  11.]])

In [4]: mask = x > 5

In [5]: mask
Out[5]:
tensor([[ 0,  0,  0],
        [ 0,  0,  0],
        [ 1,  1,  1],
        [ 1,  1,  1]], dtype=torch.uint8)

In [6]: y = x[mask]

In [7]: y
Out[7]: tensor([  6.,   7.,   8.,   9.,  10.,  11.])

mask的数据类型为 torch.uint8, 用其作为slice的时候,得到的结果就会 展开成一个一维的数组, 并且改变 y的值, x的值也不会发生变化。

In [8]: y[:] = 0

In [9]: y
Out[9]: tensor([ 0.,  0.,  0.,  0.,  0.,  0.])

In [10]: x
Out[10]:
tensor([[  0.,   1.,   2.],
        [  3.,   4.,   5.],
        [  6.,   7.,   8.],
        [  9.,  10.,  11.]])

In [11]: y.fill_(666)
Out[11]: tensor([ 666.,  666.,  666.,  666.,  666.,  666.])

In [12]: y
Out[12]: tensor([ 666.,  666.,  666.,  666.,  666.,  666.])

In [13]: x
Out[13]:
tensor([[  0.,   1.,   2.],
        [  3.,   4.,   5.],
        [  6.,   7.,   8.],
        [  9.,  10.,  11.]])

你可能感兴趣的:(Pytorch Tensor Slice)