[PyTorch][chapter 40][CIFAR-10 数据集]

前言:

        CIFAR-10和CIFAR-100是8000万个微小图像数据集的标记子集。它们由Alex Krizhevsky、Vinod Nair和Geoffrey Hinto收集

目录:

  1.      CIFAR-10数据集简介
  2.      在线下载方式
  3.      离线下载方式


一 CIFAR-10数据集简介

   

        CIFAR-10数据集由10个类别的60000张32x32彩色图像组成,每个类别有6000张图像。有50000个训练图像和10000个测试图像。

        数据集分为五个训练批次和一个测试批次,每个批次有10000张图像。测试批次包含从每个类别中随机选择的1000幅图像。训练批包含按随机顺序排列的剩余图像,但一些训练批可能包含来自一个类的图像多于来自另一类的图像。在它们之间,训练批次正好包含每个类的5000个图像。

以下是数据集中的类,以及每个类的10张随机图像:



Here are the classes in the dataset, as well as 10 random images from each:

[PyTorch][chapter 40][CIFAR-10 数据集]_第1张图片


二  在线下载方式

    

# -*- coding: utf-8 -*-
"""
Created on Wed Jun 14 15:04:59 2023

@author: cxf
"""

import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader


def download():
    
    maxIter = 2
    dataset_trans = transforms.Compose([
    transforms.ToTensor(),transforms.Resize((32,32))
    ]) 
    
    cifar = datasets.CIFAR10(root='cifar',train=True,transform= dataset_trans,download =True) #一次只加载一个
                             
    
    train_data = DataLoader(cifar, batch_size=32,shuffle=True)
    
  
    # DataLoader迭代产生训练数据提供给模型
    for i in range(maxIter):
        
        for index,(img,label) in enumerate(train_data):
            pass
                             
if __name__ == "__main__":
    
    download()

二  离线下载方式

   如果PC没安装代理,直接通过在线访问的方式会非常慢,

长时间无反应,可以通过离线方式下载

 1:进入 CIFAR-10 and CIFAR-100 datasets

2: 选择python version

[PyTorch][chapter 40][CIFAR-10 数据集]_第2张图片

 3: 下载完离线包后,解压缩到本地

[PyTorch][chapter 40][CIFAR-10 数据集]_第3张图片

 4: 把datasets 里面的

       root 路径设置成img 的路径

       download  设置成False

# -*- coding: utf-8 -*-
"""
Created on Wed Jun 14 15:04:59 2023

@author: cxf
"""

import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader


def download():
    
    print("\n start")
    rootDir='./data'
    maxIter = 2
    dataset_trans = transforms.Compose([
    transforms.ToTensor(),transforms.Resize((32,32))
    ]) 
    
    cifar = datasets.CIFAR10(root=rootDir,train=True,transform= dataset_trans,download =False) #一次只加载一个
                             
    
    train_data = DataLoader(cifar, batch_size=32,shuffle=True)
    
  
    # DataLoader迭代产生训练数据提供给模型
    for i in range(maxIter):
        
        for index,(img,label) in enumerate(train_data):
            print("\n index: %d"%index, "\t label",label, "\t  img",img.shape)
            pass
                             
if __name__ == "__main__":
    
    download()

你可能感兴趣的:(数学建模)