pytorch减少显存使用

显存耗尽

RuntimeError: Caught RuntimeError in replica 0 on device 0.

RuntimeError: CUDA out of memory. Tried to allocate 4.01 GiB (GPU 0; 11.91 GiB total capacity; 4.89 GiB already allocated; 2.38 GiB free; 4.02 GiB cached)

transpose占用大量显存

通常transpose(),permute()会改变tensor的结构,需要使用.contiguous()将内存转换为连续的。估计pytorch为了加快计算速度,使用了占用大量显存的运算方式,开辟出新的显存并进行赋值。对这些函数进行修改可以减少大量显存:

将tensor分为多个batch,分别进行transpose

b,c,h,w = out.shape
# transpose dimension (1,2): bxcxhxw -> bxhxcxw. batchsize=h
# out = out.reshape(b, c, h, w).permute(0,2,1,3).contiguous()
batchsize = h
for ind_b in range(b):
	for ind_r in range(0,w):
		tmp = out[ind_b,:,:,ind_r*batchsize:(ind_r+1)*batchsize].transpose(0,1).contiguous()
		out[ind_b,:,:,ind_r*batchsize:(ind_r+1)*batchsize] = tmp.reshape(N,c,batchsize)

矩阵点乘

矩阵点乘涉及到开辟新显存,如果用结果覆盖原有的数据可以节约显存:

# corr: b,c,h,w
# image_uf: b,c,h,w
# out = corr * image_uf
for batch in range(image_uf.shape[1]):
	image_uf[:,batch,:,:] = image_uf[:,batch,:,:]*corr[:,batch,:,:]

注意,pytorch.sum()的结果与直接相加的结果不同:

# out = (corr * image_uf).sum(2)
for batch in range(image_uf.shape[1]):
	image_uf[:,batch,:,:] = image_uf[:,batch,:,:]*corr[:,batch,:,:]
	out = image_uf.sum(2)

不可以使用

# out = (corr * image_uf).sum(2)
out = torch.zeros(b,h,w)
for batch in range(image_uf.shape[1]):
	out = out + image_uf[:,batch,:,:]*corr[:,batch,:,:]

注意

只能在测试代码中使用,要传梯度的话要再想办法计算梯度

你可能感兴趣的:(显存,pytorch,CUDA,out,of,memory,pytorch,gpu,cuda)