利用Pytorch的transform进行数据扩充

# 逐张图像数据扩充

import torch
from PIL import Image
import torchvision
from torchvision import transforms

img_transfrom = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.CenterCrop(size=200),
    transforms.RandomHorizontalFlip(p=1),
    transforms.Pad(padding=20, fill=(124, 20, 187)),
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomAffine(degrees=(30, 90)),
    transforms.ToTensor(),
    transforms.RandomErasing(1, scale=(0.02, 0.03), ratio=(0.3, 0.3), value="(255/255, 0, 0)"),
])
img_input = Image.open('1.png')         # Image使用open函数读入图像
img_output = img_transfrom(img_input)   # 输出是张量形式, C*H*W
img_show = torchvision.transforms.ToPILImage()(img_output)  # 注意ToPILImage()的调用方式
img_show.show()   # 直接调用PIL Image的show函数即可显示该图片
print(img_show)   # 打印图片信息 
print(img_show.mode)    # RGB
print(img_show.size)    # (240, 240)
# 批量数据扩充,使用DataLoader

import torch
import torchvision
from torchvision import datasets, transforms

img_transfrom = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.CenterCrop(size=200),
    transforms.RandomHorizontalFlip(p=1),
    transforms.Pad(padding=20, fill=(124, 20, 187)),
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomAffine(degrees=(30, 90)),
    transforms.ToTensor(),
    transforms.RandomErasing(1, scale=(0.02, 0.03), ratio=(0.3, 0.3), value="(255/255, 0, 0)"),
])

imgSet = datasets.ImageFolder('/root/PycharmProjects/Pytorch_S/images', transform=img_transfrom)

imgLoader = torch.utils.data.DataLoader(imgSet, num_workers=2, batch_size=4, shuffle=True)


for i, img in enumerate(imgLoader, 0):
    img_show, label = img
    img_nums = img_show.shape[0]
    for num in range(img_nums):
        img_signal = img_show[num, ...].squeeze(0)
        print(img_signal.shape)
        img_aug = torchvision.transforms.ToPILImage()(img_signal)
        img_aug.show()
# 自定义DataSet读取文件,并进行数据扩充

import os
import torch
from PIL import Image
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets, transforms

class MyDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        """
        :param img_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理方式
        """
        self.data_info = []
        self.get_img(img_dir)             # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        img, label = self.data_info[index]
        img = self.transform(img)         # 在返回数据之前需要进行transform操作,并且DataLoader只接受tensor,ndarray等类型,其中不包含PIL Image
        return img, label

    def __len__(self):
        return len(self.data_info)


    def get_img(self, img_dir):
        for fn in os.listdir(img_dir):
            img = Image.open(img_dir + '/' + fn)
            self.data_info.append((img, fn))    # 因为这里用不到label,因此label没有特定去写

img_transfrom = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.CenterCrop(size=200),
    transforms.RandomHorizontalFlip(p=1),
    transforms.Pad(padding=20, fill=(124, 20, 187)),
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomAffine(degrees=(30, 90)),
    transforms.ToTensor(),
    transforms.RandomErasing(1, scale=(0.02, 0.03), ratio=(0.3, 0.3), value="(255/255, 0, 0)"),
])

imgSet = MyDataset('/root/PycharmProjects/Pytorch_S/images/train', transform=img_transfrom)

imgLoader = torch.utils.data.DataLoader(imgSet, num_workers=2, batch_size=4, shuffle=True)

for i, img in enumerate(imgLoader, 0):
    img_show, label = img
    print(i, label)
    img_nums = img_show.shape[0]
    for num in range(img_nums):
        img_signal = img_show[num, ...].squeeze(0)
        img_aug = torchvision.transforms.ToPILImage()(img_signal)
        img_aug.show()
    print('----- one batch -----')

 

你可能感兴趣的:(实验室,python)