PYTORCH LIGHTNING DATAMODULES及SimCLR源码解读

官方文档地址
项目地址

PYTORCH LIGHTNING DATAMODULES

DataModules将数据与模型解耦分开,从而可以只关注模型本身而不用关注数据
自定义DataModules时需要继承LightningModule,并实现以下几个方法

def __init__(self): # 一般用来指定data_dir(数据目录),定义transform,定义默认的self.dims,方便后面对数据的使用

def prepare_data(self): # 下载数据,在该函数对不对数据进行任何操作

def setup(self,stage): # 加载之前下载好的数据,并分配到训练、验证和测试上,stage可以为'fit'或'test',若为'fit'只分配训练集,'test'分配测试机,None则都分配

def train_dataloader(self): # 返回训练集的dataloader

def val_dataloader(self): # 返回验证集的dataloader

def test_dataloader(self): # 返回测试集的dataloader

自定义cifar10的DataModule

class CIFAR10DataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=BATCH_SIZE)

SimCLR源码解读

项目同时实现了多个自监督学习方法,定义了vision基类

import os
from abc import abstractmethod
from typing import Any, Callable, List, Optional, Union

import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split


class VisionDataModule(LightningDataModule): # 继承LightningDataModule
    EXTRA_ARGS: dict = {}
    name: str = ""
    #: Dataset class to use
    dataset_cls: type # 用来下载数据以及数据的划分,比如后面要加载CIFAR10,dataset_cls就是CIFAR10,要加载自己的数据集的话需要相应的实现
    #: A tuple describing the shape of the data
    dims: tuple

    def __init__(
            self,
            data_dir: Optional[str] = None,
            val_split: Union[int, float] = 0.2,# 如果是int则是验证集数据的长度,如果是float是验证集占训练集的百分比
            num_workers: int = 0,
            normalize: bool = False,
            batch_size: int = 32,
            seed: int = 42,
            shuffle: bool = True,
            pin_memory: bool = True,
            drop_last: bool = False,
            *args: Any,
            **kwargs: Any,
    ) -> None:
        """
        Args:
            data_dir: Where to save/load the data
            val_split: Percent (float) or number (int) of samples to use for the validation split
            num_workers: How many workers to use for loading data
            normalize: If true applies image normalize
            batch_size: How many samples per batch to load
            seed: Random seed to be used for train/val/test splits
            shuffle: If true shuffles the train data every epoch
            pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
                        returning them
            drop_last: If true drops the last incomplete batch
        """

        super().__init__(*args, **kwargs)

        self.data_dir = data_dir if data_dir is not None else os.getcwd()
        self.val_split = val_split
        self.num_workers = num_workers
        self.normalize = normalize
        self.batch_size = batch_size
        self.seed = seed
        self.shuffle = shuffle
        self.pin_memory = pin_memory
        self.drop_last = drop_last

    def prepare_data(self, *args: Any, **kwargs: Any) -> None: # 下载数据
        """Saves files to data_dir."""
        self.dataset_cls(self.data_dir, train=True, download=True)
        self.dataset_cls(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None) -> None:
        """Creates train, val, and test dataset."""
        if stage == "fit" or stage is None: # 分配训练集数据
            train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
            val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms

            dataset_train = self.dataset_cls(self.data_dir, train=True, transform=train_transforms, **self.EXTRA_ARGS)
            dataset_val = self.dataset_cls(self.data_dir, train=True, transform=val_transforms, **self.EXTRA_ARGS)

            # Split 分割数据集
            self.dataset_train = self._split_dataset(dataset_train)
            self.dataset_val = self._split_dataset(dataset_val, train=False)

        if stage == "test" or stage is None: # 分配测试数据
            test_transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms
            self.dataset_test = self.dataset_cls(
                self.data_dir, train=False, transform=test_transforms, **self.EXTRA_ARGS
            )

    def _split_dataset(self, dataset: Dataset, train: bool = True) -> Dataset:
        """Splits the dataset into train and validation set."""
        len_dataset = len(dataset)  # type: ignore[arg-type]
        splits = self._get_splits(len_dataset) # 分割后的训练集和测试集数据
        dataset_train, dataset_val = random_split(dataset, splits, generator=torch.Generator().manual_seed(self.seed))

        if train:
            return dataset_train
        return dataset_val

    def _get_splits(self, len_dataset: int) -> List[int]:
        """Computes split lengths for train and validation set."""
        if isinstance(self.val_split, int):
            train_len = len_dataset - self.val_split # 训练集长度
            splits = [train_len, self.val_split] # 返回分割后的训练集和验证集
        elif isinstance(self.val_split, float):
            val_len = int(self.val_split * len_dataset)
            train_len = len_dataset - val_len
            splits = [train_len, val_len]
        else:
            raise ValueError(f"Unsupported type {type(self.val_split)}")

        return splits

    @abstractmethod
    def default_transforms(self) -> Callable: # 子类实现
        """Default transform for the dataset."""
	# dataloader
    def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
        """The train dataloader."""
        return self._data_loader(self.dataset_train, shuffle=self.shuffle)

    def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
        """The val dataloader."""
        return self._data_loader(self.dataset_val)

    def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
        """The test dataloader."""
        return self._data_loader(self.dataset_test)

    def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            drop_last=self.drop_last,
            pin_memory=self.pin_memory,
        )

