pytorch.utils.data.DataLoader的使用

# 导入必要的模块
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import transforms as trans
from torch.utils.tensorboard import SummaryWriter as sw

# 用来演示的数据集
'''
    root:存放路径
    train:使用训练集(True)还是测试集(False)
    transform:转换器,可以用Compose设置多个转换,这里只用了一个ToTensor
    target_transform:对target进行转换 还没用过,先不设置
    download:是否下载数据集到root路径下,(我这里已经下载过了,所以选了False)
'''
dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=trans.Compose([
    trans.ToTensor()
]), download=False)

# 新建一个DataLoader实例
'''
    dataset:数据源
    shuffle:是否随机抓取数据
    batch_size:一次抓取的数据量
    drop_last:是否舍弃余数,比如数据集中一共有100个数据,一次抓取3个,抓取33次后,还剩下一个数据,True舍弃最后这个数据,False不舍弃
    num_workers:应该是与多线程有关的参数,windows里有时候抓取失败要设置为0
'''
dataloader = DataLoader(dataset=dataset, shuffle=True, batch_size=64, drop_last=False,num_workers=0)

# 将dataloader中的数据写入到tensorboard
writer = sw('logs_loader')

step = 0
for data in dataloader:
    imgs,targets = data
    print(imgs.shape,targets)
    # 写入图片,这里是多个图片,用的是add_images
    writer.add_images(tag='loader_test', img_tensor=imgs, global_step=step)
    step += 1

writer.close()

你可能感兴趣的:(深度学习基础,pytorch,深度学习,人工智能)