解决pytorch在训练时由于设置了验证集导致out of memory(同样可用于测试时减少显存占用)

问题描述:


最近一直在使用pytorch, 由于深度学习的网络往往需要设置验证集来验证模型是否稳定.

我一直再做一个关于医学影像分割的课题,为了查看自己的模型是否稳定,于是设置了验证集.

但是在运行的过程中,当程序执行到 validatioon时,显存立即上升,我可怜的显卡只有8GB显存,瞬间爆炸.

怎么办呢?实验得做呀.于是找了不少方法,比如设置各个网络变量requires_grad=False,但是并不管用,显存依然爆炸.

后来百度了一番,终于解决了显存爆炸的问题.

解决方案:


假设训练程序是这样的:

for train_data, train_label in  train_dataloader:

    do 

           trainning

then

for valid_data,valid_label in valid_dataloader:

    do 

            validtion

当程序执行到validation时,显存忽然上升,几乎是之前的两倍.


只需要这样改:

for train_data, train_label in train_dataloader:

        do

            trainning


then

with torch.no_grad():

    for valid_data,valid_label in valid_dataloader:

            do

                validtion

当程序执行到validation时,显存将不再上升.问题得到解决.真的是非常简单.

你可能感兴趣的:(解决pytorch在训练时由于设置了验证集导致out of memory(同样可用于测试时减少显存占用))