pytorch 合集: pytorch的函数之torch

tensor 类型

参考pytorch tensor

  • pytorch 中定义了9种 CPU张量类型和对应的GPU 张量类型,具体如下图。GPU 类型就是 CPU类型中间加上cuda.
  • torch.Tensor , torch.rand(), torch.randn() 默认生成torch.FloatTensor类型。
  • 相同数据类型的tensor 才能做运算
  • tensor 数据类型可以做转换, 可以使用独立的函数如 int(), float()等进行转换,通过torch.type() 可以直接显示输入需要转换的类型,使用type_as()函数,将该tensor转换为另一个tensor的 type

import torch
"""
type()函数
type(new_type=None, async=False)如果未提供new_type,则返回类型,否则将此对象转换为指定的类型。 如果已经是正确的类型,则不会执行且返回原对象,用法如下:
"""
t1 = torch.LongTensor(3, 5)
print(t1.type())
# 转换为其他类型
t2=t1.type(torch.FloatTensor)
print(t2.type())

torch.LongTensor
torch.FloatTensor


### type_as
"""
type_as()
这个函数的作用是将该tensor转换为另一个tensor的type,可以同步完成转换CPU类型和GPU类型,如torch.IntTensor-->torch.cuda.floatTendor.
如果张量已经是指定类型,则不会进行转换
"""
t1=torch.Tensor(2,3)
t2=torch.IntTensor(3,5)
t3=t1.type_as(t2)
print(t3.type())


torch.IntTensor

pytorch 合集: pytorch的函数之torch_第1张图片

tensor 操作

  • cat 拼接
    将给定的tensor 按照指定的维度拼接 (可以看成是torch.split() 和torch.chunk()的反操作)
    torch.cat(tensors:sequence of Tensors, dim=0: int, out=None: Tensor) → Tensor

    >>> x = torch.randn(2, 3)
    >>> x
    tensor([[ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497]])
    >>> torch.cat((x, x, x), 0) # 沿着dim=0拼接
    tensor([[ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497],
            [ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497],
            [ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497]])
    >>> torch.cat((x, x, x), 1) # 沿着dim=1拼接
    tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
             -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
             -0.5790,  0.1497]])
    
  • torch.where 条件选择
    torch.where(condition, x, y) -> tensor 这个就是pytorch的条件选择语句,输出OUT= x if condition else y (这里的x和y都应该是tensor)

    >>> x = torch.randn(3, 2)
    >>> y = torch.ones(3, 2)
    >>> x
    tensor([[-0.4620,  0.3139],
            [ 0.3898, -0.7197],
            [ 0.0478, -0.1657]])
    >>> torch.where(x > 0, x, y)
    tensor([[ 1.0000,  0.3139],
            [ 0.3898,  1.0000],
            [ 0.0478,  1.0000]])
    

tensor.contiguous()

在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。

这些操作是:

narrow(),view(),expand()和transpose()

举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。转置的tensor和原tensor的内存是共享的!

为了证明这一点,我们来看下面的代码:

x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0]) -->233

可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。

也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。

在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!

当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。

一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:

RuntimeError: input is not contiguous

只要看到这个错误提示,加上contiguous()就好啦~

张量运算
pytorch.contiguous理解

你可能感兴趣的:(pytorch,python,python,pytorch)