pytorch学习笔记——Dataset与Dataloader以及迭代器与迭代对象

流程:

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# 构造dataset
train = Mydataset(xy_train)
test = Mydataset(xy_test)
# 训练时,一般打乱数据;但测试时不打乱;batch_size自己设置
# 构造dataloader
train_loader = DataLoader(dataset=train,batch_size=32,shuffle=True)
test_loader = DataLoader(dataset=train,batch_size=32,shuffle=False)

# 自写dataset的类,继承torch.utils.data.Dataset
class Mydataset(Dataset):
	# 构造函数
	def __init__(self,xy_data):
		# axis=1,表示是按列名删除(列名就是刚刚的第一行)
		# 因为原数据集中有字符串等,为了等下转tensor方便,直接把这些数据删掉
		# 我们本节的目的是为了学习DataLoader的使用而不是利用数据进行后续分析
		self.x_data = xy_data.drop(['PassengerId','Survived','Name', 'Sex', 'Ticket', 'Cabin', 'Embarked'], axis=1)
		self.y_data = xy_data[['Survived']]
		# 这里的x_data、y_data要处理成tensor格式
		# 能转成tensor格式的数据有int、float、bool
		# 这部分转tensor好像有点累赘,欢迎补充交流
		self.x_train = np.array(self.x_train)
        self.y_train = np.array(self.y_train)
		# self.x_train = self.x_train.astype(float)
        #self.y_train = self.y_train.astype(float)
        self.x_train = torch.Tensor(self.x_train.astype(float))
        self.y_train = torch.Tensor(self.y_train.astype(float))

		# shape[0]是从纵向角度看,代表行数
		self.len = self.xy_data.shape[0]
	
	# 按索引取出对应元素
	def __getitem__(self,index):
		return self.x_data[index],self.y_data[index]
	
	# 
	def __len__(self):
		return self.len

# 从Dataloader里读出数据和标签
for epoch in range(28):
    print("epoch:",epoch)
    for i,data in enumerate(train_loader):
        inputs, labels = data# inputs=images
        

dataloader通过for…in…读取,当用emunate时返回index和batch
在这里插入图片描述
batch是长度为2的列表,第一个元素是输入即图片,第二个元素是标签。(这里的batch返回是什么取决于dataset的__getitem()__,所以有再训练和验证时返回图片和标签而在测试时的返回是图片和其他(可以是图片名字等)。
输入是形状为(4,3,640,640)的张量,这里4为batch_size,640为图片尺寸,3为通道数;
标签是
在这里插入图片描述
输出是长度为4的列表,每个元素是batch里每张图对应的标签。

在这里插入图片描述
输入经过网络的输出为长度为3的列表,每个元素为每一个特征图的输出,这里以yolox为例,输出是20,40,80的特征图
在这里插入图片描述
Pytorch(三):Dataset和Dataloader的理解
pytorch学习笔记(一):Dataset和DataLoader
迭代器与迭代对象:
含有__iter__()为可迭代对象,含有__iter__()和__next__()为迭代器,iter()返回迭代器,只有迭代器才能调用内置函数__next__(),所有可迭代对象都可以用for…in…访问。

__iter__方法的作用是让对象可以用for … in循环遍历,getitem( )方法是让对象可以通过“实例名[index]”的方式访问实例中的元素。
python可迭代对象
python中可迭代对象是什么

你可能感兴趣的:(pytorch学习笔记,git,深度学习,pytorch)