STANet复现

论文

一、为了捕获遥感图像不同尺度下的时空依赖关系,从而生成更好的表示来适应不同大小的对象,论文作者提出了一个用于变化检测的时空注意力神经网络STANet
二、为了缓解双时相图像配准引起的误检问题,作者设计了两种自注意力模块:BAM、PAM(PAM实验效果最好)。训练阶段采用对比损失学习网络参数,由于遥感影像变化样本和未变化样本的数量差距非常大,所以作者提出批量平衡对比损失。另外,本篇论文提出了一种新的大型的遥感影像变化检测数据集LEVIR-CD:此数据集为大型遥感影像变化检测数据集,比目前公开数据集数量多,分辨率高,标注数据在建筑生长以及建筑衰败方面更加关注。
STANet复现_第1张图片

代码复现

问题1:

AssertionError: X and Y should be the same shape

解决办法:

修改可视化版本visdom=0.1.8.8;
scipy=1.1.0:因为1.2.0版本的scipy没有 imread,也会报错。

问题2:

RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 6.00 GiB total capacity; 2.98 GiB already allocated; 1.78 GiB free; 2.89 MiB cached)

解决办法:

1、裁剪图片256*256
2、降batch size 8 为 4
3、在命令行加 --ds 2

运行命令

python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFA0 --lr 0.001 --model CDFA --SA_mode BAM --batch_size 4 --load_size 256 --crop_size 256 --preprocess rotate_and_crop  --ds 2

裁剪图片代码

import os
import os.path as osp
import sys
from multiprocessing import Pool
import numpy as np
import cv2
from PIL import Image
import time
from shutil import get_terminal_size

sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))


def main():
    mode = 'pair'  # single (one input folder) | pair (extract corresponding GT and LR pairs)
    opt = {}
    opt['n_thread'] = 20
    opt['compression_level'] = 3  # 3 is the default value in cv2
    # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
    # compression time. If read raw images during training, use 0 for faster IO speed.
    if mode == 'single':
        opt['input_folder'] = r'E:\A'
        opt['save_folder'] = 'E:\dataset_new'
        opt['crop_sz'] = 480  # the size of each sub-image
        opt['step'] = 240  # step of the sliding crop window
        opt['thres_sz'] = 48  # size threshold
        extract_signle(opt)

    elif mode == 'pair':
        GT_folder = r'E:\A\val\B'
        save_GT_folder = 'E:\\dataset_new\\val\\B'
        crop_sz = 256  # the size of each sub-image (GT)
        step = 256  # step of the sliding crop window (GT)
        thres_sz = 256  # size threshold

        img_GT_list = _get_paths_from_images(GT_folder)

        print('process GT...')
        opt['input_folder'] = GT_folder
        opt['save_folder'] = save_GT_folder
        opt['crop_sz'] = crop_sz
        opt['step'] = step
        opt['thres_sz'] = thres_sz
        extract_signle(opt)

    else:
        raise ValueError('Wrong mode.')


def extract_signle(opt):
    input_folder = opt['input_folder']
    save_folder = opt['save_folder']
    if not osp.exists(save_folder):
        os.makedirs(save_folder)
        print('mkdir [{:s}] ...'.format(save_folder))
    else:
        print('Folder [{:s}] already exists. Exit...'.format(save_folder))
        sys.exit(1)
    img_list = _get_paths_from_images(input_folder)

    def update(arg):
        pbar.update(arg)

    pbar = ProgressBar(len(img_list))

    pool = Pool(opt['n_thread'])
    for path in img_list:
        pool.apply_async(worker, args=(path, opt), callback=update)
    pool.close()
    pool.join()
    print('All subprocesses done.')


def worker(path, opt):
    crop_sz = opt['crop_sz']
    step = opt['step']
    thres_sz = opt['thres_sz']
    img_name = osp.basename(path)
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)

    n_channels = len(img.shape)
    if n_channels == 2:
        h, w = img.shape
    elif n_channels == 3:
        h, w, c = img.shape
    else:
        raise ValueError('Wrong image shape - {}'.format(n_channels))

    h_space = np.arange(0, h - crop_sz + 1, step)
    if h - (h_space[-1] + crop_sz) > thres_sz:
        h_space = np.append(h_space, h - crop_sz)
    w_space = np.arange(0, w - crop_sz + 1, step)
    if w - (w_space[-1] + crop_sz) > thres_sz:
        w_space = np.append(w_space, w - crop_sz)

    index = 0
    for x in h_space:
        for y in w_space:
            index += 1
            if n_channels == 2:
                crop_img = img[x:x + crop_sz, y:y + crop_sz]
            else:
                crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
            crop_img = np.ascontiguousarray(crop_img)
            cv2.imwrite(
                osp.join(opt['save_folder'],
                         img_name.replace('.png', '.png'.format(index))), crop_img,
                [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
    return 'Processing {:s} ...'.format(img_name)


class ProgressBar(object):
    '''A progress bar which can print the progress
    modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
    '''

    def __init__(self, task_num=0, bar_width=50, start=True):
        self.task_num = task_num
        max_bar_width = self._get_max_bar_width()
        self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
        self.completed = 0
        if start:
            self.start()

    def _get_max_bar_width(self):
        terminal_width, _ = get_terminal_size()
        max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
        if max_bar_width < 10:
            print('terminal width is too small ({}), please consider widen the terminal for better '
                  'progressbar visualization'.format(terminal_width))
            max_bar_width = 10
        return max_bar_width

    def start(self):
        if self.task_num > 0:
            sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
                ' ' * self.bar_width, self.task_num, 'Start...'))
        else:
            sys.stdout.write('completed: 0, elapsed: 0s')
        sys.stdout.flush()
        self.start_time = time.time()

    def update(self, msg='In progress...'):
        self.completed += 1
        elapsed = time.time() - self.start_time + 1e-9
        fps = self.completed / elapsed
        if self.task_num > 0:
            percentage = self.completed / float(self.task_num)
            eta = int(elapsed * (1 - percentage) / percentage + 0.5)
            mark_width = int(self.bar_width * percentage)
            bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
            sys.stdout.write('\033[2F')  # cursor up 2 lines
            sys.stdout.write('\033[J')  # clean the output (remove extra chars since last display)
            sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
                bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
        else:
            sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
                self.completed, int(elapsed + 0.5), fps))
        sys.stdout.flush()


# ###################
# ### Data Utils ####
# ###################
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def _get_paths_from_images(path):
    """get image path list from image folder"""
    assert osp.isdir(path), '{:s} is not a valid directory'.format(path)
    images = []
    for dirpath, _, fnames in sorted(os.walk(path)):
        for fname in sorted(fnames):
            print("..fname is:", fname)

            if is_image_file(fname):
                img_path = os.path.join(dirpath, fname)
                images.append(img_path)
    assert images, '{:s} has no valid image file'.format(path)
    return images


if __name__ == '__main__':
    main()

你可能感兴趣的:(目标检测)