不多说,直接上源码
我做的是语种分类的项目,所以直接上了,里面有些介绍。
还是先简要介绍,继承torch.data.dataset,然后重写init、len和getitem方法。
代码如下:
import os
import torch
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import numpy.fft as fft
import cv2
import torchvision.transforms as transforms
from torch.utils import data
#librosa 简介,音频处理库
# 音频读取函数load()
# 重采样函数resample()
# 短时傅里叶变换stft()
# 幅度转换函数amplitude_to_db()
# 频率转换函数hz_to_mel()
# 频谱显示函数specshow()
# 波形显示函数waveplot()
class MyDataset(data.Dataset):
def __init__(self, Path, second=1, transform=None, target_transform=None): #初始化一些需要传入的参数
super(MyDataset,self).__init__()
self.Path=Path
self.classes = {"en":"0", "ru":"1", "yue":"2", "zh":"3"}
melimgs = []
for root1, dirs, files in sorted(os.walk(Path)):
if root1!=Path:
for files in sorted(os.listdir(root1)):
if files.split(".")[-1] in ["wav","mp3"]:
melimgs.append((os.path.join(root1,files),os.path.split(root1)[1]))#,self.classes[os.path.split(root1)[1]]))
self.second=second
self.melimgs=melimgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
wavpath, label = self.melimgs[index]
clip,sample_rate=librosa.load(wavpath,sr=None)
melspec = librosa.feature.melspectrogram(clip,sample_rate,n_fft=1024,hop_length=512,n_mels=128)
logmelspec= librosa.power_to_db(melspec)
if self.second !=1 :
if logmelspec.shape[1]>91*self.second:
logmelspec = logmelspec[np.newaxis,:,:91*self.second]
if self.transform is not None:
logmelspec = self.transform(logmelspec) #是否进行transform
return logmelspec,label #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
else:
logmelspec = logmelspec[np.newaxis,:,:91*self.second]
if self.transform is not None:
logmelspec = self.transform(logmelspec) #是否进行transform
return logmelspec,label #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
return len(self.melimgs)
if __name__ == '__main__':
train_data=MyDataset(Path="data",second=1, transform=transforms.ToTensor())
#然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关
train_loader = data.DataLoader(dataset=train_data, batch_size=5, shuffle=True)
for batch_idx, (data, target) in enumerate(train_loader):
print(batch_idx, (data, target))
就是这样简单,这只是加载数据集的一种方法,其实只要你能讲数据送到网络里就行不管你采用什么方法。得到结果才是王道。
torch的这个文件包含了一些关于数据集处理的类:
class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。
class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。
class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。
class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=
torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。
class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 iter 方-法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。
class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。
class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。
class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。