Lesson 1. 张量(Tensor)的创建和常用方法
一、张量(Tensor)的基本创建及其类型
//
import torch #导入PyTorch包
import numpy as np
torch.__version__ #查看版本号
1.张量(Tensor)函数创建方法
// 张量(Tensor)函数创建方法
t = torch.tensor([1, 2]) # 通过列表创建张量
torch.tensor((1, 2)) # 通过元组创建张量
a = np.array((1, 2))
t1 = torch.tensor(a) # 通过数组创建张量
type(t) #查看张量的类型
2.PyTorch中Tensor类型
数据类型 | dtype |
---|---|
32bit浮点数 | torch.float32或torch.float |
64bit浮点数 | torch.float64或torch.double |
16bit浮点数 | torch.float16或torch.half |
8bit无符号整数 | torch.unit8 |
8bit有符号整数 | torch.int8 |
16bit有符号整数 | torch.int16或torch.short |
16bit有符号整数 | torch.int16或torch.short |
32bit有符号整数 | torch.int32或torch.int |
64bit有符号整数 | torch.int64或torch.long |
布尔型 | torch.bool |
复数型 | torch.complex64 |
dtype使用方法
// 张量(Tensor)函数创建方法
t.dtype #查看张量的类型
torch.tensor(np.array([1.1, 2.2])).dtype
torch.tensor([1.11, 2.2]).dtype
torch.tensor([1.1, 2.7], dtype = torch.int16)
3.张量类型的转化
// 张量类型的转化
t.float() # 转化为默认浮点型(32位)
t.double() # 转化为双精度浮点型
t.short() # 转化为16位整数
二、张量的维度与形变
// 张量的维度
t1 = torch.tensor([1, 2])
t1.ndim
t1.shape
t1.size()
len(t1) #返回拥有几个(N-1)维元素
t1.numel() #返回总共拥有几个数
#注:一维张量len和numel返回结果相同,但更高维度张量则不然
t2 = torch.tensor([[1, 2], [3, 4]])
len([[1, 2], [3, 4]]) #2
t2.numel() #4
// 张量的形变
t2 = torch.tensor([[1, 2], [3, 4]])
t2.flatten() #按行排列,拉平。
t2.reshape(1,4)
t2.reshape(1, 1, 4)
torch.zeros([2, 3]) # 创建全是0的,两行、三列的张量(矩阵)
torch.ones([2, 3])
torch.eye(5)
torch.diag(t1)
torch.rand(2, 3) #rand:服从0-1均匀分布的张量
torch.randn(2, 3) #randn:服从标准正态分布的张量
torch.normal(2, 3, size = (2, 2))
# 均值为2,标准差为3的张量 normal:服从指定正态分布的张量
torch.randint(1, 10, [2, 4])
# 在1-10之间随机抽取整数,组成两行四列的矩阵
torch.arange(5) #arange/linspace:生成数列
torch.arange(1, 5, 0.5) # 从1到5(左闭右开),每隔0.5取值一个
torch.linspace(1, 5, 3) # 从1到5(左右都包含),等距取三个数
torch.empty(2, 3) #empty:生成未初始化的指定形状矩阵
torch.full([2, 4], 2) #full:根据指定形状,填充指定数值
torch.full_like(t1, 2) # 根据t1形状,填充数值2
torch.randint_like(t2, 1, 10)
torch.zeros_like(t1)
// 张量(Tensor)和其他相关类型之间的转化方法
t1 = torch.arange(1,11)
t1.numpy()
np.array(t1)
t1.tolist()
list(t1)
torch.tensor(1).item()
t11 = t1 # t11是t1的浅拷贝
t11 = t1.clone() #t11不随t1对象改变而改变,则需要对t11进行深拷贝