RuntimeError: Caught RuntimeError in replica 0 on device 0.
RuntimeError: CUDA out of memory. Tried to allocate 4.01 GiB (GPU 0; 11.91 GiB total capacity; 4.89 GiB already allocated; 2.38 GiB free; 4.02 GiB cached)
通常transpose(),permute()会改变tensor的结构,需要使用.contiguous()将内存转换为连续的。估计pytorch为了加快计算速度,使用了占用大量显存的运算方式,开辟出新的显存并进行赋值。对这些函数进行修改可以减少大量显存:
将tensor分为多个batch,分别进行transpose
b,c,h,w = out.shape
# transpose dimension (1,2): bxcxhxw -> bxhxcxw. batchsize=h
# out = out.reshape(b, c, h, w).permute(0,2,1,3).contiguous()
batchsize = h
for ind_b in range(b):
for ind_r in range(0,w):
tmp = out[ind_b,:,:,ind_r*batchsize:(ind_r+1)*batchsize].transpose(0,1).contiguous()
out[ind_b,:,:,ind_r*batchsize:(ind_r+1)*batchsize] = tmp.reshape(N,c,batchsize)
矩阵点乘涉及到开辟新显存,如果用结果覆盖原有的数据可以节约显存:
# corr: b,c,h,w
# image_uf: b,c,h,w
# out = corr * image_uf
for batch in range(image_uf.shape[1]):
image_uf[:,batch,:,:] = image_uf[:,batch,:,:]*corr[:,batch,:,:]
注意,pytorch.sum()的结果与直接相加的结果不同:
# out = (corr * image_uf).sum(2)
for batch in range(image_uf.shape[1]):
image_uf[:,batch,:,:] = image_uf[:,batch,:,:]*corr[:,batch,:,:]
out = image_uf.sum(2)
不可以使用
# out = (corr * image_uf).sum(2)
out = torch.zeros(b,h,w)
for batch in range(image_uf.shape[1]):
out = out + image_uf[:,batch,:,:]*corr[:,batch,:,:]
只能在测试代码中使用,要传梯度的话要再想办法计算梯度