图像增广
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image
img = Image.open(r'D:\pyPro\hhandwritten_web\static\images\bg.jpg')
plt.imshow(img)
def show_img(img, num_rows, num_cols, scale=2):
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows,num_cols, figsize=figsize)
for i in range(0,num_rows):
for j in range(0,num_cols):
axes[i][j].imshow(img[i * num_rows + num_cols])
axes[i][j].axes.get_xaxis().set_visible(False)
axes[i][j].axes.get_yaxis().set_visible(False)
def apply(img, aug, rows = 2, cols = 4, scale = 1.5):
Y = [aug(img) for _ in range(0,rows*cols)]
show_img(Y, rows, cols, scale)
apply(img, torchvision.transforms.RandomHorizontalFlip())
apply(img, torchvision.transforms.RandomVerticalFlip())
shape_aug = torchvision.transforms.RandomResizedCrop((600,1000), scale= (0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)
apply(img, torchvision.transforms.ColorJitter(brightness=0.5, hue=0.5,contrast=0.5, saturation=0.5))
mutil_aug = torchvision.transforms.Compose(
[torchvision.transforms.RandomResizedCrop((600,1000), scale= (0.1, 1), ratio=(0.5, 2)),
torchvision.transforms.ColorJitter(brightness=0.5, hue=0.5,contrast=0.5, saturation=0.5)]
)
apply(img,mutil_aug)