# 逐张图像数据扩充
import torch
from PIL import Image
import torchvision
from torchvision import transforms
img_transfrom = transforms.Compose([
transforms.Resize((300, 300)),
transforms.CenterCrop(size=200),
transforms.RandomHorizontalFlip(p=1),
transforms.Pad(padding=20, fill=(124, 20, 187)),
transforms.Grayscale(num_output_channels=3),
transforms.RandomAffine(degrees=(30, 90)),
transforms.ToTensor(),
transforms.RandomErasing(1, scale=(0.02, 0.03), ratio=(0.3, 0.3), value="(255/255, 0, 0)"),
])
img_input = Image.open('1.png') # Image使用open函数读入图像
img_output = img_transfrom(img_input) # 输出是张量形式, C*H*W
img_show = torchvision.transforms.ToPILImage()(img_output) # 注意ToPILImage()的调用方式
img_show.show() # 直接调用PIL Image的show函数即可显示该图片
print(img_show) # 打印图片信息
print(img_show.mode) # RGB
print(img_show.size) # (240, 240)
# 批量数据扩充,使用DataLoader
import torch
import torchvision
from torchvision import datasets, transforms
img_transfrom = transforms.Compose([
transforms.Resize((300, 300)),
transforms.CenterCrop(size=200),
transforms.RandomHorizontalFlip(p=1),
transforms.Pad(padding=20, fill=(124, 20, 187)),
transforms.Grayscale(num_output_channels=3),
transforms.RandomAffine(degrees=(30, 90)),
transforms.ToTensor(),
transforms.RandomErasing(1, scale=(0.02, 0.03), ratio=(0.3, 0.3), value="(255/255, 0, 0)"),
])
imgSet = datasets.ImageFolder('/root/PycharmProjects/Pytorch_S/images', transform=img_transfrom)
imgLoader = torch.utils.data.DataLoader(imgSet, num_workers=2, batch_size=4, shuffle=True)
for i, img in enumerate(imgLoader, 0):
img_show, label = img
img_nums = img_show.shape[0]
for num in range(img_nums):
img_signal = img_show[num, ...].squeeze(0)
print(img_signal.shape)
img_aug = torchvision.transforms.ToPILImage()(img_signal)
img_aug.show()
# 自定义DataSet读取文件,并进行数据扩充
import os
import torch
from PIL import Image
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets, transforms
class MyDataset(Dataset):
def __init__(self, img_dir, transform=None):
"""
:param img_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理方式
"""
self.data_info = []
self.get_img(img_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
img, label = self.data_info[index]
img = self.transform(img) # 在返回数据之前需要进行transform操作,并且DataLoader只接受tensor,ndarray等类型,其中不包含PIL Image
return img, label
def __len__(self):
return len(self.data_info)
def get_img(self, img_dir):
for fn in os.listdir(img_dir):
img = Image.open(img_dir + '/' + fn)
self.data_info.append((img, fn)) # 因为这里用不到label,因此label没有特定去写
img_transfrom = transforms.Compose([
transforms.Resize((300, 300)),
transforms.CenterCrop(size=200),
transforms.RandomHorizontalFlip(p=1),
transforms.Pad(padding=20, fill=(124, 20, 187)),
transforms.Grayscale(num_output_channels=3),
transforms.RandomAffine(degrees=(30, 90)),
transforms.ToTensor(),
transforms.RandomErasing(1, scale=(0.02, 0.03), ratio=(0.3, 0.3), value="(255/255, 0, 0)"),
])
imgSet = MyDataset('/root/PycharmProjects/Pytorch_S/images/train', transform=img_transfrom)
imgLoader = torch.utils.data.DataLoader(imgSet, num_workers=2, batch_size=4, shuffle=True)
for i, img in enumerate(imgLoader, 0):
img_show, label = img
print(i, label)
img_nums = img_show.shape[0]
for num in range(img_nums):
img_signal = img_show[num, ...].squeeze(0)
img_aug = torchvision.transforms.ToPILImage()(img_signal)
img_aug.show()
print('----- one batch -----')