torchvision数据集下载位置修改

torchvision数据集下载位置修改

记录一下,方便后边查阅。

1.torchvision基础介绍

torchvision是pytorch的一个图形库,它服务于深度学习Pytorch框架,主要用来构建计算机视觉模型。
下面是torchvision的构成[1]

	1.torchvision.datasets:一些加载数据的函数及常用的数据集接口;
	2.torchvision.models:包含常用的模型结构,例如AlexNet,VGG,ResNet等;
	3.torchvision.transforms:常用的一些图片变换,例如图片裁剪、选择等;
	4.torchvison.utils:其他一些有用的方法

2.torchvision常用数据集

torchvision数据集下载位置修改_第1张图片
官方的数据集用法:

torchvision数据集下载位置修改_第2张图片
torchvision下的常用数据集:
torchvision数据集下载位置修改_第3张图片

3.数据集用法介绍及root路径解释

代码示例:

def get_train_dataset():
    return dataset.FashionMNIST(
        root='./data',
        train=True,
        download=True,
        transform=getTransforms()
其中,
root:表示数据集下载保存位置
train:表示下载的数据集是不是训练集,True表示训练集,False表示测试集
download:表示数据集是否需要下载
transform:表示图片变换的一系列操作

root路径详解:

root='/':表示根目录下,如果你的代码保存在D盘,就是下载到D盘根目录下
root='./':表示当前文件夹下
root='':效果等同于'./'
root='./data':表示在当前文件夹下的data(如果没有,则会新建一个)文件夹下保存数据集
root='data':效果等同于'./data'

4.本文训练LeNet的数据处理代码

代码出自文献[2]:

FashionMNIST.py

import torch.utils.data
from torchvision import datasets as dataset
from torchvision import transforms


def getTransforms():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3018,))]
    )
    return transform


def get_train_dataset():
    return dataset.FashionMNIST(
        root='',
        train=True,
        download=True,
        transform=getTransforms()
    )


def get_test_dataset():
    return dataset.FashionMNIST(
        root='',
        train=False,
        download=True,
        transform=getTransforms()
    )


def get_train_loader(batch_size, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset=get_train_dataset(),
        batch_size=batch_size,
        shuffle=shuffle
    )


def get_test_loader(batch_size, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset=get_test_dataset(),
        batch_size=batch_size,
        shuffle=shuffle
    )

MNIST.py

import torch.utils.data
from torchvision import datasets as dataset
from torchvision import transforms


def getTransforms():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3018,))]
    )
    return transform


def get_train_dataset():
    return dataset.MNIST(
        root='LeNetTest',
        train=True,
        download=True,
        transform=getTransforms()
    )


def get_test_dataset():
    return dataset.MNIST(
        root='LeNetTest',
        train=False,
        download=True,
        transform=getTransforms()
    )


def get_train_loader(batch_size, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset=get_train_dataset(),
        batch_size=batch_size,
        shuffle=shuffle
    )


def get_test_loader(batch_size, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset=get_test_dataset(),
        batch_size=batch_size,
        shuffle=shuffle
    )

第一次写,可能写得不是很好,希望大家多多包涵!

文献

[1]:https://wenku.baidu.com/view/21bbc06bf4ec4afe04a1b0717fd5360cba1a8df6.html
[2]:https://blog.csdn.net/weixin_38878828/article/details/125614377

你可能感兴趣的:(Pytorch,torchvision数据处理,深度学习,人工智能,pytorch,python)