Pytorch中Dataset和DataLoader

1.Dataset

  Dataset定义了数据集的内容,具有确定的长度,能够用索引获取数据集中的元素。只需实现Dataset的__len__方法和__getitem__方法,就可以轻松构建自己的数据集。复杂的数据集,还要设计 DataLoader中的 collate_fn

#fasterRCNN 

def collate_fn(batch):
  return tuple(zip(*batch))




#定义dataset
class dataset_db(Dataset):
    def __init__(self, dataPath, fileList):
        self.dataPath=dataPath
#         fileList=os.listdir(self.dataPath)
        random.shuffle(fileList)
        self.fileList = fileList
        self.classnames = ['_齿轮箱3V_128k', '_齿轮箱4H_128k', '_齿轮箱5V_128k', 
                            '_齿轮箱6A_128k', '_齿轮箱7V_128k', '_发电机8V_128k',
                            '_发电机9V_128k', '_主轴1H_128k', '_主轴2H_128k']

    def __len__(self):
        return len(self.fileList)

    def __getitem__(self, idx):
        npypath = self.fileList[idx] 
        classnameindex &

你可能感兴趣的:(pytorch,python,深度学习)