pytorch加载不同size的文件(.npy, .wav, .jpg)进行padding

pytorch加载不同size的文件,然后进行padding

以加载不同size的"XXX.npy"文件为例;

第一步:重写dataset,代码如下

from torch.utils.data import DataLoader, Dataset

class train_dataset(Dataset):
    def __init__(self, train_path):
        super(train_dataset, self)
        self.all_list = find_files(train_path, ext="npy")
        self.length = len(self.all_list)

    def __getitem__(self, index):
        x = np.load(self.all_list[index])
        npy_name = self.all_list[index]
        npys = "Ses" + npy_name.split("Ses")[-1]
        npy = npys.split("label")[1][:-4].strip().split(" ")
        labels = np.array(list(map(float, npy[1:]))) #获得数据标签,可根据自己数据进行修改
        imdex = x.shape[0]#获得每个数据的大小。
        # print(x.shape)
        # emo = npy[0]
        return x, labels, imdex

    def __len__(self):
        return self.length

第二步:加载数据

def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    index = [item[2] for item in batch]
    return [data, target, index]
batch_size = 28
train_set = train_dataset(train_path)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, collate_fn=my_collate)

第三步:读取数据;

#这里的datas,labels, indexs分别以列表的形式,大家可以尝试print
for datas, labels, indexs in train_loader :
    max_index = max(indexs)

    # zero-padding batch中数据,可根据自己需求进行修改
    '''
    zero-padding可以在my_collate中进行修改。读取数据,进行zero-padding
    '''
    for i, da in enumerate(datas):
        da = torch.tensor(da)
        # print(da.shape)
        if da.shape[0] < max_index:
            padding = torch.zeros([max_index - da.shape[0], 512])
            temp = torch.cat((da, padding), 0)
            da = temp.unsqueeze(0)
        else:
            da = da.unsqueeze(0)

        if i == 0:
            train_da = da
        else:
			train_da = torch.cat((train_da, da), 0)
	     # 对数据标签进行处理
	for i, label in enumerate(labels):
	    label = torch.tensor(label)
	    label = label.unsqueeze(0)
	    if i == 0:
	        train_label = label.float()
	    else:
	        train_label = torch.cat((train_label, label.float()), 0)

你可能感兴趣的:(python,语音识别,深度学习,情感分析,神经网络)