张量:张量为我们提供了描述具有任意数量轴的n维数组的通用方法,是矩阵的扩展与延伸。张量类似于Numpy的多维数组(ndarray)的概念,可以具有任意多的维度。
算子:算子是构建复杂机器学习模型的基础组件,包含一个函数f(x)的前向函数和反向函数。深度学习算法由一个个计算单元组成,我们称这些计算单元为算子。
创建一个张量可以有多种方式,如:指定数据创建、指定形状创建、指定区间创建等。
通过给定Python列表数据,可以创建任意维度的张量。
(1)通过指定的Python列表数据[2.0, 3.0, 4.0],创建一个一维张量。
import torch
X = torch.tensor([2.0, 3.0, 4.0])
print(X)
运行结果:
(2)通过指定的Python列表数据来创建类似矩阵(matrix)的二维张量。
import torch
X = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
print(X)
运行结果:
(3)同样地,还可以创建维度为3、4...N等更复杂的多维张量。
import torch
X = torch.tensor([[[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10]],
[[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20]]])
print(X)
运行结果:
(注:需要注意的是,张量在任何一个维度上的元素数量必须相等。下面尝试定义一个在同一维度上元素数量不等的张量。 )
如果要创建一个指定形状、元素数据相同的张量,可以使用torch.zeros、torch.ones
、torch.full
等。
import torch
m, n = 2, 3
zeros_tensor = torch.zeros([m, n])
ones_tensor = torch.ones([m, n])
full_tensor = torch.full([m, n], 10)
print('zeros Tensor:\n', zeros_tensor)
print('ones Tensor:\n', ones_tensor)
print('full Tensor:\n', full_tensor)
如果要在指定区间内创建张量,可以使用torch.arange
、torch.linspace
等。
import torch
arange_tensor = torch.arange(start=1, end=5, step=1)
linspace_tensor = torch.linspace(start=1, end=5, steps=5)
print('arange Tensor:\n', arange_tensor)
print('linspace Tensor:\n', linspace_tensor)
运行结果:
张量具有如下形状属性:
Tensor.ndim
:张量的维度,例如向量的维度为1,矩阵的维度为2。Tensor.shape
: 张量每个维度上元素的数量。Tensor.shape[n]
:张量第n维的大小。第n维也称为轴。Tensor.numel()
:张量中全部元素的个数。import torch
X = torch.ones([2, 3, 4, 5])
print("Number of dimensions:", X.ndim)
print("Shape of Tensor:", X.shape)
print("Elements number along axis 0 of Tensor:", X.shape[0])
print("Elements number along the last axis of Tensor:", X.shape[-1])
print('Number of elements in Tensor: ', X.numel())
运行结果:
(注:pytorch中Tensor.size()显示的是矩阵的规模)
除了查看张量的形状外,重新设置张量的在实际编程中也具有重要意义。
import torch
X = torch.tensor([[[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10]],
[[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20]],
[[21, 22, 23, 24, 25],
[26, 27, 28, 29, 30]]])
print('the shape of X:', X.shape)
print('after reshape:\n', X.view([2, 5, 3]))
运行结果:
(注:张量数据形状发生改变,而张量内的数据和元素顺序则没有发生改变)
分别对上文定义的X进行reshape为[-1]操作和reshape为[1, 5, 6]两种操作,观察新张量的形状。
import torch
ndim_3_Tensor = torch.tensor([[[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10]],
[[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20]],
[[21, 22, 23, 24, 25],
[26, 27, 28, 29, 30]]])
new_Tensor1 = ndim_3_Tensor.reshape([-1])
print('new Tensor 1 shape: ', new_Tensor1.shape)
new_Tensor2 = ndim_3_Tensor.reshape([1, 5, 6])
print('new Tensor 2 shape: ', new_Tensor2.shape)
运行结果:
从输出结果看,第一行代码中的第一个reshape操作将张量reshape
为元素数量为30的一维向量;第四行代码中的第二个reshape
操作中,0对应的维度的元素个数与原张量在该维度上的元素个数相同。
除使用torch.reshape
进行张量形状的改变外,还可以通过torch.unsqueeze
将张量中的一个或多个维度中插入尺寸为1的维度。
import torch
ones_Tensor = torch.ones([5, 10])
new_Tensor1 = torch.unsqueeze(ones_Tensor, dim=0)
print('new Tensor 1 shape: ', new_Tensor1.shape)
new_Tensor2 = torch.unsqueeze(ones_Tensor, dim=1)
print('new Tensor 2 shape: ', new_Tensor2.shape)
运行结果:
(注:torch.unsqueeze()中,dim表示插入维度的索引,即在哪里插入)
pytorch中可以通过Tensor.dtype
来查看数据类型。
Tensor的最基本数据类型有:
1)通过Python元素创建的张量,如果未指定:
2)通过Numpy数组创建的张量,则与其原来的数据类型保持相同。通过torch.tensor()函数可以将Numpy数组转化为张量。
import torch
print('Tensor dtype from python integers:', torch.tensor(1).dtype)
print('Tensor dtype from python floating point:', torch.tensor(1.0).dtype)
运行结果:
如果想改变张量的数据类型,可以通过调用Tensor.type来实现。
import torch
float32_tensor = torch.tensor(1.0)
int64_tensor = float32_tensor.type(torch.int64)
print('Tensor after cast to int64:', int64_tensor.dtype)
运行结果:
初始化张量时可以通过place来指定其分配的设备位置,可支持的设备位置有三种:CPU、GPU和固定内存。固定内存也称为不可分页内存或锁页内存,它与GPU之间具有更高的读写效率,并且支持异步传输,这对网络整体性能会有进一步提升,但它的缺点是分配空间过多时可能会降低主机系统的性能,因为它减少了用于存储虚拟内存数据的可分页内存。当未指定设备位置时,张量默认设备位置和安装的pytorch版本一致,如安装了GPU版本的pytorch,则设备位置默认为GPU。
如下代码创建了CPU上的张量,并通过Tensor.device
查看张量所在的设备位置。
import torch
# 创建CPU上的Tensor
cpu_Tensor = torch.tensor(1, device=torch.device('cpu'))
# 通过Tensor.place查看张量所在设备位置
print('cpu Tensor: ', cpu_Tensor.device)
运行结果:
张量和Numpy数组可以相互转换。第1.2.2.3节中我们了解到torch.tensor()函数可以将Numpy数组转化为张量,也可以通过Tensor.numpy()
函数将张量转化为Numpy数组。
import torch
ndim_1_Tensor = torch.tensor([1., 2.])
# 将当前Tensor转化为numpy.ndarray
print('Tensor to convert:', ndim_1_Tensor.numpy())
运行结果:
我们可以通过索引或切片方便地访问或修改张量。pytorch使用标准的Python索引规则与Numpy索引规则,具有以下特点:
针对一维张量,对单个轴进行索引和切片。
import torch
ndim_1_Tensor = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
print("Origin Tensor:", ndim_1_Tensor)
print("First element:", ndim_1_Tensor[0])
print("Last element:", ndim_1_Tensor[-1])
print("All element:", ndim_1_Tensor[:])
print("Before 3:", ndim_1_Tensor[:3])
print("Interval of 3:", ndim_1_Tensor[::3])
运行结果:
=========================================================================
(注:Pytorch不支持负数步长)
import torch
ndim_1_Tensor = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
print("Reverse:", ndim_1_Tensor[::-1])
运行结果:
由运行结果可见,在pytorch中step不可为负。
=========================================================================
针对二维及以上维度的张量,在多个维度上进行索引或切片。索引或切片的第一个值对应第0维,第二个值对应第1维,以此类推,如果某个维度上未指定索引,则默认为“:”。
import torch
ndim_2_Tensor = torch.tensor([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]])
print("Origin Tensor:\n", ndim_2_Tensor)
print("First row:", ndim_2_Tensor[0])
print("First row:", ndim_2_Tensor[0, :])
print("First column:", ndim_2_Tensor[:, 0])
print("Last column:", ndim_2_Tensor[:, -1])
print("All element:\n", ndim_2_Tensor[:])
print("First row and second column:", ndim_2_Tensor[0, 1])
运行结果:
与访问张量类似,可以在单个或多个轴上通过索引或切片操作来修改张量。
(注:索引或切片操作是在原地修改该张量的数值,原值并不会被保存,因此在修改张量时,应慎重使用索引或切片操作)
import torch
ndim_2_Tensor = torch.ones([2, 3], dtype=torch.float32)
print('Origin Tensor:\n ', ndim_2_Tensor)
# 修改第1维为0
ndim_2_Tensor[0] = 0
print('change Tensor:\n ', ndim_2_Tensor)
# 修改第1维为2.1
ndim_2_Tensor[0:1] = 2.1
print('change Tensor: \n', ndim_2_Tensor)
# 修改全部Tensor
ndim_2_Tensor[...] = 3
print('change Tensor:\n ', ndim_2_Tensor)
运行结果:
张量支持包括基础数学运算、逻辑运算、矩阵运算等100余种运算操作,以加法为例,有如下两种实现方式:
1)使用pytorch torch.add(x,y)
。
2)使用张量类成员函数x.add(y)
。
import torch
x = torch.tensor([[1.1, 2.2], [3.3, 4.4]], dtype=torch.float64)
y = torch.tensor([[5.5, 6.6], [7.7, 8.8]], dtype=torch.float64)
print('Method 1:\n', torch.add(x, y))
print('Method 2:\n', x.add(y))
运行结果:
从输出结果看,使用张量类成员函数x.add(y)和torch.add(x,y)具有相同的效果。
张量类的基础数学函数如下:
x.abs() # 逐元素取绝对值
x.ceil() # 逐元素向上取整
x.floor() # 逐元素向下取整
x.round() # 逐元素四舍五入
x.exp() # 逐元素计算自然常数为底的指数
x.log() # 逐元素计算x的自然对数
x.reciprocal() # 逐元素求倒数
x.square() # 逐元素计算平方
x.sqrt() # 逐元素计算平方根
x.sin() # 逐元素计算正弦
x.cos() # 逐元素计算余弦
x.add(y) # 逐元素加
x.subtract(y) # 逐元素减
x.multiply(y) # 逐元素乘(积)
x.divide(y) # 逐元素除
x.mod(y) # 逐元素除并取余
x.pow(y) # 逐元素幂
x.max() # 指定维度上元素最大值,默认为全部维度
x.min() # 指定维度上元素最小值,默认为全部维度
x.prod() # 指定维度上元素累乘,默认为全部维度
x.sum() # 指定维度上元素的和,默认为全部维度
同时,为了更方便地使用张量,pytorch对Python数学运算相关的魔法函数进行了重写,以下操作与上述结果相同。
x + y -> x.add(y) # 逐元素加
x - y -> x.subtract(y) # 逐元素减
x * y -> x.multiply(y) # 逐元素乘(积)
x / y -> x.divide(y) # 逐元素除
x % y -> x.mod(y) # 逐元素除并取余
x ** y -> x.pow(y) # 逐元素幂
张量类的逻辑运算函数如下:
x.isfinite() # 判断Tensor中元素是否是有限的数字,即不包括inf与nan
x.equal_all(y) # 判断两个Tensor的全部元素是否相等,并返回形状为[1]的布尔类Tensor
x.equal(y) # 判断两个Tensor的每个元素是否相等,并返回形状相同的布尔类Tensor
x.not_equal(y) # 判断两个Tensor的每个元素是否不相等
x.less_than(y) # 判断Tensor x的元素是否小于Tensor y的对应元素
x.less_equal(y) # 判断Tensor x的元素是否小于或等于Tensor y的对应元素
x.greater_than(y) # 判断Tensor x的元素是否大于Tensor y的对应元素
x.greater_equal(y) # 判断Tensor x的元素是否大于或等于Tensor y的对应元素
x.allclose(y) # 判断两个Tensor的全部元素是否接近
同样地,pytorch对Python逻辑比较相关的魔法函数也进行了重写,这里不再赘述。
张量类还包含了矩阵运算相关的函数,如矩阵的转置、范数计算和乘法等。
x.t() # 矩阵转置
x.transpose([1, 0]) # 交换第 0 维与第 1 维的顺序
x.norm('fro') # 矩阵的弗罗贝尼乌斯范数
x.dist(y, p=2) # 矩阵(x-y)的2范数
x.matmul(y) # 矩阵乘法
有些矩阵运算中也支持大于两维的张量,比如matmul函数,对最后两个维度进行矩阵乘。比如x是形状为[j,k,n,m]的张量,另一个y是[j,k,m,p]的张量,则x.matmul(y)输出的张量形状为[j,k,n,p]。
pytorch的一些API在计算时支持广播(Broadcasting)机制,允许在一些运算时使用不同形状的张量。通常来讲,如果有一个形状较小和一个形状较大的张量,会希望多次使用较小的张量来对较大的张量执行某些操作,看起来像是形状较小的张量首先被扩展到和较大的张量形状一致,然后再做运算。
广播机制的条件
pytorch的广播机制主要遵循如下规则(参考Numpy广播机制):
1)每个张量至少为一维张量。
2)从后往前比较张量的形状,当前维度的大小要么相等,要么其中一个等于1,要么其中一个不存在。
import torch
x = torch.ones((2, 3, 4))
y = torch.ones((2, 3, 4))
z = x + y
print('broadcasting with two same shape tensor: ', z.shape)
x = torch.ones((2, 3, 1, 5))
y = torch.ones((3, 4, 1))
# 从后往前依次比较:
# 第一次:y的维度大小是1
# 第二次:x的维度大小是1
# 第三次:x和y的维度大小相等,都为3
# 第四次:y的维度不存在
# 所以x和y是可以广播的
z = x + y
print('broadcasting with two different shape tensor:', z.shape)
运行结果:
从输出结果看,x与y在上述两种情况中均遵循广播规则,因此在张量相加时可以广播。我们再定义两个shape分别为[2, 3, 4]和[2, 3, 6]的张量,观察这两个张量是否能够通过广播操作相加。
import torch
x = torch.ones((2, 3, 4))
y = torch.ones((2, 3, 6))
z = x + y
运行结果:
从输出结果看,此时x和y是不能广播的,因为在第一次从后往前的比较中,4和6不相等,不符合广播规则。
广播机制的计算规则
现在我们知道在什么情况下两个张量是可以广播的。两个张量进行广播后的结果张量的形状计算规则如下:
1)如果两个张量shape的长度不一致,那么需要在较小长度的shape前添加1,直到两个张量的形状长度相等。
2) 保证两个张量形状相等之后,每个维度上的结果维度就是当前维度上较大的那个。
以张量x和y进行广播为例,x的shape为[2, 3, 1,5],张量y的shape为[3,4,1]。首先张量y的形状长度较小,因此要将该张量形状补齐为[1, 3, 4, 1],再对两个张量的每一维进行比较。从第一维看,x在一维上的大小为2,y为1,因此,结果张量在第一维的大小为2。以此类推,对每一维进行比较,得到结果张量的形状为[2, 3, 4, 5]。
由于矩阵乘法函数torch.matmul在深度学习中使用非常多,这里需要特别说明一下它的广播规则:
1)如果两个张量均为一维,则获得点积结果。
2) 如果两个张量都是二维的,则获得矩阵与矩阵的乘积。
3) 如果张量x是一维,y是二维,则将x的shape转换为[1, D],与y进行矩阵相乘后再删除前置尺寸。
4) 如果张量x是二维,y是一维,则获得矩阵与向量的乘积。
5) 如果两个张量都是N维张量(N > 2),则根据广播规则广播非矩阵维度(除最后两个维度外其余维度)。比如:如果输入x是形状为[j,1,n,m]的张量,另一个y是[k,m,p]的张量,则输出张量的形状为[j,k,n,p]。
import torch
x = torch.ones([10, 1, 5, 2])
y = torch.ones([3, 2, 5])
z = torch.matmul(x, y)
print('After matmul:', z.shape)
运行结果:
从输出结果看,计算张量乘积时会使用到广播机制。
一、house_tiny.csv数据集处理:
(1)读取数据集 house_tiny.csv
import pandas as pd
data = pd.read_csv('house_tiny.csv')
print(data)
运行结果:
(2) 处理缺失值
import pandas as pd
data = pd.read_csv('house_tiny.csv')
# 将data分成inputs和outputs
inputs, outputs = data.iloc[:, 0:2], data.iloc[:, 2]
# 对于inputs中缺少的数值,用同一列的均值替换“NaN”项
# mean()求均值
inputs = inputs.fillna(inputs.mean())
print(inputs)
inputs = pd.get_dummies(inputs, dummy_na=True)
print(inputs)
运行结果:
(3)转换为张量格式
import torch
import pandas as pd
data = pd.read_csv('house_tiny.csv')
inputs, outputs = data.iloc[:, 0:2], data.iloc[:, 2]
inputs = inputs.fillna(inputs.mean())
inputs = pd.get_dummies(inputs, dummy_na=True)
x, y = torch.tensor(inputs.values), torch.tensor(outputs.values)
print(x, '\n', y)
运行结果:
二、boston_house_prices.csv数据集处理:
(1)读取数据集 boston_house_prices.csv
import pandas as pd
data = pd.read_csv('boston_house_prices.csv')
print(data)
运行结果:
(2)无缺失值
(3)转换为张量格式
import torch
import pandas as pd
data = pd.read_csv('boston_house_prices.csv')
x = torch.tensor(data.values)
print(x)
运行结果:
三、Iris.csv数据集处理:
(1)读取数据集 Iris.csv
import pandas as pd
data = pd.read_csv('Iris.csv')
print(data)
运行结果:
(2)无缺失值
(3)转换为张量格式
import torch
import pandas as pd
data = pd.read_csv('Iris.csv')
inputs, outputs = data.iloc[:, 0:5], data.iloc[:, 5]
outputs = pd.get_dummies(outputs, dummy_na=True)
x, y = torch.tensor(inputs.values), torch.tensor(outputs.values)
print(x)
print('\n')
print(y)
运行结果:
tensor([[1.0000e+00, 5.1000e+00, 3.5000e+00, 1.4000e+00, 2.0000e-01],
[2.0000e+00, 4.9000e+00, 3.0000e+00, 1.4000e+00, 2.0000e-01],
[3.0000e+00, 4.7000e+00, 3.2000e+00, 1.3000e+00, 2.0000e-01],
[4.0000e+00, 4.6000e+00, 3.1000e+00, 1.5000e+00, 2.0000e-01],
[5.0000e+00, 5.0000e+00, 3.6000e+00, 1.4000e+00, 2.0000e-01],
[6.0000e+00, 5.4000e+00, 3.9000e+00, 1.7000e+00, 4.0000e-01],
[7.0000e+00, 4.6000e+00, 3.4000e+00, 1.4000e+00, 3.0000e-01],
[8.0000e+00, 5.0000e+00, 3.4000e+00, 1.5000e+00, 2.0000e-01],
[9.0000e+00, 4.4000e+00, 2.9000e+00, 1.4000e+00, 2.0000e-01],
[1.0000e+01, 4.9000e+00, 3.1000e+00, 1.5000e+00, 1.0000e-01],
[1.1000e+01, 5.4000e+00, 3.7000e+00, 1.5000e+00, 2.0000e-01],
[1.2000e+01, 4.8000e+00, 3.4000e+00, 1.6000e+00, 2.0000e-01],
[1.3000e+01, 4.8000e+00, 3.0000e+00, 1.4000e+00, 1.0000e-01],
[1.4000e+01, 4.3000e+00, 3.0000e+00, 1.1000e+00, 1.0000e-01],
[1.5000e+01, 5.8000e+00, 4.0000e+00, 1.2000e+00, 2.0000e-01],
[1.6000e+01, 5.7000e+00, 4.4000e+00, 1.5000e+00, 4.0000e-01],
[1.7000e+01, 5.4000e+00, 3.9000e+00, 1.3000e+00, 4.0000e-01],
[1.8000e+01, 5.1000e+00, 3.5000e+00, 1.4000e+00, 3.0000e-01],
[1.9000e+01, 5.7000e+00, 3.8000e+00, 1.7000e+00, 3.0000e-01],
[2.0000e+01, 5.1000e+00, 3.8000e+00, 1.5000e+00, 3.0000e-01],
[2.1000e+01, 5.4000e+00, 3.4000e+00, 1.7000e+00, 2.0000e-01],
[2.2000e+01, 5.1000e+00, 3.7000e+00, 1.5000e+00, 4.0000e-01],
[2.3000e+01, 4.6000e+00, 3.6000e+00, 1.0000e+00, 2.0000e-01],
[2.4000e+01, 5.1000e+00, 3.3000e+00, 1.7000e+00, 5.0000e-01],
[2.5000e+01, 4.8000e+00, 3.4000e+00, 1.9000e+00, 2.0000e-01],
[2.6000e+01, 5.0000e+00, 3.0000e+00, 1.6000e+00, 2.0000e-01],
[2.7000e+01, 5.0000e+00, 3.4000e+00, 1.6000e+00, 4.0000e-01],
[2.8000e+01, 5.2000e+00, 3.5000e+00, 1.5000e+00, 2.0000e-01],
[2.9000e+01, 5.2000e+00, 3.4000e+00, 1.4000e+00, 2.0000e-01],
[3.0000e+01, 4.7000e+00, 3.2000e+00, 1.6000e+00, 2.0000e-01],
[3.1000e+01, 4.8000e+00, 3.1000e+00, 1.6000e+00, 2.0000e-01],
[3.2000e+01, 5.4000e+00, 3.4000e+00, 1.5000e+00, 4.0000e-01],
[3.3000e+01, 5.2000e+00, 4.1000e+00, 1.5000e+00, 1.0000e-01],
[3.4000e+01, 5.5000e+00, 4.2000e+00, 1.4000e+00, 2.0000e-01],
[3.5000e+01, 4.9000e+00, 3.1000e+00, 1.5000e+00, 1.0000e-01],
[3.6000e+01, 5.0000e+00, 3.2000e+00, 1.2000e+00, 2.0000e-01],
[3.7000e+01, 5.5000e+00, 3.5000e+00, 1.3000e+00, 2.0000e-01],
[3.8000e+01, 4.9000e+00, 3.1000e+00, 1.5000e+00, 1.0000e-01],
[3.9000e+01, 4.4000e+00, 3.0000e+00, 1.3000e+00, 2.0000e-01],
[4.0000e+01, 5.1000e+00, 3.4000e+00, 1.5000e+00, 2.0000e-01],
[4.1000e+01, 5.0000e+00, 3.5000e+00, 1.3000e+00, 3.0000e-01],
[4.2000e+01, 4.5000e+00, 2.3000e+00, 1.3000e+00, 3.0000e-01],
[4.3000e+01, 4.4000e+00, 3.2000e+00, 1.3000e+00, 2.0000e-01],
[4.4000e+01, 5.0000e+00, 3.5000e+00, 1.6000e+00, 6.0000e-01],
[4.5000e+01, 5.1000e+00, 3.8000e+00, 1.9000e+00, 4.0000e-01],
[4.6000e+01, 4.8000e+00, 3.0000e+00, 1.4000e+00, 3.0000e-01],
[4.7000e+01, 5.1000e+00, 3.8000e+00, 1.6000e+00, 2.0000e-01],
[4.8000e+01, 4.6000e+00, 3.2000e+00, 1.4000e+00, 2.0000e-01],
[4.9000e+01, 5.3000e+00, 3.7000e+00, 1.5000e+00, 2.0000e-01],
[5.0000e+01, 5.0000e+00, 3.3000e+00, 1.4000e+00, 2.0000e-01],
[5.1000e+01, 7.0000e+00, 3.2000e+00, 4.7000e+00, 1.4000e+00],
[5.2000e+01, 6.4000e+00, 3.2000e+00, 4.5000e+00, 1.5000e+00],
[5.3000e+01, 6.9000e+00, 3.1000e+00, 4.9000e+00, 1.5000e+00],
[5.4000e+01, 5.5000e+00, 2.3000e+00, 4.0000e+00, 1.3000e+00],
[5.5000e+01, 6.5000e+00, 2.8000e+00, 4.6000e+00, 1.5000e+00],
[5.6000e+01, 5.7000e+00, 2.8000e+00, 4.5000e+00, 1.3000e+00],
[5.7000e+01, 6.3000e+00, 3.3000e+00, 4.7000e+00, 1.6000e+00],
[5.8000e+01, 4.9000e+00, 2.4000e+00, 3.3000e+00, 1.0000e+00],
[5.9000e+01, 6.6000e+00, 2.9000e+00, 4.6000e+00, 1.3000e+00],
[6.0000e+01, 5.2000e+00, 2.7000e+00, 3.9000e+00, 1.4000e+00],
[6.1000e+01, 5.0000e+00, 2.0000e+00, 3.5000e+00, 1.0000e+00],
[6.2000e+01, 5.9000e+00, 3.0000e+00, 4.2000e+00, 1.5000e+00],
[6.3000e+01, 6.0000e+00, 2.2000e+00, 4.0000e+00, 1.0000e+00],
[6.4000e+01, 6.1000e+00, 2.9000e+00, 4.7000e+00, 1.4000e+00],
[6.5000e+01, 5.6000e+00, 2.9000e+00, 3.6000e+00, 1.3000e+00],
[6.6000e+01, 6.7000e+00, 3.1000e+00, 4.4000e+00, 1.4000e+00],
[6.7000e+01, 5.6000e+00, 3.0000e+00, 4.5000e+00, 1.5000e+00],
[6.8000e+01, 5.8000e+00, 2.7000e+00, 4.1000e+00, 1.0000e+00],
[6.9000e+01, 6.2000e+00, 2.2000e+00, 4.5000e+00, 1.5000e+00],
[7.0000e+01, 5.6000e+00, 2.5000e+00, 3.9000e+00, 1.1000e+00],
[7.1000e+01, 5.9000e+00, 3.2000e+00, 4.8000e+00, 1.8000e+00],
[7.2000e+01, 6.1000e+00, 2.8000e+00, 4.0000e+00, 1.3000e+00],
[7.3000e+01, 6.3000e+00, 2.5000e+00, 4.9000e+00, 1.5000e+00],
[7.4000e+01, 6.1000e+00, 2.8000e+00, 4.7000e+00, 1.2000e+00],
[7.5000e+01, 6.4000e+00, 2.9000e+00, 4.3000e+00, 1.3000e+00],
[7.6000e+01, 6.6000e+00, 3.0000e+00, 4.4000e+00, 1.4000e+00],
[7.7000e+01, 6.8000e+00, 2.8000e+00, 4.8000e+00, 1.4000e+00],
[7.8000e+01, 6.7000e+00, 3.0000e+00, 5.0000e+00, 1.7000e+00],
[7.9000e+01, 6.0000e+00, 2.9000e+00, 4.5000e+00, 1.5000e+00],
[8.0000e+01, 5.7000e+00, 2.6000e+00, 3.5000e+00, 1.0000e+00],
[8.1000e+01, 5.5000e+00, 2.4000e+00, 3.8000e+00, 1.1000e+00],
[8.2000e+01, 5.5000e+00, 2.4000e+00, 3.7000e+00, 1.0000e+00],
[8.3000e+01, 5.8000e+00, 2.7000e+00, 3.9000e+00, 1.2000e+00],
[8.4000e+01, 6.0000e+00, 2.7000e+00, 5.1000e+00, 1.6000e+00],
[8.5000e+01, 5.4000e+00, 3.0000e+00, 4.5000e+00, 1.5000e+00],
[8.6000e+01, 6.0000e+00, 3.4000e+00, 4.5000e+00, 1.6000e+00],
[8.7000e+01, 6.7000e+00, 3.1000e+00, 4.7000e+00, 1.5000e+00],
[8.8000e+01, 6.3000e+00, 2.3000e+00, 4.4000e+00, 1.3000e+00],
[8.9000e+01, 5.6000e+00, 3.0000e+00, 4.1000e+00, 1.3000e+00],
[9.0000e+01, 5.5000e+00, 2.5000e+00, 4.0000e+00, 1.3000e+00],
[9.1000e+01, 5.5000e+00, 2.6000e+00, 4.4000e+00, 1.2000e+00],
[9.2000e+01, 6.1000e+00, 3.0000e+00, 4.6000e+00, 1.4000e+00],
[9.3000e+01, 5.8000e+00, 2.6000e+00, 4.0000e+00, 1.2000e+00],
[9.4000e+01, 5.0000e+00, 2.3000e+00, 3.3000e+00, 1.0000e+00],
[9.5000e+01, 5.6000e+00, 2.7000e+00, 4.2000e+00, 1.3000e+00],
[9.6000e+01, 5.7000e+00, 3.0000e+00, 4.2000e+00, 1.2000e+00],
[9.7000e+01, 5.7000e+00, 2.9000e+00, 4.2000e+00, 1.3000e+00],
[9.8000e+01, 6.2000e+00, 2.9000e+00, 4.3000e+00, 1.3000e+00],
[9.9000e+01, 5.1000e+00, 2.5000e+00, 3.0000e+00, 1.1000e+00],
[1.0000e+02, 5.7000e+00, 2.8000e+00, 4.1000e+00, 1.3000e+00],
[1.0100e+02, 6.3000e+00, 3.3000e+00, 6.0000e+00, 2.5000e+00],
[1.0200e+02, 5.8000e+00, 2.7000e+00, 5.1000e+00, 1.9000e+00],
[1.0300e+02, 7.1000e+00, 3.0000e+00, 5.9000e+00, 2.1000e+00],
[1.0400e+02, 6.3000e+00, 2.9000e+00, 5.6000e+00, 1.8000e+00],
[1.0500e+02, 6.5000e+00, 3.0000e+00, 5.8000e+00, 2.2000e+00],
[1.0600e+02, 7.6000e+00, 3.0000e+00, 6.6000e+00, 2.1000e+00],
[1.0700e+02, 4.9000e+00, 2.5000e+00, 4.5000e+00, 1.7000e+00],
[1.0800e+02, 7.3000e+00, 2.9000e+00, 6.3000e+00, 1.8000e+00],
[1.0900e+02, 6.7000e+00, 2.5000e+00, 5.8000e+00, 1.8000e+00],
[1.1000e+02, 7.2000e+00, 3.6000e+00, 6.1000e+00, 2.5000e+00],
[1.1100e+02, 6.5000e+00, 3.2000e+00, 5.1000e+00, 2.0000e+00],
[1.1200e+02, 6.4000e+00, 2.7000e+00, 5.3000e+00, 1.9000e+00],
[1.1300e+02, 6.8000e+00, 3.0000e+00, 5.5000e+00, 2.1000e+00],
[1.1400e+02, 5.7000e+00, 2.5000e+00, 5.0000e+00, 2.0000e+00],
[1.1500e+02, 5.8000e+00, 2.8000e+00, 5.1000e+00, 2.4000e+00],
[1.1600e+02, 6.4000e+00, 3.2000e+00, 5.3000e+00, 2.3000e+00],
[1.1700e+02, 6.5000e+00, 3.0000e+00, 5.5000e+00, 1.8000e+00],
[1.1800e+02, 7.7000e+00, 3.8000e+00, 6.7000e+00, 2.2000e+00],
[1.1900e+02, 7.7000e+00, 2.6000e+00, 6.9000e+00, 2.3000e+00],
[1.2000e+02, 6.0000e+00, 2.2000e+00, 5.0000e+00, 1.5000e+00],
[1.2100e+02, 6.9000e+00, 3.2000e+00, 5.7000e+00, 2.3000e+00],
[1.2200e+02, 5.6000e+00, 2.8000e+00, 4.9000e+00, 2.0000e+00],
[1.2300e+02, 7.7000e+00, 2.8000e+00, 6.7000e+00, 2.0000e+00],
[1.2400e+02, 6.3000e+00, 2.7000e+00, 4.9000e+00, 1.8000e+00],
[1.2500e+02, 6.7000e+00, 3.3000e+00, 5.7000e+00, 2.1000e+00],
[1.2600e+02, 7.2000e+00, 3.2000e+00, 6.0000e+00, 1.8000e+00],
[1.2700e+02, 6.2000e+00, 2.8000e+00, 4.8000e+00, 1.8000e+00],
[1.2800e+02, 6.1000e+00, 3.0000e+00, 4.9000e+00, 1.8000e+00],
[1.2900e+02, 6.4000e+00, 2.8000e+00, 5.6000e+00, 2.1000e+00],
[1.3000e+02, 7.2000e+00, 3.0000e+00, 5.8000e+00, 1.6000e+00],
[1.3100e+02, 7.4000e+00, 2.8000e+00, 6.1000e+00, 1.9000e+00],
[1.3200e+02, 7.9000e+00, 3.8000e+00, 6.4000e+00, 2.0000e+00],
[1.3300e+02, 6.4000e+00, 2.8000e+00, 5.6000e+00, 2.2000e+00],
[1.3400e+02, 6.3000e+00, 2.8000e+00, 5.1000e+00, 1.5000e+00],
[1.3500e+02, 6.1000e+00, 2.6000e+00, 5.6000e+00, 1.4000e+00],
[1.3600e+02, 7.7000e+00, 3.0000e+00, 6.1000e+00, 2.3000e+00],
[1.3700e+02, 6.3000e+00, 3.4000e+00, 5.6000e+00, 2.4000e+00],
[1.3800e+02, 6.4000e+00, 3.1000e+00, 5.5000e+00, 1.8000e+00],
[1.3900e+02, 6.0000e+00, 3.0000e+00, 4.8000e+00, 1.8000e+00],
[1.4000e+02, 6.9000e+00, 3.1000e+00, 5.4000e+00, 2.1000e+00],
[1.4100e+02, 6.7000e+00, 3.1000e+00, 5.6000e+00, 2.4000e+00],
[1.4200e+02, 6.9000e+00, 3.1000e+00, 5.1000e+00, 2.3000e+00],
[1.4300e+02, 5.8000e+00, 2.7000e+00, 5.1000e+00, 1.9000e+00],
[1.4400e+02, 6.8000e+00, 3.2000e+00, 5.9000e+00, 2.3000e+00],
[1.4500e+02, 6.7000e+00, 3.3000e+00, 5.7000e+00, 2.5000e+00],
[1.4600e+02, 6.7000e+00, 3.0000e+00, 5.2000e+00, 2.3000e+00],
[1.4700e+02, 6.3000e+00, 2.5000e+00, 5.0000e+00, 1.9000e+00],
[1.4800e+02, 6.5000e+00, 3.0000e+00, 5.2000e+00, 2.0000e+00],
[1.4900e+02, 6.2000e+00, 3.4000e+00, 5.4000e+00, 2.3000e+00],
[1.5000e+02, 5.9000e+00, 3.0000e+00, 5.1000e+00, 1.8000e+00]],
dtype=torch.float64)
tensor([[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0]], dtype=torch.uint8)