最近在用pytorch跑实验,有如下操作需要用到: (pytorch版本为0.3.1)
class SpatialFilter(nn.Module):
def __init__(self,mode=True,sf_rate=0.8):
'''
给定特征图求hot_map
'''
super(SpatialFilter,self).__init__()
self.sf_rate=sf_rate
self.mode=mode
def forward(self,x):
b,c,h,w=x.size()
if self.mode:
#print("====",self.sf_rate)
hot_map=torch.mean(x,dim=1).view(b,1,h*w).
map_med=torch.median(hot_map,dim=2)[0].view(b,1,1,1) #hot_map的中位数
hot_map=hot_map.view(b,1,h,w)
hot_map=torch.gt(hot_map,map_med*self.sf_rate).float()
del(map_med)
else:
#print("++++")
hot_map=Variable(torch.ones(b,1,h,w),requires_grad=False)
return hot_map
发现在训练的时候显存会不断增加,情况就是每隔一个epoch显存会增加30M左右,在一个epoch之内显存不会增加。刚开始我以为是我的训练部分写的有问题,后来发现不用torch.median()而是用 F.AdaptiveAvgPool2d()就不会有这个问题,于是我就去看了pytorch的中文文档,发现pytorch中文文档中torch.median()函数下有这么一句话: 注意: 这个函数还没有在torch.cuda.Tensor
中定义. 所以问题就很清晰了,这个运算应该是先把Tensor转到cpu上再把它挪回到gpu上,应该是由于内部转换机制的不健全导致了gpu上的显存没有及时释放,只需要人为的把操作转到cpu上再把tensor转到gpu上,并自己delete就可解决这个问题。
代码如下:
class SpatialFilter(nn.Module):
def __init__(self,mode=True,sf_rate=0.8):
'''
给定特征图求hot_map
'''
super(SpatialFilter,self).__init__()
self.sf_rate=sf_rate
self.mode=mode
def forward(self,x):
b,c,h,w=x.size()
if self.mode:
#print("====",self.sf_rate)
hot_map=torch.mean(x,dim=1).view(b,1,h*w).cpu()
if isinstance(x.data,torch.cuda.FloatTensor):
map_med=torch.median(hot_map,dim=2)[0].view(b,1,1,1).cuda() #hot_map的中位数
hot_map=hot_map.view(b,1,h,w).cuda()
else:
map_med=torch.median(hot_map,dim=2)[0].view(b,1,1,1) #hot_map的中位数
hot_map=hot_map.view(b,1,h,w)
hot_map=torch.gt(hot_map,map_med*self.sf_rate).float()
del(map_med)
else:
#print("++++")
hot_map=Variable(torch.ones(b,1,h,w),requires_grad=False)
return hot_map
这样改动之后就可以在cpu上和gpu上同时跑了。
相似的可能导致这种问题的操作还有:torch.mode() 求众数函数
希望能帮助大家以后跳过这个坑。