pytorch 数据加载性能对比

传统方式需要10s,dat方式需要0.6s


import os

import time
import torch
import random
from common.coco_dataset import COCODataset

def gen_data(batch_size,data_path,target_path):
    os.makedirs(target_path,exist_ok=True)
    dataloader = torch.utils.data.DataLoader(COCODataset(data_path,
                                                         (352, 352),
                                                         is_training=False, is_scene=True),
                                             batch_size=batch_size,
                                             shuffle=False, num_workers=0, pin_memory=False,
                                             drop_last=True)  # DataLoader
    start = time.time()
    for step, samples in enumerate(dataloader):
        images, labels, image_paths = samples["image"], samples["label"], samples["img_path"]
        print("time", images.size(0), time.time() - start)
        start = time.time()
        # torch.save(samples,target_path+ '/' + str(step) + '.dat')
        print(step)


def cat_100(target_path,batch_size=100):
    paths = os.listdir(target_path)

    li = [i for i in range(len(paths))]
    random.shuffle(li)

    images = []
    labels = []
    image_paths = []
    start = time.time()
    for i in range(len(paths)):
        samples = torch.load(target_path + str(li[i]) + ".dat")
        image, label, image_path = samples["image"], samples["label"], samples["img_path"]
        images.append(image.cuda())
        labels.append(label.cuda())
        image_paths.append(image_path)
        if i % batch_size == batch_size - 1:
            images = torch.cat((images), 0)
            print("time", images.size(0), time.time() - start)
            images = []
            labels = []
            image_paths = []
            start = time.time()
        i += 1

if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = '3'
    batch_size=320
    # target_path='d:/test_1000/'
    target_path='d:\img_2/'
    data_path = r'D:\dataset\origin_all_datas\_2train'
    gen_data(batch_size,data_path,target_path)
    # get_data(target_path,batch_size)
    # cat_100(target_path,batch_size)

这个读取数据也比较快:320 batch_size 450ms

 

def cat_100(target_path,batch_size=100):
    paths = os.listdir(target_path)

    li = [i for i in range(len(paths))]
    random.shuffle(li)

    images = []
    labels = []
    image_paths = []
    start = time.time()
    for i in range(len(paths)):
        samples = torch.load(target_path + str(li[i]) + ".dat")
        image, label, image_path = samples["image"], samples["label"], samples["img_path"]
        images.append(image)#.cuda())
        labels.append(label)#.cuda())
        image_paths.append(image_path)
        if i % batch_size < batch_size - 1:
            i += 1
            continue
        i += 1
        images = torch.cat(([image.cuda() for image in images]), 0)
        print("time", images.size(0), time.time() - start)
        images = []
        labels = []
        image_paths = []
        start = time.time()

你可能感兴趣的:(torch)