mixup 数据增强(自定义数据集)

mixup 数据增强(自定义数据集)_第1张图片

 mixup 数据增强(自定义数据集)_第2张图片

 这在不同类之间提供了连续的数据样本,直观地扩展了给定训练集的分布,从而使网络在测试阶段更加健壮。

 mixup之后的bbox就是2张图的bbox都有

import os
import ast
from collections import namedtuple

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

from tqdm import tqdm
from PIL import Image

import joblib
from joblib import Parallel, delayed

import cv2
import albumentations
from albumentations.pytorch.transforms import ToTensorV2
from albumentations.core.transforms_interface import DualTransform
from albumentations.core.bbox_utils import denormalize_bbox, normalize_bbox

from sklearn.model_selection import StratifiedKFold

import torch
from torch.utils.data import DataLoader, Dataset
import torch.utils.data as data_utils

from matplotlib import pyplot as plt
import matplotlib.patches as patches
from matplotlib.image import imsave

def get_bbox(bboxes, col, color='white'):
    for i in range(len(bboxes)):
        # Create a Rectangle patch
        rect = patches.Rectangle(
            (bboxes[i][0], bboxes[i][1]),
            bboxes[i][2] - bboxes[i][0], 
            bboxes[i][3] - bboxes[i][1], 
            linewidth=2, 
            edgecolor=color, 
            facecolor='none')

        # Add the patch to the Axes
        col.add_patch(rect)
        
