PyTorch中张量的创建方法的选择 | Pytorch系列(五)

点击上方“AI算法与图像处理”,选择加"星标"或“置顶”

重磅干货,第一时间送达

文 |AI_study

欢迎回到PyTorch神经网络编程系列。在这篇文章中,我们将仔细研究将数据转换成PyTorch张量的主要方法之间的区别。

PyTorch中张量的创建方法的选择 | Pytorch系列(五)_第1张图片

在这篇文章的最后,我们将知道主要选项之间的区别,以及应该使用哪些选项和何时使用。言归正传,我们开始吧。

我们已经见过的PyTorch张量就是PyTorch类torch.Tensor 的实例。张量和PyTorch张量之间的抽象概念的区别在于PyTorch张量给了我们一个具体的实现,我们可以在代码中使用它。

PyTorch中张量的创建方法的选择 | Pytorch系列(五)_第2张图片

在上一篇文章中《Pytorch中张量讲解 | Pytorch系列(四)》,我们了解了如何使用Python列表、序列和NumPy ndarrays等数据在PyTorch中创建张量。给定一个numpy.ndarray,我们发现有四种方法可以创建 torch.Tensor 对象。

下面是快速回顾一下:

> data = np.array([1,2,3])
> type(data)
numpy.ndarray


> o1 = torch.Tensor(data)
> o2 = torch.tensor(data)
> o3 = torch.as_tensor(data)
> o4 = torch.from_numpy(data)


> print(o1)
> print(o2)
> print(o3)
> print(o4)
tensor([1., 2., 3.])
tensor([1, 2, 3], dtype=torch.int32)
tensor([1, 2, 3], dtype=torch.int32)
tensor([1, 2, 3], dtype=torch.int32)

我们在这篇文章的任务是探索这些选项之间的区别,并为我们的张量创建需求提出一个最佳的选择。


生成张量的操作:有什么区别?

让我们开始并找出这些不同之处。

一、Uppercase/Lowercase: torch.Tensor() Vs torch.tensor()

注意到第一个选项torch.Tensor()是大写的T,而第二个选项torch.tensor()是小写的T。

第一个选项(即包含大写T的)是torch.Tensor 类的构造函数。第二个选项是我们所谓的工厂函数( factory function),该函数构造torch.Tensor对象并将其返回给调用者。


PyTorch中张量的创建方法的选择 | Pytorch系列(五)_第3张图片

你可以将torch.tensor()函数看作是在给定一些参数输入的情况下构建张量的工厂。工厂函数是用于创建对象的软件设计模式。

如果您想了解更多关于它的信息,请点击这里。https://en.wikipedia.org/wiki/Factory_(object-oriented_programming)

好的。那是大写字母T和小写字母t之间的区别,但是两者之间哪种方法更好?答案是可以使用其中之一。但是,工厂函数torch.tensor() 具有更好的文档和更多的配置选项,因此现在它可以赢得胜利。

二、Default dtype Vs Inferred dtype


好了,在我们把torch.Tensor()构造函数从我们的列表中删除之前,让我们复习一下打印出来的张量输出的不同之处。

区别在于每个张量的 dtype。让我们看看:

> print(o1.dtype)
> print(o2.dtype)
> print(o3.dtype)
> print(o4.dtype)
torch.float32
torch.int32
torch.int32
torch.int32


此处的差异是由于在构建张量时,torch.Tensor() 构造函数使用的默认的dtype不同。我们可以使用torch.get_default_dtype() 方法验证默认的dtype:

> torch.get_default_dtype()
torch.float32

为了验证代码,我们可以这样做:

> o1.dtype == torch.get_default_dtype()
True

其他调用根据传入的数据来选择 dtype。这称为类型推断(type inference)。dtype 根据传入的数据来推断。请注意,也可以通过给 dtype 指定参数来为这些调用显示设置 dtype。

> torch.tensor(data, dtype=torch.float32)
> torch.as_tensor(data, dtype=torch.float32)


使用torch.Tensor(),我们无法将 dtype 传递给构造函数。这是torch.Tensor() 构造函数缺少配置选项的示例。这也是使用 torch.tensor() 工厂函数创建张量的原因之一。

