参考pytorch tensor
import torch
"""
type()函数
type(new_type=None, async=False)如果未提供new_type,则返回类型,否则将此对象转换为指定的类型。 如果已经是正确的类型,则不会执行且返回原对象,用法如下:
"""
t1 = torch.LongTensor(3, 5)
print(t1.type())
# 转换为其他类型
t2=t1.type(torch.FloatTensor)
print(t2.type())
torch.LongTensor
torch.FloatTensor
### type_as
"""
type_as()
这个函数的作用是将该tensor转换为另一个tensor的type,可以同步完成转换CPU类型和GPU类型,如torch.IntTensor-->torch.cuda.floatTendor.
如果张量已经是指定类型,则不会进行转换
"""
t1=torch.Tensor(2,3)
t2=torch.IntTensor(3,5)
t3=t1.type_as(t2)
print(t3.type())
torch.IntTensor
cat 拼接
将给定的tensor 按照指定的维度拼接 (可以看成是torch.split() 和torch.chunk()的反操作)
torch.cat(tensors:sequence of Tensors, dim=0: int, out=None: Tensor) → Tensor
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0) # 沿着dim=0拼接
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1) # 沿着dim=1拼接
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])
torch.where 条件选择
torch.where(condition, x, y) -> tensor 这个就是pytorch的条件选择语句,输出OUT= x if condition else y (这里的x和y都应该是tensor)
>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620, 0.3139],
[ 0.3898, -0.7197],
[ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000, 0.3139],
[ 0.3898, 1.0000],
[ 0.0478, 1.0000]])
在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。
这些操作是:
narrow(),view(),expand()和transpose()
举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。转置的tensor和原tensor的内存是共享的!
为了证明这一点,我们来看下面的代码:
x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0]) -->233
可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。
也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。
在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!
当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。
一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:
RuntimeError: input is not contiguous
只要看到这个错误提示,加上contiguous()就好啦~
张量运算
pytorch.contiguous理解