STANet 代码复现--新手遇到的问题解决办法

STANet网络复现–可运行/持续更新

文章目录

    • STANet网络复现--可运行/持续更新
  • 前言
  • 一、out of memory?
  • 1.ds参数设置
  • 2.数据集分割
  • 二、原因分析
    • 1.导致oom的原因
    • 2.ds参数是什么
  • 总结


前言

STANet作为入门的第一篇研究文章,github上已有开源代码,也有优秀的博主对文章内容做了讲解,但是,新手实在是太菜了。下面对遇到的几个问题做总结,并给出成功运行代码的方法,特别是BAM/PAM内存溢出问题的解决。

一、out of memory?

RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 6.00 GiB total capacity; 2.98 GiB already allocated; 1.78 GiB free; 2.89 MiB cached)
如果你感到上述报错非常熟悉,那你就看对文章了。
分析可知主要是attention矩阵维度过大,导致了内存溢出,目前找到的解决办法有两种:

1.ds参数设置

不是batch size,这里的是代码注意力机制默认的参数ds,代码默认为1,想要跑通需要设置为2/4/8。
具体的bam运行代码为:

python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFA0 --lr 0.001 --model CDFA --SA_mode BAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop  --ds 4

具体的pam运行代码为:

python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFAp0 --lr 0.001 --model CDFA --SA_mode PAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop  --ds 4

若显存仍旧不足,请同时降低bz和ds(笔者同设置为2则可以高精度跑通)

2.数据集分割

这里你肯定会疑惑?代码运行的时候不是做了数据集分割?
在验证过程笔者发现有时候数据集可能没有被分割,所以保险起见进行了分割以后直接参与运行
分割数据集的代码(1024×1024分割为16张64*64的图片):

from PIL import Image
import os

def image_crop(data_dir,save_dir):
     if not os.path.exists(save_dir):

              os.mkdir(save_dir)

     path=os.path.join(data_dir)

     img_list=os.listdir(path)

     for img in img_list:

            a=0

           # if img.endswith('.png') or img.endswith('.jpg' ):
            if img.endswith('.png'):
                img_name = path+'/' +img

                im=Image.open(img_name)

#这里是将[1024, 1024]裁剪为[64,64]

                for i in range(4):

                    for j in range(4):

                        x=i*256

                        y=j*256
                        region=im.crop((x,y,x+256,y+256))

                        region.save(save_dir+'/'+img.split('-')[0]+'_'+str(a)+'.png')
                        a+=1
data_dir=r"E:/this desk/STANet-master/LEVIR-CD/val/label"#原数据集位置
save_dir=r"E:/this desk/STANet-master/LEVIR-CD1/val/label"#新数据集位置
image_crop(data_dir,save_dir)

二、原因分析

1.导致oom的原因

我们来看网络的主要结构,如下:
STANet 代码复现--新手遇到的问题解决办法_第1张图片
在softmax之前和之后,分别两次 O ( N 2 ) {\rm{O}}\left( {{N^2}} \right) O(N2)级别的矩阵乘法,自然吃掉了极大部分的显存。

2.ds参数是什么

代码中:ds表示控制输入下采样(从灰色到蓝橙绿)
代码如下(示例):

class _PAMBlock(nn.Module):
    '''
    The basic implementation for self-attention block/non-local block
    Input/Output:
        N * C  *  H  *  (2*W)
    Parameters:
        in_channels       : the dimension of the input feature map
        key_channels      : the dimension after the key/query transform
        value_channels    : the dimension after the value transform
        scale             : choose the scale to partition the input feature maps
        ds                : downsampling scale
自我注意块/非局部块的基本实现
输入/输出:
N * C * H * (2*W)
参数:
in_channels:输入特征映射的维度
key_channels:键/查询转换之后的维度
value_c、channels:值转换之后的维度
比例尺:选择比例尺对输入的特征图进行划分
ds:下采样尺度
    '''
    # 进行定义
    def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1):
        super(_PAMBlock, self).__init__()
        self.scale = scale
        self.ds = ds
        self.pool = nn.AvgPool2d(self.ds)#

总结

雪儿妹妹的求学路第一篇到此结束,谢谢大家支持,互相学习~~

你可能感兴趣的:(C++),cv论文学习,深度学习,pytorch,神经网络,代码规范)