pytorch显存越来越多的一个潜在原因-- 这个函数还没有在torch.cuda.Tensor中定义

最近在用pytorch跑实验,有如下操作需要用到: (pytorch版本为0.3.1)

class SpatialFilter(nn.Module):
    def __init__(self,mode=True,sf_rate=0.8):
        '''
        给定特征图求hot_map
        '''
        super(SpatialFilter,self).__init__()
        self.sf_rate=sf_rate
        self.mode=mode
    def forward(self,x):
        b,c,h,w=x.size()
        
        if self.mode:
            #print("====",self.sf_rate)
            hot_map=torch.mean(x,dim=1).view(b,1,h*w).
            map_med=torch.median(hot_map,dim=2)[0].view(b,1,1,1)  #hot_map的中位数
            hot_map=hot_map.view(b,1,h,w)
            hot_map=torch.gt(hot_map,map_med*self.sf_rate).float()
            del(map_med)        
        else:
            #print("++++")
            hot_map=Variable(torch.ones(b,1,h,w),requires_grad=False)
        return hot_map

发现在训练的时候显存会不断增加,情况就是每隔一个epoch显存会增加30M左右,在一个epoch之内显存不会增加。刚开始我以为是我的训练部分写的有问题,后来发现不用torch.median()而是用 F.AdaptiveAvgPool2d()就不会有这个问题,于是我就去看了pytorch的中文文档,发现pytorch中文文档中torch.median()函数下有这么一句话:   注意: 这个函数还没有在torch.cuda.Tensor中定义. 所以问题就很清晰了,这个运算应该是先把Tensor转到cpu上再把它挪回到gpu上,应该是由于内部转换机制的不健全导致了gpu上的显存没有及时释放,只需要人为的把操作转到cpu上再把tensor转到gpu上,并自己delete就可解决这个问题。

代码如下:

class SpatialFilter(nn.Module):
	def __init__(self,mode=True,sf_rate=0.8):
		'''
		给定特征图求hot_map
		'''
		super(SpatialFilter,self).__init__()
		self.sf_rate=sf_rate
		self.mode=mode
	def forward(self,x):
		b,c,h,w=x.size()
		
		if self.mode:
			#print("====",self.sf_rate)
			hot_map=torch.mean(x,dim=1).view(b,1,h*w).cpu()
			if isinstance(x.data,torch.cuda.FloatTensor):
				map_med=torch.median(hot_map,dim=2)[0].view(b,1,1,1).cuda()  #hot_map的中位数
				hot_map=hot_map.view(b,1,h,w).cuda()
			else:
				map_med=torch.median(hot_map,dim=2)[0].view(b,1,1,1)  #hot_map的中位数
				hot_map=hot_map.view(b,1,h,w)
			hot_map=torch.gt(hot_map,map_med*self.sf_rate).float()
			del(map_med)		
		else:
			#print("++++")
			hot_map=Variable(torch.ones(b,1,h,w),requires_grad=False)
		return hot_map

这样改动之后就可以在cpu上和gpu上同时跑了。

相似的可能导致这种问题的操作还有:torch.mode() 求众数函数

希望能帮助大家以后跳过这个坑。

你可能感兴趣的:(pytorch,机器学习,深度学习,python)