PyTorch自定义Transform方法

P y T o r c h 自定义 T r a n s f o r m 方法 PyTorch自定义Transform方法 PyTorch自定义Transform方法

图像增强为例

from PIL import Image
from torchvision import transforms
from utils import transform_invert
import random
import numpy as np

class Enhance(object):
    """增加椒盐噪声
    Args:
        x():乘
        y ():"""

    def __init__(self, x=1, y=0):
        self.x = x
        self.y = y

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """

        img_ = np.array(img).copy()
        img_ = img_*self.x + self.y
        return Image.fromarray(img_.astype('uint8')).convert('RGB')


if __name__ == '__main__':
    # 1.读取图像
    img = Image.open(r"./cat.png").convert('RGB')


    # 2.确定预处理方式
    img_transform = transforms.Compose([
                        transforms.Grayscale(),
                        Enhance(x=2, y=22),
                        transforms.ToTensor()  # 转Tensor型变量
                                        ])

    img_tensor = img_transform(img)

    # 4.逆Transform变换
    img = transform_invert(img_tensor, img_transform)  # input: shape=[c h w]
    # 5.进行预处理效果展示
    img.show()

PyTorch自定义Transform方法_第1张图片

你可能感兴趣的:(1024程序员节)