class WheatDataset(Dataset):
    
    def __init__(self, df):
        self.df = df
        self.image_ids = self.df['image_id'].unique()

    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image = cv2.imread(os.path.join(BASE_DIR, 'train', f'{image_id}.jpg'), cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0  # Normalize
        
        # Get bbox coordinates for each wheat head(s)
        bboxes_df = self.df[self.df['image_id'] == image_id]
        boxes, areas = [], []
        n_objects = len(bboxes_df)  # Number of wheat heads in the given image

        for i in range(n_objects):
            x_min = bboxes_df.iloc[i]['x_min']
            x_max = bboxes_df.iloc[i]['x_max']
            y_min = bboxes_df.iloc[i]['y_min']
            y_max = bboxes_df.iloc[i]['y_max']

            boxes.append([x_min, y_min, x_max, y_max])
            areas.append(bboxes_df.iloc[i]['area'])

        return {
            'image_id': image_id,
            'image': image,
            'boxes': boxes,
            'area': areas,
        }
        
def collate_fn(batch):
    images, bboxes, areas, image_ids = ([] for _ in range(4))
    for data in batch:
        images.append(data['image'])
        bboxes.append(data['boxes'])
        areas.append(data['area'])
        image_ids.append(data['image_id'])

    return np.array(images), np.array(bboxes), np.array(areas), np.array(image_ids)     

def mixup(images, bboxes, areas, alpha=1.0):
    """
    Randomly mixes the given list if images with each other
    
    :param images: The images to be mixed up
    :param bboxes: The bounding boxes (labels)
    :param areas: The list of area of all the bboxes
    :param alpha: Required to generate image wieghts (lambda) using beta distribution. In this case we'll use alpha=1, which is same as uniform distribution
    """
    # Generate random indices to shuffle the images
    indices = torch.randperm(len(images))
    shuffled_images = images[indices]
    shuffled_bboxes = bboxes[indices]
    shuffled_areas = areas[indices]
    
    # Generate image weight (minimum 0.4 and maximum 0.6)
    lam = np.clip(np.random.beta(alpha, alpha), 0.4, 0.6)
    print(f'lambda: {lam}')
    
    # Weighted Mixup
    mixedup_images = lam*images + (1 - lam)*shuffled_images
    
    mixedup_bboxes, mixedup_areas = [], []
    for bbox, s_bbox, area, s_area in zip(bboxes, shuffled_bboxes, areas, shuffled_areas):
        mixedup_bboxes.append(bbox + s_bbox)
        mixedup_areas.append(area + s_area)
    
    return mixedup_images, mixedup_bboxes, mixedup_areas, indices.numpy()

def read_image(image_id):
    """Read the image from image id"""

    image = cv2.imread(os.path.join(BASE_DIR, 'train', f'{image_id}.jpg'), cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
    image /= 255.0  # Normalize
    return image

if __name__ == '__main__':
    # Constants
    BASE_DIR = 'global-wheat-detection'
    # WORK_DIR = '/kaggle/working'
    BATCH_SIZE = 16

    # Set seed for numpy for reproducibility
    np.random.seed(1996)


    train_df = pd.read_csv(os.path.join(BASE_DIR, 'train.csv'))

    # Let's expand the bounding box coordinates and calculate the area of all the bboxes
    train_df[['x_min','y_min', 'width', 'height']] = pd.DataFrame([ast.literal_eval(x) for x in train_df.bbox.tolist()], index= train_df.index)
    train_df = train_df[['image_id', 'bbox', 'source', 'x_min', 'y_min', 'width', 'height']]
    train_df['area'] = train_df['width'] * train_df['height']
    train_df['x_max'] = train_df['x_min'] + train_df['width']
    train_df['y_max'] = train_df['y_min'] + train_df['height']
    train_df = train_df.drop(['bbox'], axis=1)
    train_df = train_df[['image_id', 'x_min', 'y_min', 'x_max', 'y_max', 'width', 'height', 'area', 'source']]

    # There are some buggy annonations in training images having huge bounding boxes. Let's remove those bboxes
    train_df = train_df[train_df['area'] < 100000]
    image_ids = train_df['image_id'].unique()
    # train_df.head()


    train_dataset = WheatDataset(train_df)
    train_loader = data_utils.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=collate_fn)

    images, bboxes, areas, image_ids = next(iter(train_loader))
    aug_images, aug_bboxes, aug_areas, aug_indices = mixup(images, bboxes, areas)

    fig, ax = plt.subplots(nrows=5, ncols=3, figsize=(15, 30))
    for index in range(5):
        image_id = image_ids[index]
        image = read_image(image_id)

        get_bbox(bboxes[index], ax[index][0], color='red')
        ax[index][0].grid(False)
        ax[index][0].set_xticks([])
        ax[index][0].set_yticks([])
        ax[index][0].title.set_text('Original Image #1')
        ax[index][0].imshow(image)
        
        image_id = image_ids[aug_indices[index]]
        image = read_image(image_id)
        get_bbox(bboxes[aug_indices[index]], ax[index][1], color='red')
        ax[index][1].grid(False)
        ax[index][1].set_xticks([])
        ax[index][1].set_yticks([])
        ax[index][1].title.set_text('Original Image #2')
        ax[index][1].imshow(image)

        get_bbox(aug_bboxes[index], ax[index][2], color='red')
        ax[index][2].grid(False)
        ax[index][2].set_xticks([])
        ax[index][2].set_yticks([])
        ax[index][2].title.set_text(f'Augmented Image: lambda * image1 + (1 - lambda) * image2')
        ax[index][2].imshow(aug_images[index])
    plt.show()
    plt.savefig('mixup.jpg')

 Data Augmentation Tutorial: Basic, Cutout, Mixup | Kaggle

Pytorch实现

在CIFAR-10数据集上

变化都在取数据的时候,将image和label都融合。网络训练部分和正常的一样

"""
Import necessary libraries to train a network using mixup
The code is mainly developed using the PyTorch library
"""
import numpy as np
import pickle
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader


"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


"""
Create a simple CNN
"""
class CNN(nn.Module):
   def __init__(self):
        super(CNN, self).__init__()

        # Network consists of 4 convolutional layers followed by 2 fully-connected layers
        self.conv11 = nn.Conv2d(3, 64, 3)
        self.conv12 = nn.Conv2d(64, 64, 3)
        self.conv21 = nn.Conv2d(64, 128, 3)
        self.conv22 = nn.Conv2d(128, 128, 3)
        self.fc1 = nn.Linear(128 * 5 * 5, 256)
        self.fc2 = nn.Linear(256, 10)
   def forward(self, x):
       x = F.relu(self.conv11(x))
       x = F.relu(self.conv12(x))
       x = F.max_pool2d(x, (2,2))
       x = F.relu(self.conv21(x))
       x = F.relu(self.conv22(x))
       x = F.max_pool2d(x, (2,2))

       # Size is calculated based on kernel size 3 and padding 0
       x = x.view(-1, 128 * 5 * 5)
       x = F.relu(self.fc1(x))
       x = self.fc2(x)

       return nn.Sigmoid()(x)


"""
Dataset and Dataloader creation
All data are downloaded found via Graviti Open Dataset which links to CIFAR-10 official page
The dataset implementation is where mixup take place
"""

class CIFAR_Dataset(Dataset):
    def __init__(self, data_dir, train, transform):
        self.data_dir = data_dir
        self.train = train
        self.transform = transform
        self.data = []
        self.targets = []

        # Loading all the data depending on whether the dataset is training or testing
        if self.train:
            for i in range(5):
                with open(data_dir + 'data_batch_' + str(i+1), 'rb') as f:
                    entry = pickle.load(f, encoding='latin1')
                    self.data.append(entry['data'])
                    self.targets.extend(entry['labels'])
        else:
            with open(data_dir + 'test_batch', 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                self.targets.extend(entry['labels'])

        # Reshape it and turn it into the HWC format which PyTorch takes in the images
        # Original CIFAR format can be seen via its official page
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        # Create a one hot label
        label = torch.zeros(10)
        label[self.targets[idx]] = 1.

        # Transform the image by converting to tensor and normalizing it
        if self.transform:
            image = transform(self.data[idx])

        # If data is for training, perform mixup, only perform mixup roughly on 1 for every 5 images
        if self.train and idx > 0 and idx%5 == 0:

            # Choose another image/label randomly
            mixup_idx = random.randint(0, len(self.data)-1)
            mixup_label = torch.zeros(10)
            label[self.targets[mixup_idx]] = 1.
            if self.transform:
                mixup_image = transform(self.data[mixup_idx])

            # Select a random number from the given beta distribution
            # Mixup the images accordingly
            alpha = 0.2
            lam = np.random.beta(alpha, alpha)
            image = lam * image + (1 - lam) * mixup_image
            label = lam * label + (1 - lam) * mixup_label

        return image, label

"""
Define the hyperparameters, image transform components, and the dataset/dataloaders
"""
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

BATCH_SIZE = 64
NUM_WORKERS = 4
LEARNING_RATE = 0.0001
NUM_EPOCHS = 30


train_dataset = CIFAR_Dataset('../lian/dataset/cifar-10-batches-py/', 1, transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

test_dataset = CIFAR_Dataset('../lian/dataset/cifar-10-batches-py/', 0, transform)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)


"""
Initialize the network, loss Adam optimizer
Torch BCE Loss does not support mixup labels (not 1 or 0), so we implement our own
"""
net = CNN().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
def bceloss(x, y):
    eps = 1e-6
    return -torch.mean(y * torch.log(x + eps) + (1 - y) * torch.log(1 - x + eps))
best_Acc = 0


"""
Training Procedure
"""
for epoch in range(NUM_EPOCHS):
    net.train()
    # We train and visualize the loss every 100 iterations
    for idx, (imgs, labels) in enumerate(train_dataloader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        preds = net(imgs)
        loss = bceloss(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx%100 == 0:
            print("Epoch {} Iteration {}, Current Loss: {}".format(epoch, idx, loss))

    # We evaluate the network after every epoch based on test set accuracy
    net.eval()
    with torch.no_grad():
        total = 0
        numCorrect = 0
        for (imgs, labels) in test_dataloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            preds = net(imgs)
            numCorrect += (torch.argmax(preds, dim=1) == torch.argmax(labels, dim=1)).float().sum()
            total += len(imgs)
        acc = numCorrect/total
        print("Current image classification accuracy at epoch {}: {}".format(epoch, acc))
        if acc > best_Acc:
            best_Acc = acc

"""
Printing out overall best result
"""
print("Best Result: {}".format(best_Acc))

你可能感兴趣的:(深度学习,神经网络,人工智能)