pytorch学习笔记1 - 数据集读取

前言

要开始搞data science了, 最近在考虑一些AI结合security的事情, 还比较迷茫, 不过应该总能找到出路的, 所以先不管学这个有什么意义了, 直接开干技术.
pytorch同TensorFlow一样, 是data science标配, 所以必须好好学习

blog主要参考b站up 我是土堆 的pytorch教学视频
以及pytorch官方文档和源码

torch包结构

查看torch包的成员
dir(torch)

['AVG',
 'AggregationType',
 'AnyType',
 'Argument',
 'ArgumentSpec',
 'BFloat16Storage',
 'BFloat16Tensor',
 'BenchmarkConfig',
 'BenchmarkExecutionStats',
 'Block',
 'BoolStorage',
 'BoolTensor',
 'BoolType',
 'BufferDict',
 'ByteStorage',
 'ByteTensor',
 'CONV_BN_FUSION',
 'CallStack',
 'Capsule',
 'CharStorage',
 'CharTensor',
 'ClassType',
 'Code',
 'CompilationUnit',
 'CompleteArgumentSpec',
 'ComplexDoubleStorage',
 'ComplexFloatStorage',
 'ComplexType',
 'ConcreteModuleType',
 'ConcreteModuleTypeBuilder',
 'CudaBFloat16StorageBase',
 'CudaBoolStorageBase',
 'CudaByteStorageBase',
 'CudaCharStorageBase',
 'CudaComplexDoubleStorageBase',
 'CudaComplexFloatStorageBase',
 'CudaDoubleStorageBase',
 'CudaFloatStorageBase',
 'CudaHalfStorageBase',
 'CudaIntStorageBase',
 'CudaLongStorageBase',
 'CudaShortStorageBase',
 ...]

查看cuda的成员
dir(torch.cuda)

['Any',
 'BFloat16Storage',
 'BFloat16Tensor',
 'BoolStorage',
 'BoolTensor',
 'ByteStorage',
 'ByteTensor',
 'CharStorage',
 'CharTensor',
 'ComplexDoubleStorage',
 'ComplexFloatStorage',
 'CudaError',
 'DeferredCudaCallError',
 'Device',
 'Dict',
 'DoubleStorage',
 'DoubleTensor',
 'Event',
 'FloatStorage',
 'FloatTensor',
 'HalfStorage',
 'HalfTensor',
 'IntStorage',
 'IntTensor',
 'List',
 'LongStorage',
 'LongTensor',
 'Optional',
 'ShortStorage',
 'ShortTensor',
 'Stream',
 'Tuple',
 'Union',
 '_CudaBase',
 '_CudaDeviceProperties',
 '_Graph',
 '_StorageBase',
 '__annotations__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '_check_capability',
 '_check_cubins',
 '_cudart',
 '_device',
 '_device_t',
 '_dummy_type',
 '_get_device_index',
 '_initialization_lock',
 '_initialized',
 '_is_in_bad_fork',
 '_lazy_call',
 '_lazy_init',
 '_lazy_new',
 '_queued_calls',
 '_sleep',
 '_tls',
 '_utils',
 'amp',
 'caching_allocator_alloc',
 'caching_allocator_delete',
 'can_device_access_peer',
 'check_error',
 'collections',
 'contextlib',
 'cudaStatus',
 'cudart',
 'current_blas_handle',
 'current_device',
 'current_stream',
 'default_generators',
 'default_stream',
 'device',
 'device_count',
 'device_of',
 'empty_cache',
 'get_arch_list',
 'get_device_capability',
 'get_device_name',
 'get_device_properties',
 'get_gencode_flags',
 'get_rng_state',
 'get_rng_state_all',
 'has_half',
 'has_magma',
 'init',
 'initial_seed',
 'ipc_collect',
 'is_available',
 'is_initialized',
 'list_gpu_processes',
 'manual_seed',
 'manual_seed_all',
 'max_memory_allocated',
 'max_memory_cached',
 'max_memory_reserved',
 'memory',
 'memory_allocated',
 'memory_cached',
 'memory_reserved',
 'memory_snapshot',
 'memory_stats',
 'memory_stats_as_nested_dict',
 'memory_summary',
 'nccl',
 'nvtx',
 'os',
 'profiler',
 'random',
 'reset_accumulated_memory_stats',
 'reset_max_memory_allocated',
 'reset_max_memory_cached',
 'reset_peak_memory_stats',
 'seed',
 'seed_all',
 'set_device',
 'set_per_process_memory_fraction',
 'set_rng_state',
 'set_rng_state_all',
 'sparse',
 'stream',
 'streams',
 'synchronize',
 'threading',
 'torch',
 'traceback',
 'warnings']

查看成员帮助说明
help(torch.cuda.is_available)

Help on function is_available in module torch.cuda:
is_available() -> bool
    Returns a bool indicating if CUDA is currently available.

这样就知道torch.cuda.is_available()是用来判断CUDA是否可用的函数

以上, 可以经常使用dir()help()快捷方便的查看和使用包中的成员与函数说明等.

数据集读取

数据集用up给的bee和ant图片包

Dataset

继承torch.utils.dataDataset

import os
from PIL import Image
from torch.utils.data import Dataset


class DatasetUsage(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.path, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)


root_dir = "dataset/hymenoptera_data/train"
ant_label = "ants"
bee_label = "bees"
ant_dataset = DatasetUsage(root_dir, ant_label)
bee_dataset = DatasetUsage(root_dir, bee_label)
dataset_sum = ant_dataset + bee_dataset

print(dataset_sum)

(1) 路径采用相对路径, root_dir, 用os.path.join拼接路径
(2) 关于数据集读取需要继承父类Dataset, 并实现__init__, __getitem__, __len__3个私有方法, 得到子类
(3) __len__统计在某一文件夹下的数据文件个数
(4) __get__需要返回数据和label

关于add方法重载, 在父类Dataset实现

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

ConcatDataset通过将传入的dataset对象转换成list类型存储, 实现数据集对象拼接

class ConcatDataset(Dataset[T_co]):
    r"""Dataset as a concatenation of multiple datasets.

    This class is useful to assemble different existing datasets.

    Args:
        datasets (sequence): List of datasets to be concatenated
    """
    datasets: List[Dataset[T_co]]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ConcatDataset, self).__init__()
        # Cannot verify that datasets is Sized
        assert len(datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore
        self.datasets = list(datasets)
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes

DataLoader

先看看官方文档https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 
num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, 
multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

说明常用参数
dataset 目标数据集
batch_size 是单次取数据的个数, 比如batch_size = 4则是4个数据为一个batch
shuffle 表示是否随机取数据, shuffle=True打乱顺序来取数据
num_workers 执行读取数据集的进程数
drop_last 是否丢弃最后不符合batch_size的batch

这里用CIFAR10数据集作为dataloader的演示

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


test_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("DataLoader_example")
step = 0
for data in test_loader:
    imgs, targets = data
    writer.add_images("test data", imgs, step)
    step += 1
writer.close()

效果
pytorch学习笔记1 - 数据集读取_第1张图片

参考

https://www.bilibili.com/video/BV1hE411t7RN?p=1

你可能感兴趣的:(#,机器学习,pytorch,深度学习,python)