本章节主要介绍如何使用torch.utils.data 中的Dataset和Dataloader来构建数据集, 重点要看使用细节
class WeiBoDataset(Dataset):
pass
注意 : 我们一般会在初始化的时候就加载进数据, 读取数据函数需要自定义
class WeiBoDataset(Dataset):
def __init__(self, data_path):
# 读取数据
self.label, self.data = self.read_data(data_path)
class WeiBoDataset(Dataset):
def __init__(self, data_path):
# 读取数据
self.label, self.data = self.read_data(data_path)
def __len__(self):
"""
这个必须要设置, getitem中的index就是根据这个来设置的
:return:
"""
return len(self.data)
def __getitem__(self, index):
label = 1
# features = [str(i) for i in range(10)]
features = np.array([i for i in range(10)])
return label, features
weibo_dataset=WeiBoDataset("../../datasets/weibo_test_data.csv)
dataloader=DataLoader(weibo_dataset,batch_size=1024,shuffle=True)
for i, batch in enumerate(dataloader):
# batch : [label, features] 组成
print(type(batch[0]), type(batch[1]))
注意 features的元素类型是str, 那么可以看到下面的输出结果中 label 是 tensor, features 是 list类型的
def __getitem__(self, index):
label = 1
# 转换为 ndarray 会报错
# features = np.array([str(i) for i in range(10)])
features = [str(i) for i in range(10)]
return label, features
下面将feature中的数据元素换成了int类型的, 并且对将list转换为ndarray, 这样在获取batch时数据会自动转换为tensor , 但是这里需要注意的是, 上面的数据是不能用np.array()的, 这是因为 batch 必须包含 tensors, numpy arrays, numbers, dicts or lists 这几种类型, 其他的都会报错, 具体可以查看
def __getitem__(self, index):
label = 1
features = np.array([i for i in range(10)])
return label, features