pytorch中加载数据时常用的def __len__(self):和def __getitem__(self, index)

研究了一天,总算有点眉目,先上代码

    def __len__(self):
        return len(self.ques)  # 返回dataset的长度

    def __getitem__(self, index):
        questions = self.ques[index]
        skill = self.skill[index]
        answers = self.ans[index]
        onehot = self.onehot(questions, skill, answers) # 后续编码,不管他
        return onehot

我的理解是,假如你的dataset处理完后是[n,50],50是步长

__getitem__做的事情就是返回第index个样本的具体数据:
return dataset[index]
而这个index是随机返回一个属于[0,n-1]的数
__len__作用是得到长度n

你可能感兴趣的:(读研之路,pytorch,python,人工智能)