自定义训练,使用Generator dataset迭代数据报错

mindspore 1.3  Ascend 910环境

因为任务要求,我没有使用高层的model.train()接口,而是自定义训练过程,像pytorch那样训练,使用Generator Dataset迭代数据,第一个epoch数据正常迭代,下一个epoch就会报错。在每个epoch迭代之后,我都对dataset进行了reset,是不是我reset的位置不对或是少了什么东西?

报错的截图:

自定义训练,使用Generator dataset迭代数据报错_第1张图片

根据报错的信息,你自定义的dataset 的 __len__ 函数返回值是 36,但是真实的 __next__ 返回的数据量只有35 条,这个校验错误,所以报错了。
快速验证的话,你可以把 __len__的返回值改成 35 再试下。

你可能感兴趣的:(深度学习,python,pytorch)