pytorch学习笔记--trochvision.datasets和DataLoader的使用

trochvision.datasets和DataLoader的使用

  • 一、datasets
  • 二、DataLoader
  • 补充:datasets类的代码

本文为学习笔记,感谢PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

一、datasets

datasets工具在trochvision中

import torchvision
from torchvision import transforms as tf
from tensorboardX import SummaryWriter

train_dataset = torchvision.datasets.CIFAR10(root='./dataset',transform=tf.ToTensor(),train=True,download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./dataset',transform=tf.ToTensor(),train=False,download=True)

print(train_dataset[0]) #(, 6)
print(train_dataset[1])#同上,返回一张图和标签组成的元组
print(train_dataset.classes) #查看分类类型,此数据集共10类

writer = SummaryWriter('logs\\2')

#可视化十张图
for i in range(10):
	img ,label = train_dataset[i]
	writer.add_image('10train_img',img,i)
writer.close()

参数:
CIFAR10:是数据集的名字
root=’./dataset’:保存路径
transform=tf.ToTensor():对图片的转变方法
train=True:训练or测试数据
download=True:是否检测下载

二、DataLoader

from torch.utils.data import DataLoader
from torchvision import transforms as tf
import torchvision

test_dataset = torchvision.datasets.CIFAR10(root='./dataset',transform=tf.ToTensor(),train=False,download=True)

#参数batch_size是取数据集中的一个批量进行打包输出,test_iter中的每个元素都是64张图的合并
test_iter = DataLoader(dataset=test_dataset,batch_size=64,shuffle=True,num_workers=0,drop_last=True)

DataLoader中参数batch_size是取数据集中的一个批量进行打包输出,test_iter中的每个元素都是64张图的合并

参数:
dataset:读取的数据集
batch_size :批量大小
shuffle :序列的所有元素随机排序
num_worker :进程数
drop_last :是否丢弃尾部不足batch_size的数据

补充:datasets类的代码

#Dataset类的代码
from torch.utils.data import Dataset
from PIL import Image
import os
# F:\python_project\deep_learning\train\ants_image\0013035.jpg
class MyDate(Dataset):
	def __init__(self,root_dir,label_dir):
		self.root_dir = root_dir
		self.label_dir = label_dir
		self.path = os.path.join(self.root_dir,self.label_dir)
		self.img_path = os.listdir(self.path)
	def __getitem__(self, idx):
		img_name = self.img_path[idx]
		img_value_path = os.path.join(self.root_dir,self.label_dir,img_name)
		img = Image.open(img_value_path)
		label = self.label_dir
		return img,label
	def __len__(self):
		return len(self.img_path)

root_dir = 'F:\python_project\deep_learning\\train'
label_dir = 'ants_image'
ant = MyDate(root_dir,label_dir)
img,label = ant.__getitem__(1)
img.show()
print(label)

你可能感兴趣的:(深度学习,pytorch,深度学习,机器学习)