深度学习新手第一次复现代码的记录

研究生入学的第个学期。我接到了一个任务-- 复现一个深度学习模型。1 以及在此基础上改进得到的2。

一、

我首先找到了一个 1 模型代码(非原作者)。但是它用的框架我不太懂。为了方便,我把整个模型移植到了实验室自创框架上。花时间两周左右,时间主要用来改框架(原框架是针对实现数域计算的,但是这个模型需要能计算虚数)。框架中,数据输入,预处理,结果的可视化等。新框架命名为 Mcom

一周零三天左右,改动完成了。我在自己PC上试了一下效果,然而效果稀烂。而且要求所有的语音都是等长。我又花了两天时间,反复查找代码的问题。无果。(这个时间基本浪费了)

然而,这时我偶然找到了1的原作者代码,(一开始没找到是因为代码名字和论文中模型名字不一样)花了一天时间,把模型迁移到了Mcom上。(时间少框架已经完成,而且原作者代码习惯好,易用)发现效果很不错。而且不要求语音等长。

教训1  :一定要反复查找原作者代码,其他作者的代码不值得信任!(而且原作者代码质量一般更高) paper with code 只是第一步,应该在 git上多找一找,换几个名字试试。并且应该大胆向原作者求援,发邮件。

二、

完成复现1以后,我花了两周的时间,初步改成了 2 ,第一周编写复数 attention 模块,第二周设计各个网络层及其参数。

教训2:新手第一次设计网络可能生疏,不敢动手,要大胆的把网络结构画好,迅速写好代码多做实验,调通很重要。

三、

我迅速把代码移植到服务器,开始训练,很快发现,屡屡出现显存爆炸问题。

开始解决: 我找了两个监视显存的系统,花费了三天时间。发现每一轮过去,所用的显存都有一定上升,进行大量实验后,几乎监视了所有的变量和语句,终于发现出问题的似乎是pytorch 的自带函数。在完成训练以后没有及时释放显存。但是死活找不到到底是哪个变量出的问题。

终于,偶然机会我发现,在pc上没有显存递增问题。在服务器上,单卡也不会出现显存递增问题。我想到问题应该出在 并行计算的通信机制上。

这时候,我想到和之前正常的模型相比,1、2模型使用了新的pytorch 版本 1.4.0。(因为最初的非原作者1模型中,要求环境配置为 pytoch 1.6,但是1.6报错比较多,我改成1.4.0)我把pytorch 降回 1.2.0。问题终于解决。

教训三:实验室或者大牛的 深度学习框架,不要轻易改动自己看不懂的地方,更不要改动环境的配置。否则会出现自己意想不到的问题。

你可能感兴趣的:(自己记忆,人工智能,深度学习)