torch.utils.data.Dataset/TensorDataset

torch.utils.data.Dataset

class MyData(torch.utils.data.Dataset):
    def __init__(self, dt, lb):
        self.dt = dt
        self.lb = lb

    def __len__(self):
        return len(self.dt)

    def __getitem__(self, index):
        return self.lb[index], np.array(self.dt[index])

重写了__len__()、和__getitem__()方法。

torch.utils.data.TensorDataset:

class TensorDataset(Dataset):
    """Dataset wrapping tensors.
    Each sample will be retrieved by indexing tensors along the first dimension.
    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
 
    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
 
    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)
 
    def __len__(self):
        return self.tensors[0].size(0)

从上图可知,TensorDataset是继承了Dataset

你可能感兴趣的:(NLP)