PyTorch - 数据集介绍(mnist、CIFAR10、CIFAR100)

参考自官网:torchvision.datasets

总介绍

torchvision.datasets中包含了以下数据集

  • MNIST
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10

详细介绍(以mnist手写数字集为例)

  • 数据集介绍
    60000个训练数据,10000个测试数据,每张图片大小28*28。
    单通道的黑白色图片,即(batch_size, channels, Height, Width) =(batch_size, 1, 28, 28)
  • 参数列表
    MNIST(root, train=True, transform=None, target_transform=None, download=False)
    参数说明:
    • root : processed/training.pt 和 processed/test.pt 的主目录
    • train : True = 训练集, False = 测试集
    • target_transform:一个函数,原始图片作为输入,返回一个转换后的图片
    • download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。
  • 除此以外还需要对target_transform进一步了解:
    一个函数,输入为target,输出对其的转换。
    torchvision.transforms.Compose(transforms)
    例如:
torchvision.transforms.Compose([
    torchvision.transforms.Resize(224), # 缩放图片,保持长宽比不变,最短边为224像素
    torchvision.transforms.CenterCrop(10),# 将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,在这种情况下,切出来的图片的形状是正方形。
    torchvision.transforms.ToTensor(), # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloadTensor
    torchvision.transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化(正则化)至[-1, 1]
 ])
  • 代码使用
import torchvision

# 获取数据集
train_data = torchvision.datasets.MNIST(root='mnist', 
                                        train=True, 
                                        transform = torchvision.transforms.ToTensor(), 
                                        download=True)
test_data = torchvision.datasets.MNIST(root='mnist', 
                                       train=False, 
                                       transform = torchvision.transforms.ToTensor(), 
                                       download=True)

# 属性测试
num_sample = train_data.__len__()
print(num_sample) # 60000
item = train_data.__getitem__(0)

或者

import torchvision


# 数据集的预处理
data_tf = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.5],[0.5])
    ])

data_path = r'./mnist'
# 获取数据集
train_data = torchvision.datasets.MNIST(data_path, train=True, transform = data_tf, download=True)
test_data = torchvision.datasets.MNIST(data_path, train=False, transform = data_tf, download=True)

其他数据集

  • CIFAR10

    • API
      CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
    • 介绍:
      该数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。
    • 图示 PyTorch - 数据集介绍(mnist、CIFAR10、CIFAR100)_第1张图片
  • CIFAR100

    • API
      CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
    • 介绍:
      这个数据集就像CIFAR-10,除了它有100个类,每个类包含600个图像。,每类各有500个训练图像和100个测试图像。CIFAR-100中的100个类被分成20个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类)
      以下是CIFAR-100中的类别列表:
超类 类别
水生哺乳动物 海狸,海豚,水獭,海豹,鲸鱼
水族馆的鱼,比目鱼,射线,鲨鱼,鳟鱼
花卉 兰花,罂粟花,玫瑰,向日葵,郁金香
食品容器 瓶子,碗,罐子,杯子,盘子
水果和蔬菜 苹果,蘑菇,橘子,梨,甜椒
家用电器 时钟,电脑键盘,台灯,电话机,电视机
家用家具 床,椅子,沙发,桌子,衣柜
昆虫 蜜蜂,甲虫,蝴蝶,毛虫,蟑螂
大型食肉动物 熊,豹,狮子,老虎,狼
大型人造户外用品 桥,城堡,房子,路,摩天大楼
大自然的户外场景 云,森林,山,平原,海
大杂食动物和食草动物 骆驼,牛,黑猩猩,大象,袋鼠
中型哺乳动物 狐狸,豪猪,负鼠,浣熊,臭鼬
非昆虫无脊椎动物 螃蟹,龙虾,蜗牛,蜘蛛,蠕虫
宝贝,男孩,女孩,男人,女人
爬行动物 鳄鱼,恐龙,蜥蜴,蛇,乌龟
小型哺乳动物 仓鼠,老鼠,兔子,母老虎,松鼠
树木 枫树,橡树,棕榈,松树,柳树
车辆1 自行车,公共汽车,摩托车,皮卡车,火车
车辆2 割草机,火箭,有轨电车,坦克,拖拉机

你可能感兴趣的:(PyTorch)