要开始搞data science了, 最近在考虑一些AI结合security的事情, 还比较迷茫, 不过应该总能找到出路的, 所以先不管学这个有什么意义了, 直接开干技术.
pytorch同TensorFlow一样, 是data science标配, 所以必须好好学习
blog主要参考b站up 我是土堆 的pytorch教学视频
以及pytorch官方文档和源码
查看torch包的成员
dir(torch)
['AVG',
'AggregationType',
'AnyType',
'Argument',
'ArgumentSpec',
'BFloat16Storage',
'BFloat16Tensor',
'BenchmarkConfig',
'BenchmarkExecutionStats',
'Block',
'BoolStorage',
'BoolTensor',
'BoolType',
'BufferDict',
'ByteStorage',
'ByteTensor',
'CONV_BN_FUSION',
'CallStack',
'Capsule',
'CharStorage',
'CharTensor',
'ClassType',
'Code',
'CompilationUnit',
'CompleteArgumentSpec',
'ComplexDoubleStorage',
'ComplexFloatStorage',
'ComplexType',
'ConcreteModuleType',
'ConcreteModuleTypeBuilder',
'CudaBFloat16StorageBase',
'CudaBoolStorageBase',
'CudaByteStorageBase',
'CudaCharStorageBase',
'CudaComplexDoubleStorageBase',
'CudaComplexFloatStorageBase',
'CudaDoubleStorageBase',
'CudaFloatStorageBase',
'CudaHalfStorageBase',
'CudaIntStorageBase',
'CudaLongStorageBase',
'CudaShortStorageBase',
...]
查看cuda的成员
dir(torch.cuda)
['Any',
'BFloat16Storage',
'BFloat16Tensor',
'BoolStorage',
'BoolTensor',
'ByteStorage',
'ByteTensor',
'CharStorage',
'CharTensor',
'ComplexDoubleStorage',
'ComplexFloatStorage',
'CudaError',
'DeferredCudaCallError',
'Device',
'Dict',
'DoubleStorage',
'DoubleTensor',
'Event',
'FloatStorage',
'FloatTensor',
'HalfStorage',
'HalfTensor',
'IntStorage',
'IntTensor',
'List',
'LongStorage',
'LongTensor',
'Optional',
'ShortStorage',
'ShortTensor',
'Stream',
'Tuple',
'Union',
'_CudaBase',
'_CudaDeviceProperties',
'_Graph',
'_StorageBase',
'__annotations__',
'__builtins__',
'__cached__',
'__doc__',
'__file__',
'__loader__',
'__name__',
'__package__',
'__path__',
'__spec__',
'_check_capability',
'_check_cubins',
'_cudart',
'_device',
'_device_t',
'_dummy_type',
'_get_device_index',
'_initialization_lock',
'_initialized',
'_is_in_bad_fork',
'_lazy_call',
'_lazy_init',
'_lazy_new',
'_queued_calls',
'_sleep',
'_tls',
'_utils',
'amp',
'caching_allocator_alloc',
'caching_allocator_delete',
'can_device_access_peer',
'check_error',
'collections',
'contextlib',
'cudaStatus',
'cudart',
'current_blas_handle',
'current_device',
'current_stream',
'default_generators',
'default_stream',
'device',
'device_count',
'device_of',
'empty_cache',
'get_arch_list',
'get_device_capability',
'get_device_name',
'get_device_properties',
'get_gencode_flags',
'get_rng_state',
'get_rng_state_all',
'has_half',
'has_magma',
'init',
'initial_seed',
'ipc_collect',
'is_available',
'is_initialized',
'list_gpu_processes',
'manual_seed',
'manual_seed_all',
'max_memory_allocated',
'max_memory_cached',
'max_memory_reserved',
'memory',
'memory_allocated',
'memory_cached',
'memory_reserved',
'memory_snapshot',
'memory_stats',
'memory_stats_as_nested_dict',
'memory_summary',
'nccl',
'nvtx',
'os',
'profiler',
'random',
'reset_accumulated_memory_stats',
'reset_max_memory_allocated',
'reset_max_memory_cached',
'reset_peak_memory_stats',
'seed',
'seed_all',
'set_device',
'set_per_process_memory_fraction',
'set_rng_state',
'set_rng_state_all',
'sparse',
'stream',
'streams',
'synchronize',
'threading',
'torch',
'traceback',
'warnings']
查看成员帮助说明
help(torch.cuda.is_available)
Help on function is_available in module torch.cuda:
is_available() -> bool
Returns a bool indicating if CUDA is currently available.
这样就知道torch.cuda.is_available()
是用来判断CUDA是否可用的函数
以上, 可以经常使用dir()
和help()
快捷方便的查看和使用包中的成员与函数说明等.
数据集用up给的bee和ant图片包
继承torch.utils.data
的Dataset
类
import os
from PIL import Image
from torch.utils.data import Dataset
class DatasetUsage(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.path, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/hymenoptera_data/train"
ant_label = "ants"
bee_label = "bees"
ant_dataset = DatasetUsage(root_dir, ant_label)
bee_dataset = DatasetUsage(root_dir, bee_label)
dataset_sum = ant_dataset + bee_dataset
print(dataset_sum)
(1) 路径采用相对路径, root_dir, 用os.path.join拼接路径
(2) 关于数据集读取需要继承父类Dataset, 并实现__init__, __getitem__, __len__
3个私有方法, 得到子类
(3) __len__
统计在某一文件夹下的数据文件个数
(4) __get__
需要返回数据和label
关于add方法重载, 在父类Dataset
实现
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
ConcatDataset
通过将传入的dataset对象转换成list类型存储, 实现数据集对象拼接
class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
self.datasets = list(datasets)
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
先看看官方文档https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
说明常用参数
dataset 目标数据集
batch_size 是单次取数据的个数, 比如batch_size = 4
则是4个数据为一个batch
shuffle 表示是否随机取数据, shuffle=True
打乱顺序来取数据
num_workers 执行读取数据集的进程数
drop_last 是否丢弃最后不符合batch_size的batch
这里用CIFAR10数据集作为dataloader的演示
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("DataLoader_example")
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("test data", imgs, step)
step += 1
writer.close()
https://www.bilibili.com/video/BV1hE411t7RN?p=1