pytorch框架--数据方面--实际使用版

应用于图像分类

from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader

# 图像目标路径
ROOT_TRAIN = './data/train'
# 一组2个图像打包
batch_size = 2

# 预处理打包:  可加入更多预处理操作
train_transform = transforms.Compose([
    # 缩放
    transforms.Resize((224, 224)),
    # ------------------数据增强 开始---------------------
    # 随机旋转,-45度到45度之间随机选
    transforms.RandomRotation(45),
    # 从中心开始裁剪
    transforms.CenterCrop(224),
    # 随机水平翻转 选择概率值为 p=0.5
    transforms.RandomHorizontalFlip(p=0.5),
    # 随机垂直翻转
    transforms.RandomVerticalFlip(p=0.5),
    # 参数:亮度、对比度、饱和度、色相
    transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
    # 转为3通道灰度图 R=G=B 概率设定0.025
    transforms.RandomGrayscale(p=0.025),
    # ------------------数据增强 结束---------------------
    
    # 数据格式转换为rensor  必备步骤
    transforms.ToTensor(),
    # 归一化
    # 将图像的像素值归一化到【-1, 1】之间   均值、标准差 是计算后得出的,例如VGG:[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    # 这0.5是随便给的
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

# 数据路径、模型预处理打包
train_data = ImageFolder(ROOT_TRAIN, transform=train_transform)

# 格式打包
# 参数:数据、1组几个、下一轮轮是否打乱、进程个数、最后一组是否凑成一组
train_datas = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)

# 返回数据是 图像对应标签的可迭代对象
#     for data in train_datas:
#        imgs, targets = data

你可能感兴趣的:(Pytorch框架,python,pytorch,计算机视觉)