一、为了捕获遥感图像不同尺度下的时空依赖关系,从而生成更好的表示来适应不同大小的对象,论文作者提出了一个用于变化检测的时空注意力神经网络STANet。
二、为了缓解双时相图像配准引起的误检问题,作者设计了两种自注意力模块:BAM、PAM(PAM实验效果最好)。训练阶段采用对比损失学习网络参数,由于遥感影像变化样本和未变化样本的数量差距非常大,所以作者提出批量平衡对比损失。另外,本篇论文提出了一种新的大型的遥感影像变化检测数据集LEVIR-CD:此数据集为大型遥感影像变化检测数据集,比目前公开数据集数量多,分辨率高,标注数据在建筑生长以及建筑衰败方面更加关注。
问题1:
AssertionError: X and Y should be the same shape
解决办法:
修改可视化版本visdom=0.1.8.8;
scipy=1.1.0:因为1.2.0版本的scipy没有 imread,也会报错。
问题2:
RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 6.00 GiB total capacity; 2.98 GiB already allocated; 1.78 GiB free; 2.89 MiB cached)
解决办法:
1、裁剪图片256*256
2、降batch size 8 为 4
3、在命令行加 --ds 2
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFA0 --lr 0.001 --model CDFA --SA_mode BAM --batch_size 4 --load_size 256 --crop_size 256 --preprocess rotate_and_crop --ds 2
import os
import os.path as osp
import sys
from multiprocessing import Pool
import numpy as np
import cv2
from PIL import Image
import time
from shutil import get_terminal_size
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
def main():
mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
opt = {}
opt['n_thread'] = 20
opt['compression_level'] = 3 # 3 is the default value in cv2
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['input_folder'] = r'E:\A'
opt['save_folder'] = 'E:\dataset_new'
opt['crop_sz'] = 480 # the size of each sub-image
opt['step'] = 240 # step of the sliding crop window
opt['thres_sz'] = 48 # size threshold
extract_signle(opt)
elif mode == 'pair':
GT_folder = r'E:\A\val\B'
save_GT_folder = 'E:\\dataset_new\\val\\B'
crop_sz = 256 # the size of each sub-image (GT)
step = 256 # step of the sliding crop window (GT)
thres_sz = 256 # size threshold
img_GT_list = _get_paths_from_images(GT_folder)
print('process GT...')
opt['input_folder'] = GT_folder
opt['save_folder'] = save_GT_folder
opt['crop_sz'] = crop_sz
opt['step'] = step
opt['thres_sz'] = thres_sz
extract_signle(opt)
else:
raise ValueError('Wrong mode.')
def extract_signle(opt):
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print('mkdir [{:s}] ...'.format(save_folder))
else:
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
sys.exit(1)
img_list = _get_paths_from_images(input_folder)
def update(arg):
pbar.update(arg)
pbar = ProgressBar(len(img_list))
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(worker, args=(path, opt), callback=update)
pool.close()
pool.join()
print('All subprocesses done.')
def worker(path, opt):
crop_sz = opt['crop_sz']
step = opt['step']
thres_sz = opt['thres_sz']
img_name = osp.basename(path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
n_channels = len(img.shape)
if n_channels == 2:
h, w = img.shape
elif n_channels == 3:
h, w, c = img.shape
else:
raise ValueError('Wrong image shape - {}'.format(n_channels))
h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz:
h_space = np.append(h_space, h - crop_sz)
w_space = np.arange(0, w - crop_sz + 1, step)
if w - (w_space[-1] + crop_sz) > thres_sz:
w_space = np.append(w_space, w - crop_sz)
index = 0
for x in h_space:
for y in w_space:
index += 1
if n_channels == 2:
crop_img = img[x:x + crop_sz, y:y + crop_sz]
else:
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
crop_img = np.ascontiguousarray(crop_img)
cv2.imwrite(
osp.join(opt['save_folder'],
img_name.replace('.png', '.png'.format(index))), crop_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
return 'Processing {:s} ...'.format(img_name)
class ProgressBar(object):
'''A progress bar which can print the progress
modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
'''
def __init__(self, task_num=0, bar_width=50, start=True):
self.task_num = task_num
max_bar_width = self._get_max_bar_width()
self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
self.completed = 0
if start:
self.start()
def _get_max_bar_width(self):
terminal_width, _ = get_terminal_size()
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
if max_bar_width < 10:
print('terminal width is too small ({}), please consider widen the terminal for better '
'progressbar visualization'.format(terminal_width))
max_bar_width = 10
return max_bar_width
def start(self):
if self.task_num > 0:
sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
' ' * self.bar_width, self.task_num, 'Start...'))
else:
sys.stdout.write('completed: 0, elapsed: 0s')
sys.stdout.flush()
self.start_time = time.time()
def update(self, msg='In progress...'):
self.completed += 1
elapsed = time.time() - self.start_time + 1e-9
fps = self.completed / elapsed
if self.task_num > 0:
percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
mark_width = int(self.bar_width * percentage)
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
sys.stdout.write('\033[2F') # cursor up 2 lines
sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
else:
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
self.completed, int(elapsed + 0.5), fps))
sys.stdout.flush()
# ###################
# ### Data Utils ####
# ###################
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def _get_paths_from_images(path):
"""get image path list from image folder"""
assert osp.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
print("..fname is:", fname)
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format(path)
return images
if __name__ == '__main__':
main()