`本文的全套代码下载点击此处:https://download.csdn.net/download/weixin_54546190/85539252人名币二分类完整Python代码(包含数据集)
在介绍数据读取机制(DataLoader与Dataset)之前,我们先学习一下人名币二分类实验,有助于后面的理解。
机器学习模型训练步骤 |
数据处理的过程 |
Sampler的作用:生成索引(index)也就是样本的序号。
DataSet的作用:根据索引读取相应的图片和标签。
Dataloader构建了一个迭代的数据装载器,训练时每次for循环,每次iteration都会从Dataloader中获取一个batch_size大小的数据进行操作。
Epoch / Iteration / Batchsize之间的关系 |
Epoch:
所有训练样本都已输入到模型中,称为一个EpochIteration:
一批样本输入到模型中,称之为一个IterationBatchsize:
批大小,决定一个Epoch有多少个Iteration- 例:样本总数:80, Batchsize:8 时
1 Epoch = 10 Iteration
- 例:样本总数:87, Batchsize:8 时
当drop_last = True时:1 Epoch = 10
当drop_last = False时:1 Epoch = 11 Iteration,但是最后一个Iteration只有7个数据。
torch.utils.data.Dataset是用来定义数据从哪读取和如何读取的问题。
Dataset是一个抽象类
:
在人民币二分类之前,我们先了解三个问题,带着这三个问题去学习人名币二分类实验代码。如下图所示:
以训练集0.8,验证集0.1,测试集0.1的比例对总人名币数据集进行划分。通过下述1_split_dataset.py文件运行可得子文件-- -- rmb_split,具体划分之后的文件分布我详细卸载代码下方。
# -*- coding: utf-8 -*-
"""
# @file name : 1_split_dataset.py
# @date : 2022-06-1 10:08:00
# @brief : 将数据集划分为训练集,验证集,测试集
"""
import os
import random
import shutil
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
if __name__ == '__main__':
dataset_dir = os.path.abspath(os.path.join(BASE_DIR, "RMB_data"))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
test_dir = os.path.join(split_dir, "test")
if not os.path.exists(dataset_dir):
raise Exception("\n{} 不存在,请下载 02-01-数据-RMB_data.rar 放到\n{} 下,并解压即可".format(
dataset_dir, os.path.dirname(dataset_dir)))
train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1
for root, dirs, files in os.walk(dataset_dir):
for sub_dir in dirs:
imgs = os.listdir(os.path.join(root, sub_dir))
imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
random.shuffle(imgs)
img_count = len(imgs)
train_point = int(img_count * train_pct)
valid_point = int(img_count * (train_pct + valid_pct))
if img_count == 0:
print("{}目录下,无图片,请检查".format(os.path.join(root, sub_dir)))
import sys
sys.exit(0)
for i in range(img_count):
if i < train_point:
out_dir = os.path.join(train_dir, sub_dir)
elif i < valid_point:
out_dir = os.path.join(valid_dir, sub_dir)
else:
out_dir = os.path.join(test_dir, sub_dir)
makedir(out_dir)
target_path = os.path.join(out_dir, imgs[i])
src_path = os.path.join(dataset_dir, sub_dir, imgs[i])
shutil.copy(src_path, target_path)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
img_count-valid_point))
print("已在 {} 创建划分好的数据\n".format(out_dir))
lesson-06
– – RMB_data
– – 1
– – × × × \times \times \times ×××.jpg
– – 100
– – × × × \times \times \times ×××.jpg
– – rmb_split
– – test
– – 1
– – 100
– – train
– – 1
– – 100
– – valid
– – 1
– – 100
– – test_data
– – 100
– – 100.jpg
– --tools
– – common_tools.py
– – dcgan.py
– – my_dataset.py
– – unet.py
– --model
– – lenet.py
– – 1_split_dataset.py
– – 2_train_lenet.py
我们首先看一下完整代码和代码训练效果以及结果。
"""
# @file name : train_lenet.py
# @author : 源仔
# @date : 2022-06-1 10:08:00
# @brief : 人民币分类模型训练
"""
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
path_lenet = os.path.abspath(os.path.join(BASE_DIR,"model", "lenet.py"))
path_tools = os.path.abspath(os.path.join(BASE_DIR,"tools", "common_tools.py"))
assert os.path.exists(path_lenet), "{}不存在,请将lenet.py文件放到 {}".format(path_lenet, os.path.dirname(path_lenet))
assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
set_seed() # 设置随机种子
rmb_label = {"1": 0, "100": 1}
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
# ============================ step 1/5 数据 ============================
# ============================ 读取数据在硬盘中的地址 ============================
split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
if not os.path.exists(split_dir):
raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
# ============================ 这里的均值和方差 ============================
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
# =====================对训练数据和验证数据进行预处理=========================
train_transform = transforms.Compose([
transforms.Resize((32, 32)), # 图片的大小缩放到(w,h)=(32,32)
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.ToTensor(), # 把图片格式转化为tensor形式
transforms.Normalize(norm_mean, norm_std), # 归一化
])
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
scheduler.step() # 更新学习率
# validate the model
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum().numpy()
loss_val += loss.item()
loss_val_epoch = loss_val / len(valid_loader)
valid_curve.append(loss_val_epoch)
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval - 1 # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
# ============================ inference ============================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")
test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)
for i, data in enumerate(valid_loader):
# forward
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
rmb = 1 if predicted.numpy()[0] == 0 else 100
print("模型获得{}元".format(rmb))
OUT:
Training:Epoch[000/010] Iteration[010/010] Loss: 0.6326 Acc:61.25%
Valid: Epoch[000/010] Iteration[002/002] Loss: 0.4373 Acc:85.00%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.3222 Acc:88.12%
Valid: Epoch[001/010] Iteration[002/002] Loss: 0.0250 Acc:100.00%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.0559 Acc:98.12%
Valid: Epoch[002/010] Iteration[002/002] Loss: 0.0003 Acc:100.00%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.0287 Acc:99.38%
Valid: Epoch[003/010] Iteration[002/002] Loss: 0.0001 Acc:100.00%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.3408 Acc:92.50%
Valid: Epoch[004/010] Iteration[002/002] Loss: 0.0186 Acc:100.00%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.0386 Acc:98.75%
Valid: Epoch[005/010] Iteration[002/002] Loss: 0.0165 Acc:100.00%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.0145 Acc:100.00%
Valid: Epoch[006/010] Iteration[002/002] Loss: 0.0004 Acc:100.00%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.0136 Acc:99.38%
Valid: Epoch[007/010] Iteration[002/002] Loss: 0.0002 Acc:100.00%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.0072 Acc:100.00%
Valid: Epoch[008/010] Iteration[002/002] Loss: 0.0005 Acc:100.00%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.0039 Acc:100.00%
Valid: Epoch[009/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
模型获得100元
关于模型中输入数据的预处理(如裁剪,旋转,颜色扰动等)、step 2/5 模型 、step 3/5 损失函数、step 4/5 模型、step 5/5 训练四个内容,在后面的博客会陆续更新,一口吃不成胖子,我们要按部就班的去学习,本张只讲step 1/5 训练数据的导入.
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
- data_dir=train_dir:str, 数据集所在路径。
- train_transform:对训练数据进行预处理,代码如下:
train_transform: 数据预处理代码如下
train_transform = transforms.Compose([
transforms.Resize((32, 32)), # 图片的大小缩放到(w,h)=(32,32)
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.ToTensor(), # 把图片格式转化为tensor形式
transforms.Normalize(norm_mean, norm_std), # 将数据转换为正太分布,使模型更容易收敛。
])
RMBDataset类 |
我们从train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
开始, 这一句话里面的核心就是RMBDataset
,这个是我们自己写的一个类,继承了上面的抽象类Dataset
,并且重写了__getitem__()
方法, 这个类的目的就是传入数据的路径,和预处理部分(看参数),然后给我们返回数据,下面看它是怎么实现的(Pycharm里面按住Ctrl+B键,或按住Ctrl,然后点击这个RMBDataset位置就进入当前的RMBDataset类中):
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
RMB:面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
这一部分解释来自Python中__init__的用法和理解,在此感谢这位博主。
在Python中定义类经常会用到__init__函数(方法),首先需要理解的是,两个下划线开头的函数是声明该属性为私有,不能在类的外部被使用或访问。而__init__函数(方法)支持带参数类的初始化,也可为声明该类的属性(类中的变量)。__init__函数(方法)的第一个参数必须为self,后续参数为自己定义。
从文字理解比较困难,通过下面的例子能非常容易理解这个概念:
例如我们定义一个Box类,有width, height, depth三个属性,以及计算体积的方法:
# -*- coding utf-8 -*-
#Created by Lu Zhan
class Box:
def setDimension(self, width, height, depth):
self.width = width
self.height = height
self.depth = depth
def getVolume(self):
return self.width * self.height * self.depth
b = Box()
b.setDimension(10, 20, 30)
print(b.getVolume())
我们在Box类中定义了setDimension方法去设定该Box的属性,这样过于繁琐,而用__init__()这个特殊的方法就可以方便地自己对类的属性进行定义,init()方法又被称为构造器(constructor)。
#!/usr/bin/python
# -*- coding utf-8 -*-
#Created by Lu Zhan
class Box:
#def setDimension(self, width, height, depth):
# self.width = width
# self.height = height
# self.depth = depth
def __init__(self, width, height, depth):
self.width = width
self.height = height
self.depth = depth
def getVolume(self):
return self.width * self.height * self.depth
b = Box(10, 20, 30)
print(b.getVolume())
上述代码重点在
__getitem__
这里,那我们先要大概搞懂__getitem__
是起什么作用的。w我们下先看一段代码去了解__getitem__
。
class Animal:
def __init__(self, animal_list, age):
self.animals_name = animal_list
self.animals_age = age
def __getitem__(self, index):
return self.animals_name[index]
animals = Animal(["dog", "cat", "fish"], [1, 2, 3])
for i in animals:
print(i)
OUT:
dog
cat
fish
- 那这个
def __getitem__(self, index):
是在哪里执行的呢 ! 那我们继续Debug
试试。
可以从上面动态图注意到,每当Debug
到for
循环时,下一步会进入def __getitem__(self, index):
获取对应的index
,那我怎么知道是去获取对应的index
呢,那是因为你们仔细看上面的动态图过程。我从中截取了3
张图片,如下图。
获取到index
,之后def __getitem__(self, index):
会执行return
函数,返回return self.animals_name[0](等于dog)、self.animals_name[1](等于cat)、self.animals_name[2](等于fish),那么返回出来的就是我们for
循环遍历中的i
,到这里我相信大家因该都懂__getitem__
的用法了吧。
下面我们在返回去看class RMBDataset(Dataset):这段代码,为了方便大家查看,在此在下方复制了一遍。
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
RMB:面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
1、由上面的讲解,我们会先执行__init__初始化参数。
self.label_name
:数据集的标签(label
),也就是数据集的类别。data_info
:存储所有图片路径和标签,在DataLoader中通过index读取样本。self.transform
:就是数据预处理。
问了能大家能更好了理解,我debug
了两个循环的参数给大家看看(上述 RMBDataset
类中所有变量参数都在下图中展示出来了),如下图所示:
最后我们看一下经过RMBDataset类后的train_data和valid_data包含什么?如下图所示:
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
DataLoader
这个类,接收的参数就是上面的RMBDataset
,我们知道这个是返回一个样本的张量和标签,然后又跟了一个BATCH_SIZE
, 看到这个,你心里应该有数了,这个不就是说一个batch
里面有多少个样本吗? 如果有了一个batch的样本数量,有了样本总数,就能得到总共有多少个batch
了。 后面的shuffle
,这个是说我取图片的时候,把顺序打乱一下,不是重点。 那么你是不是又好奇点东西了, 这个DataLoader
在干啥事情呢? 其实它在干这样的事情,我们只要指定了Batch_SIZE
, 比如指定10
个, 我们总共是有100
个训练样本,那么就可以计算出批数是10
, 那么DataLoader
就把样本分成10
批顺序打乱的数据,每一个Batch_size
里面有10
个样本且都是张量
和标签
的形式。
由于DataLoader源码太长,我们大概知道DataLoader是起什么作用的就好。下面我们直接从训练的部分看,像中间的模型,损失函数,优化器不是重点,所以这里先不放上来:
for epoch in range(MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
上面就是训练部分的核心了,这个比较好理解, 两层循环,外循环表示的迭代
Epoch
,也就是全部的训练样本喂入模型一次, 内循环表示的批次的循环,每一个Epoch
中,都是一批批的喂入, 那么数据读取具体使用的重点就是for i, data in enumerate(train_loader)
这句话了, 所以我们Debug
看看这个函数究竟是怎么去得到数据的?
Step-into
进入后这样就会看到,程序跳转到了DataLoader
的__iter__(self)
这个方法,毕竟这是个迭代的过程, 但是简单的瞄一眼这个函数,就会发现就一个判断,说的啥呢? 原来在说是用单进程还是用多进程读取机制进行处理, 关于读取数据啥也没干。 所以这个也不是重点, 我们使用stepover
进行下一步,然后在stepinto
进入单进程(_SingleProcessDataLoaderIter
)的这个机制里面。
在stepinto
进入单进程(_SingleProcessDataLoaderIter
)的这个机制后,比较重要的一个方法就是__next__(self)
, 上面不是说RMBDataset
函数是能返回一个样本和标签吗? 这里的这个next, (如下图所示)看其字面意义就知道这个是获取下一个样本和标签,重要的两行代码就是红色线条(index\data)
这两行,self.__next__index()
获取下一个样本的index
, 然后self.dataset_fetcher.fetch(index)
根据index
去获取下一个样本, 那么是怎么做到的? 继续调试:将光标放到__next__index()
这一行,然后点击下面的run to cursor
图表,就会跳到这一行,然后stepinto
。
操作如下图所示:
step-into
进入,这里是返回了一个return next(self.sampler_iter)
, 所以重点应该是这个东西,我们继续stepinto
step-into
进入后,这里发现进入了sampler.py
, 这里面重要的就是这个__iter__(self)
, 这个方法正是一次次的去采样我们的数据的索引,然后够了一个batch_size
了就返回了。 那这一次取到的哪些样本的索引呢? 我们可以跳出这个函数,回去看看(连续两次跳出函数,回到dataloader.py
):
然后
stepover
到data
这一行, 这个意思就是说,index
这一样代码执行完毕,我们可以看到最下面取到的index
(可以和上上张图片,没执行这个函数的时候对比一下),我们的batch_size
设置的16
, 所以通过上面的sampler.py
获得了16
个样本的索引。
这样,我们就有了一个批次的
index
, 那么就好说了,根据index
取不就完事了, 所以第二行代码data = self.dataset_fetcher.fetch(index)
就是取数据去了,重点就是这里的dataset_fetcher.fetch
方法, 我们继续debug
看看它是怎么取数据的。
这样进入了
fetch.py
, 然后核心是这里的fetch
方法,这里面会发现调用了self.dataset[idx]
去获取数据, 那么我们再步入一步,就看到了奇迹:
在
Run to cursor
到return self.collate_fn(data)
这是已经取完了一个批次, 然后进入self.collate_fn(data)
进行整合,就得到了我们一个批次的data,最终我们返回来。
由下面两张图我们可以知道,
data
中含有:
好了, 上面就是
DataLoader
读取数据的过程了,可能代码调试的过程确实比较乱,或许看不大懂,所以我们基于那三个问题梳理一遍逻辑,把逻辑关系看懂就好了, 并且最后用灵魂画笔来个流程图再进行梳理。 还记得我们的三个问题吗?
- 读哪些数据? 这个我们是根据
Sampler
出的index
决定的- 从哪读数据? 这个是
Dataset
的data_dir
设置数据的路径,然后去读- 怎么读数据? 这个是
Dataset
的getitem
方法,可以帮助我们获取一个样本
我们知道,DataLoader读取数据的过程比较麻烦,用到了四五个.py文件的跳转,所以梳理这个逻辑关系最好的方式就是流程图:
通过这个流程图,把DataLoader
读取数据的流程梳理了一遍,具体细节不懂没有关系,但是这个逻辑关系应该要把握住,这样才能把握宏观过程,也能够清晰的看出DataLoader
和Dataset
的关系。 根据前面介绍,DataLoader
的作用就是构建一个数据装载器, 根据我们提供的batch_size
的大小, 将数据样本分成一个个的batch
去训练模型,而这个分的过程中需要把数据取到,这个就是借助Dataset
的getitem
方法。
这样也就清楚了,如果我们想使用Pytorch读取数据的话,首先应该自己写一个
MyDatase
,这个要继承Dataset
类并且实现里面的__getitem__
方法,在这里面告诉机器怎么去读数据。 当然这里还有个细节,就是还要覆盖里面的__len__
方法,这个是告诉机器一共有多少个样本数据。 要不然机器没法去根据batch_size
的个数去确定有多少批数据。这个写起来也很简单,返回总的样本的个数即可。
def __len__(self):
return len(self.data_info)
这样, 机器就可以根据Dataset去硬盘中读取数据,接下来就是用DataLoader构建一个可迭代的数据装载器,传入如何读取数据的机制Dataset,传入batch_size, 就可以返回一批批的数据了。 当然这个装载器具体使用是在模型训练的时候。 当然,由于DataLoader是一个可迭代对象,当我们构建完毕之后,也可以简单的看下里面的数据到底长什么样, 大致代码是:
# 查看一个batch_size的数据
for x, y in train_loader:
print(x, y)
break
好了,上面就是
Pytorch
读取机制DataLoader
和Dataset
的原理部分了。
人民币二分类的数据模块里面,除了数据读取机制
DataLoader
,还涉及了一个图像的预处理模块transforms
, 是对图像进行预处理的,在下一篇博客中,我会详解讲解Pytorch
中对数据预处理的常用方法,再搞定这个细节,人民币二分类任务的数据模块就全部结束了哈。
本文的全套代码下载点击此处
:https://download.csdn.net/download/weixin_54546190/85539252人名币二分类完整Python代码(包含数据集
感谢大家观看,如果觉得博主写的不错的话,别忘了点个赞啊。
1、数据加载Dataset和DataLoader的使用
2、Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)