PyTorch应用实战四:基于PyTorch构建复杂应用

文章目录

  • 实验环境
  • 1.PyTorch数据加载
    • 1.1 数据预处理
    • 1.2 数据加载
  • 2.PyTorch模型搭建
    • 2.1 经典模型
    • 2.2 模型加载与保存
  • 3.PyTorch优化器
    • 3.1 torch.optim
    • 3.2 学习率调整
  • 常见函数
  • 附:系列文章

实验环境

torch1.8.0+torchvision0.9.0

import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)
1.8.0
0.9.0+cpu

1.PyTorch数据加载

import torchvision.transforms as tfm
from PIL import Image
img = Image.open('volleyball.png')
img_1 = tfm.RandomCrop(200, padding=50)(img)  #随机裁剪图片
img_1.show()
img_1.save('crop.png')
img_2 = tfm.RandomHorizontalFlip()(img)       #随机水平翻转图片
img_2.show()
img_2.save('flip.png')

1.1 数据预处理

torchvision.transforms

transfrom_train = tfm.Compose([
    tfm.RandomCrop(32, padding=4),
    tfm.RandomHorizontalFlip(),   
    tfm.ToTensor(),     #将图片转换为Tensor张量                       
    tfm.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))  #标准化
])

1.2 数据加载

torch.utils.data

loader = torch.utils.data.DataLoader(
    datasets, batch_size=32, shuffle=True, sampler=None,
    num_workers=2, collate_fn=None, pin_memory=True, drop_last=False
)
  • datasets:传入的数据集,可以是自定义的dataset对象或者torchvision中的预定义数据集对象。
  • batch_size:每个batch中包含的样本数量。
  • shuffle:是否打乱数据集。
  • sampler:样本抽样器,如果指定了sampler,则忽略shuffle参数。
  • num_workers:用于数据加载的子进程数量。
  • collate_fn:对样本进行批处理前的预处理函数,可用于对样本进行排序、padding等操作。
  • pin_memory:是否将数据加载到GPU的显存中。
  • drop_last:如果数据集样本数量不能被batch_size整除,则是否舍弃剩余的不足一个batch的样本。

2.PyTorch模型搭建

2.1 经典模型

torchvision.models

from torchvision import models
net1 = models.resnet50()
net2 = models.resnet50(pretrained=True)

2.2 模型加载与保存

model.load_state_dict(torch.load('pretrained_weights.pth'))
torch.save(model.state_dict(), 'model_weights.pth')

3.PyTorch优化器

3.1 torch.optim

optimizer = optim.SGD([       #SGD随机梯度下降算法
    {'params':model.base.parameters()},
    {'params':model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
# 训练过程 
model = init_model_function()               #模型构建
optimizer = optim.SomeOptimizer(            #设置优化器
    model.parameters(), lr, mm
)

for data, label in train_dataloader:
    optimizer.zero_grad()                #前向计算前,清空原有梯度
    output = model(data)                 #前向计算
    loss = loss_function(output, label)  #损失函数
    loss.backward()                      #反向传播 
    optimizer.step()                     #更新参数

3.2 学习率调整

scheduler = optim.lr_scheduler.SomeScheduler(optimizer, *args)
for epoch in range(epochs):
    train()
    test()
    scheduler.step()

常见函数

激活单元类型

ELU MultiheadAttention SELU softshrink Softmin
Hardshrink PReLU CELU Softsign Softmax
Hardtanh ReLU GELU Tanh Softmax2d
LeakyReLU ReLU6 Sigmoid Tanhshrink LogSoftmax
LogSigmoid RReLU Softplus Threshold

损失函数层类型

L1Loss PoissonNLLLoss HingeEmbeddingLoss CosineEmbeddingLoss
MSELoss KLDivLoss MultiLabelMarginLoss MultiMarginLoss
CrossEntropyLoss BCELoss SmoothL1Loss TripletMarginLoss
CTCLoss BCEWithLogitsLoss SoftMarginLoss
NLLLoss MarginRankingLoss MultiLabelSoftMarginLoss

优化器类型

Adadelta AdamW ASGD Rprop
Adagrad SparseAdam LBFGS SGD
Adam Adamax RMSprop

变换操作类型

Compose RandomAffine RandomOrder Resize ToTensor
CenterCrop RandomApply RandomPerspective Scale Lambda
ColorJitter RandomChoice RandomResizedCrop TenCrop
FiveCrop RandomCrop RandomRotation LinearTransformation
Grayscale RandomGrayscale RandomSizedCrop Normalize
Pad RandomHorizontalFlip RandomVerticalFlip ToPILImage

数据集名称

MNIST CocoCaptions CIFAR10 Flickr8k USPS
FashionMNIST cocoDetection CIFAR100 Flickr30k Kinetics400
KMNIST LSUN STL10 VOCSegmentation HMDB51
EMNIST ImageFolder SVHN VOCDetection UCF101
QMNIST DatasetFolder PhotoTour Cityscape CelebA
FakeData ImageNet SBU SBDataset

torchvision.models中所有实现的分类模型

AlexNet VGG-13-bn ResNet-101 Densenet-201 ResNeXt-50-32x4d
VGG-11 VGG-16-bn ResNet-152 Densenet-161 ResNeXt-101-32x8d
VGG-13 VGG-19-bn SqueezeNet Inception-V3 Wide ResNet-50-2
VGG-16 ResNet-18 GoogleNet Wide ResNet-101-2
VGG-19 ResNet-34 Densenet-121 ShuffleNet-V2 MNASNet 1.0
VGG-11-bn ResNet-50 Densenet-169 MobileNet-V2

附:系列文章

序号 文章目录 直达链接
1 PyTorch应用实战一:实现卷积操作 https://want595.blog.csdn.net/article/details/132575530
2 PyTorch应用实战二:实现卷积神经网络进行图像分类 https://want595.blog.csdn.net/article/details/132575702
3 PyTorch应用实战三:构建神经网络 https://want595.blog.csdn.net/article/details/132575758
4 PyTorch应用实战四:基于PyTorch构建复杂应用 https://want595.blog.csdn.net/article/details/132625270
5 PyTorch应用实战五:实现二值化神经网络 https://want595.blog.csdn.net/article/details/132625348
6 PyTorch应用实战六:利用LSTM实现文本情感分类 https://want595.blog.csdn.net/article/details/132625382

你可能感兴趣的:(《,深度学习,》,pytorch,人工智能,python)