让我们看一下这些替代创建方法之间的最后隐藏的区别。

三、共享内存以提高性能:复制与共享


第三个区别是隐藏的区别。为了揭示差异,我们需要在使用ndarray创建张量之后,对numpy.ndarray中的原始输入数据进行更改。

让我们这样做,看看会得到什么:

> print('old:', data)
old: [1 2 3]


> data[0] = 0


> print('new:', data)
new: [0 2 3]


> print(o1)
> print(o2)
> print(o3)
> print(o4)


tensor([1., 2., 3.])
tensor([1, 2, 3], dtype=torch.int32)
tensor([0, 2, 3], dtype=torch.int32)
tensor([0, 2, 3], dtype=torch.int32)


请注意,一开始data [0] = 1,并且还注意到我们只更改了原始numpy.ndarray中的数据。注意,我们没有明确地对张量(o1,o2,o3,o4)进行任何更改。

但是,在设置data [0] = 0后,我们可以看到一些张量发生了变化。对于索引0,前两个o1和o2仍具有原始值1,而对于索引0,后两个 o3 和 o4 具有新值0。

发生这种情况是因为torch.Tensor() 和torch.tensor() 复制了它们的输入数据,而torch.as_tensor() 和torch.from_numpy() 与原始输入对象共享了它们在内存中的输入数据。


这种共享仅仅意味着内存中的实际数据存在于一个地方。因此,基础数据中发生的任何更改都将反映在两个对象中,即torch.Tensor和numpy.ndarray。

与复制数据相比,共享数据更高效,占用的内存更少,因为数据不是写在内存中的两个位置。

如果我们有 torch.Tensor 的话,我们要把它转换成一个numpy.ndarray,我们是这样做的:

> print(o3.numpy())
> print(o4.numpy())
[0 2 3]
[0 2 3]

这给出:

> print(type(o3.numpy()))
> print(type(o4.numpy()))


这样可以确定torch.as_tensor() 和torch.from_numpy() 都与它们的输入数据共享内存。但是,我们应该使用哪一个,它们有何不同?

torch.from_numpy() 函数仅接受 numpy.ndarrays,而torch.as_tensor()  函数则接受包括其他PyTorch张量在内的各种数组对象。因此,torch.as_tensor() 是内存共享比赛中的获胜选择。

在PyTorch中创建张量的最佳选择

考虑到所有这些细节,这两个是最佳选择

  • torch.tensor()

  • torch.as_tensor()

torch.tensor() 调用是一种 go-to 调用,而在调整代码性能时应使用torch.as_tensor()。

PyTorch中张量的创建方法的选择 | Pytorch系列(五)_第4张图片

关于内存共享,要记住一些注意事项(它可以在某些地方起作用):

  1. 由于numpy.ndarray对象是在CPU上分配的,因此在使用GPU时,as_tensor() 函数必须将数据从CPU复制到GPU。

  2. as_tensor() 的内存共享不适用于内置Python数据结构(如列表)。

  3. 调用as_tensor() 要求开发人员了解共享功能。这是必要的,因此我们不会在未意识到更改会影响多个对象的情况下无意间对基础数据进行不必要的更改。

  4. 如果在numpy.ndarray对象和张量对象之间进行大量来回操作,则as_tensor() 的性能提高会更大。但是,如果仅执行一次加载操作,则从性能角度来看不会有太大影响。

总结:

至此,我们现在应该对PyTorch张量创建选项有了更好的了解。我们已经了解了工厂函数,并且了解了内存共享与复制如何影响性能和程序行为。 

文章中内容都是经过仔细研究的,本人水平有限,翻译无法做到完美,但是真的是费了很大功夫,希望小伙伴能动动你性感的小手,分享朋友圈或点个“在看”,支持一下我 ^_^

英文原文链接是:

https://deeplizard.com/learn/video/AglLTlms7HU

加群交流

欢迎小伙伴加群交流,目前已有交流群的方向包括:AI学习交流群,目标检测,秋招互助,资料下载等等;加群可扫描并回复感兴趣方向即可(注明:地区+学校/企业+研究方向+昵称)

 我的生活不能没有你! ????

你可能感兴趣的:(高效入门PyTorch系列)