pytorch loss.backword() 时间太长

直接原因是:数据在进入模型之前没有进行深拷贝

深层原因大概是:如果不进行深拷贝,在梯度反向传播过程中,要寻找原始数据的地址,这个过程非常耗时间。(直接等号是前拷贝,是将新的变量指向原来变量的地址)

解决办法:

tensor_a = tensor_b.clone().detach()

或者用deepcopy也行。

位置呢,就放到数据进入模型之前就可以。大概如下:

data = loader.get_batch('train')

data_copy = data.clone().detach()

optimizer.zero_grad()
out,loss = model(data_copy)
loss.backward()
optimizer.step()

如果data是tensor构成的字典或者list,遍历处理里面的每一项即可。

效果展示:

加之前

pytorch loss.backword() 时间太长_第1张图片

加之后

pytorch loss.backword() 时间太长_第2张图片

 效果十分显著

 

你可能感兴趣的:(pytorch)