格式转换
torch.tensor与numpy.ndarray/list的转换
print('\ntorch.tensor and numpy.ndarray / list')
import numpy as np
nda_list = [1, 2, 3]
nda = np.array(nda_list)
print(nda)
nda_tensor = torch.from_numpy(nda)
print(nda_tensor)
nda_arr = nda_tensor.numpy()
print(nda_arr)
print(nda_tensor.tolist())
基本操作
import torch
torch.manual_seed(2020)
torch.cuda.manual_seed()
torch.cuda.manual_seed_all()
创建tensor
a = torch.tensor([1, 2, 3])
print(a)
b = torch.empty(1, 3)
print(b)
c = torch.zeros(1, 3)
print(c)
d = torch.ones(1, 3)
print(d)
e = torch.full((1, 3), 2)
print(e)
f = torch.rand(1, 3)
print(f)
g = torch.randn(1, 3)
print(g)
h = torch.randint(low=1, high=10, size=(1,3))
print(h)
数据处理操作
x = torch.tensor([-1, 1, 1, 1])
y = torch.tensor([3, 3, 3, 3])
print('\nsize:')
print(x.size(), type(x.size()))
print(torch.abs(x))
加法
print('\naddition:')
print(x + y)
print(torch.add(input=x, alpha=1, other=y))
print(x.add(y), x)
x.add_(y)
print(x)
减法(类比加法)
print('\nsubtraction:')
print(torch.sub(input=y, alpha=1, other=x))
print(y.sub(x))
元素乘法
print('\nelement-wise multiplication:')
print(x * y)
print(torch.mul(x, y))
print(x.mul(y), x)
x.mul_(y)
print(x)
除法:类比元素乘法
print('\ndivision:')
print(torch.div(x, y))
print(x.div_(y))
矩阵乘法
print('\ntensor multiplication:')
m = torch.rand(3, 4, 5)
n = torch.rand(3, 5, 4)
print(torch.matmul(m, n).size())
print((m @ n).size())
p = torch.tensor([[1, 2]])
q = torch.tensor([[2], [1]])
print(torch.mm(p, q))
print(torch.bmm(m, n).size())
幂运算
print('\npower:')
print(p ** 2)
print(torch.pow(p, 2))
print(p.pow(2), p)
p.pow_(2)
print(p)
r = torch.tensor([4., 4.])
print(torch.sqrt(r))
print(r.sqrt(), r)
r.sqrt_()
print(r)
指数和对数运算
print('\nexponent and logarithm:')
print(torch.exp(r))
print(r.exp(), r)
r.exp_()
print(r)
print(torch.log(r))
print(r.log(), r)
r.log_()
print(r)
近似计算
print('\napproximate:')
z = torch.tensor([3.49, 3.51])
print(torch.floor(z), torch.ceil(z), torch.trunc(z), torch.frac(z))
print(z.floor(), z.ceil(), z.trunc(), z.frac())
print(torch.round(z), z.round())
print(torch.clamp(z, 3.4, 3.5), z.clamp(3.4, 3.5))
统计信息
print(torch.max(z), torch.min(z), torch.mean(z), torch.var(z), torch.median(z))
print(z.max(), z.min(), z.mean(), z.var(), z.median())
类型转换
int float转换
import torch
tensor = torch.randn(3, 5)
print(tensor)
long_tensor = tensor.long()
print(long_tensor)
half_tensor = tensor.half()
print(half_tensor)
int_tensor = tensor.int()
print(int_tensor)
double_tensor = tensor.double()
print(double_tensor)
float_tensor = tensor.float()
print(float_tensor)
char_tensor = tensor.char()
print(char_tensor)
byte_tensor = tensor.byte()
print(byte_tensor)
short_tensor = tensor.short()
print(short_tensor)
参考链接:
- Pytorch中Tensor与各种图像格式的相互转化
- pytorch中tensor的基本数学运算
- pytorch-张量-张量的计算-统计相关的计算
- PyTorch 笔记(02)— 常用创建 Tensor 方法(torch.Tensor、ones、zeros、eye、arange、linspace、rand、randn、new)
- python获取字典的key和value
- pytorch: tensor类型的构建与相互转换
- torch学习 (41):torch中的tensor初始化操作