实现CIFAR10DataModule,继承VisionDataModule

from typing import Any, Callable, Optional, Sequence, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
    from torchvision import transforms as transform_lib
    from torchvision.datasets import CIFAR10
else:  # pragma: no cover
    warn_missing_pkg("torchvision")
    CIFAR10 = None


class CIFAR10DataModule(VisionDataModule):
    """
    .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/
        Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png
        :width: 400
        :alt: CIFAR-10

    Specs:
        - 10 classes (1 per class)
        - Each image is (3 x 32 x 32)

    Standard CIFAR10, train, val, test splits and transforms

    Transforms::

        mnist_transforms = transform_lib.Compose([
            transform_lib.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
            )
        ])

    Example::

        from pl_bolts.datamodules import CIFAR10DataModule

        dm = CIFAR10DataModule(PATH)
        model = LitModel()

        Trainer().fit(model, datamodule=dm)

    Or you can set your own transforms

    Example::

        dm.train_transforms = ...
        dm.test_transforms = ...
        dm.val_transforms  = ...
    """

    name = "cifar10"
    dataset_cls = CIFAR10 # 下载、处理数据
    dims = (3, 32, 32)

    def __init__(
        self,
        data_dir: Optional[str] = None,
        val_split: Union[int, float] = 0.2,
        num_workers: int = 0,
        normalize: bool = False,
        batch_size: int = 32,
        seed: int = 42,
        shuffle: bool = True,
        pin_memory: bool = True,
        drop_last: bool = False,# true则丢掉最后一个不满的batch
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """
        Args:
            data_dir: Where to save/load the data
            val_split: Percent (float) or number (int) of samples to use for the validation split
            num_workers: How many workers to use for loading data
            normalize: If true applies image normalize
            batch_size: How many samples per batch to load
            seed: Random seed to be used for train/val/test splits
            shuffle: If true shuffles the train data every epoch
            pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
                        returning them
            drop_last: If true drops the last incomplete batch
        """
        super().__init__(  # type: ignore[misc]
            data_dir=data_dir,
            val_split=val_split,
            num_workers=num_workers,
            normalize=normalize,
            batch_size=batch_size,
            seed=seed,
            shuffle=shuffle,
            pin_memory=pin_memory,
            drop_last=drop_last,
            *args,
            **kwargs,
        )

    @property
    def num_samples(self) -> int:
        train_len, _ = self._get_splits(len_dataset=50_000)
        return train_len

    @property
    def num_classes(self) -> int:
        """
        Return:
            10
        """
        return 10

    def default_transforms(self) -> Callable: # 覆盖基类函数
        if self.normalize:
            cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
        else:
            cf10_transforms = transform_lib.Compose([transform_lib.ToTensor()])

        return cf10_transforms

上面代码的CIFAR10是直接调用torch.vision里的,在项目里也实现了这个函数,要加载torch没有的数据集的话可以参考一下

class LightDataset(ABC, Dataset):

    data: Tensor
    targets: Tensor
    normalize: tuple
    dir_path: str
    cache_folder_name: str
    DATASET_NAME = "light"

    def __len__(self) -> int:
        return len(self.data)

    @property
    def cached_folder_path(self) -> str:
        return os.path.join(self.dir_path, self.DATASET_NAME, self.cache_folder_name)

    @staticmethod
    def _prepare_subset(
        full_data: Tensor,
        full_targets: Tensor,
        num_samples: int,
        labels: Sequence,
    ) -> Tuple[Tensor, Tensor]:
        """Prepare a subset of a common dataset."""
        classes = {d: 0 for d in labels}
        indexes = []
        for idx, target in enumerate(full_targets):
            label = target.item()
            if classes.get(label, float("inf")) >= num_samples:
                continue
            indexes.append(idx)
            classes[label] += 1
            if all(classes[k] >= num_samples for k in classes):
                break
        data = full_data[indexes]
        targets = full_targets[indexes]
        return data, targets

    def _download_from_url(self, base_url: str, data_folder: str, file_name: str):
        url = os.path.join(base_url, file_name)
        logging.info(f"Downloading {url}")
        fpath = os.path.join(data_folder, file_name)
        try:
            urllib.request.urlretrieve(url, fpath)
        except HTTPError as err:
            raise RuntimeError(f"Failed download from {url}") from err
        
class CIFAR10(LightDataset):
    """Customized `CIFAR10 `_ dataset for testing Pytorch Lightning
    without the torchvision dependency.

    Part of the code was copied from
    https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/

    Args:
        data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt``
            and  ``CIFAR10/processed/test.pt`` exist.
        train: If ``True``, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download: If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    Examples:

        >>> from torchvision import transforms
        >>> from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
        >>> cf10_transforms = transforms.Compose([transforms.ToTensor(), cifar10_normalization()])
        >>> dataset = CIFAR10(download=True, transform=cf10_transforms, data_dir="datasets")
        >>> len(dataset)
        50000
        >>> torch.bincount(dataset.targets)
        tensor([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000])
        >>> data, label = dataset[0]
        >>> data.shape
        torch.Size([3, 32, 32])
        >>> label
        6

    Labels::

        airplane: 0
        automobile: 1
        bird: 2
        cat: 3
        deer: 4
        dog: 5
        frog: 6
        horse: 7
        ship: 8
        truck: 9
    """

    BASE_URL = "https://www.cs.toronto.edu/~kriz/" # 下载地址
    FILE_NAME = "cifar-10-python.tar.gz" # 目标文件
    cache_folder_name = "complete" # 完成解压后pt文件存放目录
    TRAIN_FILE_NAME = "training.pt" # 训练集
    TEST_FILE_NAME = "test.pt" # 测试集
    DATASET_NAME = "CIFAR10" # 根目录
    labels = set(range(10))
    relabel = False

    def __init__(
        self, data_dir: str = ".", train: bool = True, transform: Optional[Callable] = None, download: bool = True
    ):
        super().__init__()
        self.dir_path = data_dir
        self.train = train  # training set or test set
        self.transform = transform

        if not _PIL_AVAILABLE:
            raise ImportError("You want to use PIL.Image for loading but it is not installed yet.")

        os.makedirs(self.cached_folder_path, exist_ok=True)
        self.prepare_data(download)

        if not self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)):
            raise RuntimeError("Dataset not found.")

        data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
        self.data, self.targets = torch.load(os.path.join(self.cached_folder_path, data_file))

    def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
        img = self.data[idx].reshape(3, 32, 32)
        target = int(self.targets[idx])

        if self.transform is not None:
            img = img.numpy().transpose((1, 2, 0))  # convert to HWC
            img = self.transform(Image.fromarray(img))
        if self.relabel:
            target = list(self.labels).index(target)
        return img, target

    @classmethod
    def _check_exists(cls, data_folder: str, file_names: Sequence[str]) -> bool:
        if isinstance(file_names, str):
            file_names = [file_names]
        return all(os.path.isfile(os.path.join(data_folder, fname)) for fname in file_names)

    def _unpickle(self, path_folder: str, file_name: str) -> Tuple[Tensor, Tensor]:
        with open(os.path.join(path_folder, file_name), "rb") as fo:
            pkl = pickle.load(fo, encoding="bytes")
        return torch.tensor(pkl[b"data"]), torch.tensor(pkl[b"labels"])

    def _extract_archive_save_torch(self, download_path):
        # extract achieve
        with tarfile.open(os.path.join(download_path, self.FILE_NAME), "r:gz") as tar:
            tar.extractall(path=download_path)
        # this is internal path in the archive
        path_content = os.path.join(download_path, "cifar-10-batches-py")

        # load Test and save as PT
        torch.save(
            self._unpickle(path_content, "test_batch"), os.path.join(self.cached_folder_path, self.TEST_FILE_NAME)
        )
        # load Train and save as PT
        data, labels = [], []
        for i in range(5):
            fname = f"data_batch_{i + 1}"
            _data, _labels = self._unpickle(path_content, fname)
            data.append(_data)
            labels.append(_labels)
        # stash all to one
        data = torch.cat(data, dim=0)
        labels = torch.cat(labels, dim=0)
        # and save as PT
        torch.save((data, labels), os.path.join(self.cached_folder_path, self.TRAIN_FILE_NAME))

    def prepare_data(self, download: bool):
        if self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)):
            return

        base_path = os.path.join(self.dir_path, self.DATASET_NAME)
        if download:
            self.download(base_path)
        self._extract_archive_save_torch(base_path)

    def download(self, data_folder: str) -> None:
        """Download the data if it doesn't exist in cached_folder_path already."""
        if self._check_exists(data_folder, self.FILE_NAME):
            return
        self._download_from_url(self.BASE_URL, data_folder, self.FILE_NAME)

在这里插入图片描述

你可能感兴趣的:(pytorch,lightning,pytorch,python,深度学习)