paddle框架数据增强

为了增加数据的泛化能力,一般在使用使用数据进行模型训练以前都会对数据进行增强
数据增强的方法有随机裁减,改变图像尺寸,图像随机旋转,随机改变亮度,随机混合,随机增加噪声等
paddle支持的数据处理方法可以通过以下命令查看

print("飞桨支持的数据预处理方式:" + str(paddle.vision.transforms.__all__))

运行结果

飞桨支持的数据预处理方式:['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'normalize']

使用的时候框架内置的数据集和自定义数据集略有些不同

1.框架内置数据集:

内置数据集使用数据增强的方式非常简单,我们可以直接定义一个数据预处理的方式,然后将其作为参数,在加载内置数据集的时候,传给 transform 参数即可。

具体代码如下:

import paddle.vision.transforms as T

# 方式一 只对图像进行调整亮度的操作
transform = T.BrightnessTransform(0.4)
# 通过transform参数传递定义好的数据增方法即可完成对自带数据集的数据增强
train_dataset_without_transform = vision.datasets.Cifar10(mode='train')
train_dataset_with_transform = vision.datasets.Cifar10(mode='train', transform=transform)

index = 10
print("未调整亮度的图像")
train_dataset_without_data_0 = np.array(train_dataset_without_transform[index][0])
train_dataset_without_data_0 = train_dataset_without_data_0.astype('float32') / 255.
plt.imshow(train_dataset_without_data_0)

print("调整亮度的图像")
train_dataset_with_data_0 = np.array(train_dataset_with_transform[index][0])
train_dataset_with_data_0 = train_dataset_with_data_0.astype('float32') / 255.
plt.imshow(train_dataset_with_data_0)

而如果想对一个数据集进行多个数据预处理的方式,可以先定义一个 transform 的容器 Compose,将我们需要的数据预处理方法以 list 的格式传入 Compose,然后在加载内置数据集的时候,传给 transform参数即可。

import paddle.vision.transforms as T

# 方式二 对图像进行多种操作
transform = T.Compose([T.BrightnessTransform(0.4), T.ContrastTransform(0.4)])
# 通过transform参数传递定义好的数据增方法即可完成对自带数据集的数据增强
train_dataset_without_compose = vision.datasets.Cifar10(mode='train')
train_dataset_with_compose = vision.datasets.Cifar10(mode='train', transform=transform)

index = 10
print("未调整的图像")
train_dataset_without_compose_data_0 = np.array(train_dataset_without_compose[index][0])
train_dataset_without_compose_data_0 = train_dataset_without_compose_data_0.astype('float32') / 255.
plt.imshow(train_dataset_without_compose_data_0)

print("多种调整后的图像")
train_dataset_with_compose_data_0 = np.array(train_dataset_with_compose[index][0])
train_dataset_with_compose_data_0 = train_dataset_with_compose_data_0.astype('float32') / 255.
plt.imshow(train_dataset_with_compose_data_0)

2.自定义数据图像

针对自定义数据集使用数据增强的方式, 比较直观的方式是在在数据集的构造函数中进行数据增强方法的定义,之后对__getitem__中返回的数据进行应用。我们以上述中FashionMNIST数据集为例来说明,具体如下:

class FashionMNISTDataset(paddle.io.Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, path='./', mode='train', transform='None'):
        """
        步骤二:实现构造函数,定义数据读取方式,划分训练和测试数据集
        """
        super(FashionMNISTDataset, self).__init__()

        images_data_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % mode)
        labels_data_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % mode)
        with gzip.open(labels_data_path, 'rb') as lbpath:
            self.labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

        with gzip.open(images_data_path, 'rb') as imgpath:
            self.images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(self.labels), 784)
        self.transform = None
        if transform != 'None':
            self.transform = transform
def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        if self.transform:
            image = self.transform(self.images[index].reshape(28, 28))
        else:
            image = self.images[index]
        label = self.labels[index]

        return image, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return len(self.images)

测试未处理的数据集

fashion_mnist_train_dataset_without_transform = FashionMNISTDataset(mode='train')

# 可视化
fashion_mnist_train_dataset_without_transform = np.array(fashion_mnist_train_dataset_without_transform[0][0])
fashion_mnist_train_dataset_without_transform = fashion_mnist_train_dataset_without_transform.reshape([28, 28])
plt.imshow(fashion_mnist_train_dataset_without_transform, cmap=plt.cm.binary)

测试处理的数据集

from paddle.vision.transforms import RandomVerticalFlip
fashion_mnist_train_dataset_with_transform = FashionMNISTDataset(mode='train', transform=RandomVerticalFlip(0.4))

# 可视化
fashion_mnist_train_dataset_with_transform = np.array(fashion_mnist_train_dataset_with_transform[0][0])
fashion_mnist_train_dataset_with_transform = fashion_mnist_train_dataset_with_transform.reshape([28, 28])
plt.imshow(fashion_mnist_train_dataset_with_transform, cmap=plt.cm.binary)

3.数据加载

定义了数据集后,就需要加载数据集。可以通过 paddle.io.DataLoader 完成数据的加载

train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)

for batch_id, data in enumerate(train_loader()):
    x_data = data[0]
    y_data = data[1]
    print(x_data.numpy().shape)
    print(y_data.numpy().shape)
    break

4.数据采样

飞桨框架提供了多种数据采样器,用于不同的场景,来提升训练模型的泛化性能。飞桨框架包含的采样器如下: paddle.io.BatchSampler 、 paddle.io.DistributedBatchSampler 、paddle.io.RandomSampler、paddle.io.SequenceSampler 等,代码示例如下:

from paddle.io import SequenceSampler, RandomSampler, BatchSampler, DistributedBatchSampler

class RandomDataset(paddle.io.Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([784]).astype('float32')
        label = np.random.randint(0, 9, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples
    
train_dataset = RandomDataset(100)

print('-----------------顺序采样----------------')
sampler = SequenceSampler(train_dataset)
batch_sampler = BatchSampler(sampler=sampler, batch_size=10)

for index in batch_sampler:
    print(index)
    
print('-----------------随机采样----------------')
sampler = RandomSampler(train_dataset)
batch_sampler = BatchSampler(sampler=sampler, batch_size=10)

for index in batch_sampler:
    print(index)

print('-----------------分布式采样----------------')
batch_sampler = DistributedBatchSampler(train_dataset, num_replicas=2, batch_size=10)

for index in batch_sampler:
print(index)

你可能感兴趣的:(paddle)