2021李宏毅机器学习(2):PyTorch

2021李宏毅机器学习(2):PyTorch

  • 1 基础内容
    • 1.1 产生tensor
    • 1.2 squeeze压缩维度
    • 1.3 unsqueeze展出维度
    • 1.4 transpose转置
    • 1.5 cat指定dimension连接多个tensor
    • 1.6 计算梯度
  • 2 神经网络
    • 2.1 读取data
    • 2.2 torch.nn
    • 2.3 torch.optim
  • 3 整个流程
    • 3.1 training
    • 3.2 validation
    • 3.3 testing
  • 4 下载和加载
    • 4.1 Save
    • 4.2 Load

1 基础内容

1.1 产生tensor

import torch
import numpy as np
x = torch.tensor([[1, -1], [-1, 1]])
y = torch.from_numpy(np.array([[1, -1], [-1, 1]]))
x, y
(tensor([[ 1, -1],
         [-1,  1]]),
 tensor([[ 1, -1],
         [-1,  1]], dtype=torch.int32))

1.2 squeeze压缩维度

x = torch.zeros([1, 2, 3])
y = x.squeeze(0)
x, y, x.shape, y.shape
(tensor([[[0., 0., 0.],
          [0., 0., 0.]]]),
 tensor([[0., 0., 0.],
         [0., 0., 0.]]),
 torch.Size([1, 2, 3]),
 torch.Size([2, 3]))

1.3 unsqueeze展出维度

x = torch.zeros([2, 3])
y = x.unsqueeze(1)  # dim = 1
z = x.unsqueeze(2)  # dim = 2
x, y, z, x.shape, y.shape, z.shape
(tensor([[0., 0., 0.],
         [0., 0., 0.]]),
 tensor([[[0., 0., 0.]],
 
         [[0., 0., 0.]]]),
 tensor([[[0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.]]]),
 torch.Size([2, 3]),
 torch.Size([2, 1, 3]),
 torch.Size([2, 3, 1]))

1.4 transpose转置

x = torch.zeros([2, 3])
y = x.transpose(0, 1)
x.shape, y.shape
(torch.Size([2, 3]), torch.Size([3, 2]))

1.5 cat指定dimension连接多个tensor

x = torch.zeros(2,1,3)
y = torch.zeros(2,3,3)
z = torch.zeros(2,2,3)
w = torch.cat([x, y, z], dim=1)
w.shape
torch.Size([2, 6, 3])

1.6 计算梯度

2021李宏毅机器学习(2):PyTorch_第1张图片

2 神经网络

2021李宏毅机器学习(2):PyTorch_第2张图片

2.1 读取data

2021李宏毅机器学习(2):PyTorch_第3张图片
2021李宏毅机器学习(2):PyTorch_第4张图片
二者是包含关系。

2.2 torch.nn

layer = torch.nn.Linear(32, 64)
layer.weight.shape, layer.bias.shape
(torch.Size([64, 32]), torch.Size([64]))
nn.Sigmoid()
nn.ReLU()
nn.MSELoss()  # 多用于线性回归
nn.CrossEntropyLoss()  #多用于分类

2021李宏毅机器学习(2):PyTorch_第5张图片

2.3 torch.optim

SGD:

torch.optin.sGD(params, lr , momentum = 0)

3 整个流程

3.1 training

2021李宏毅机器学习(2):PyTorch_第6张图片
2021李宏毅机器学习(2):PyTorch_第7张图片

3.2 validation

2021李宏毅机器学习(2):PyTorch_第8张图片

3.3 testing

2021李宏毅机器学习(2):PyTorch_第9张图片

4 下载和加载

4.1 Save

torch.save( model.state_dict(), path)

4.2 Load

ckpt = torch.load(path)
model.load_state_dict(ckpt)

你可能感兴趣的:(深度学习,pytorch,机器学习,深度学习,神经网络,人工智能)