tensor 基础数据结构 2022-09-05

数据结构,基础中的基础:

pytorch中的核心数据结构为tensor

import torch
import numpy as np

data = [[1.1,2],[3,4]]
data_np = np.array(data)

x_np = torch.from_numpy(data_np)

print(f" tensor matrix:n {x_np}\n")

x_one = torch.ones_like(x_np)  #retains the properties of x_data
x_rand = torch.rand_like(x_np, dtype=torch.float) #overide data type of x_data

print(f"tensor_one: \n {x_one}\n")
print(f"tensor_rand: \n {x_rand}\n")


shape = (2,3,)
rand_tensor = torch.rand(shape)
one_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)

print(f"shape of tensor:\n {zeros_tensor.shape}")
print(f"datapyte of tensor:\n {zeros_tensor.dtype}")
print(f"device tensor is stored on:\n {zeros_tensor.device}")

rand_tensor[:,1] = 0
print(f"rand_tensro is:\n {rand_tensor}")
print(f"rand_tensro first row: {rand_tensor[0]}")
print(f"rand_tensro first colum: {rand_tensor[:,0]}")
print(f"rand_tensro first row: {rand_tensor[...,-1]}")


cat_tensor = torch.cat([one_tensor,zeros_tensor],dim=0)
print(f" concatenate_tensor\n {cat_tensor}\n")

tensor的存储

可以选择存储在cpu 或者gpu中

tensor = torch.from_numpy(data_np)
if torch.cuda.is_available():
    tensor = tensor.to("cuda")

tensor的基本运算

注意区分@:matmul(...) 和 *:mul(...)

# This computes the matrix multiplication between two tensors. y1, y2, y3 will have the same value
y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)

print(f"y1= \n {y1}")
y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)


# This computes the element-wise product. z1, z2, z3 will have the same value
z1 = tensor * tensor
z2 = tensor.mul(tensor)

z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)

#get single element tensor to python data tpye
agg = z3.sum()
agg_item = agg.item()
print(agg_item, type(agg_item))


重要的“_”占位符:

tensor.add(1)
tensor.add_(1)
完全是两个概念

官网解释:
In-place operations Operations that store the result into the operand are called in-place. They are denoted by a _ suffix. For example: x.copy_(y), x.t_(), will change x.0i


n = torch.ones((2,2,))
#change tensor to numpy
nn = n.numpy()
print(f"n=\n {n}")
n.add_(1)
print(f"n.add(1)=\n {n}")
print(f"nn=\n {nn}")

tensor和numpy的相互转换

n = torch.from_numpy(data_np)
n.numpy()

你可能感兴趣的:(tensor 基础数据结构 2022-09-05)