因有一部分实验用到STANet网络,在网上找到相应的代码,花了大概一周一步步跳入坑、填坑的过程,苦于将其跑通,遂记录如下心得,希望能够帮助有需要的小伙伴避开“雷区”!
文章源于:
代码源于:
大致看了一下该篇文章,网上有很多解读的博客,不做过多介绍,简而言之,该文章通过利用自注意力机制模块(BAM)和(多个BAM集成的PAM块),对遥感影像进行特征提取与训练, 通过对比两张不同时期的遥感图像,以深度学习的方法训练模型,最后能够“自动比对”找出同一区域,不同时间的变化情况。下图是STANet文章的截图。
文章能够显著检测出遥感影像中变化的建筑物,可以应用于违章建筑拓展监测、乡村扶贫振兴和生态移民居住保障的风貌变化程度。
相关的数据集包括(train、每train一轮epch之后紧接着验证val集,还有训练结束之后,将保存的model进行测试的test集 (PS: 文章代码的测试部分,称为val,python val.py 就是测试,而不是验证))。
每一个数据集中包括:
-----------| A:前一段时间的遥感图像(1024 * 1024);
-----------| B:后一段时间的相同区域的遥感图像(1024*1024) ;
-----------| label:标注好两幅遥感图像之间存在的变化,因为数据中考虑一个类别(建筑物)的变化情况,以二值图形式(黑白)进行展示(1024*1024))。
!命名一定要一致!
不然可能在测试的时候,会出现:AssertionError: X and Y should be the same shape
import os
import os.path as osp
import sys
from multiprocessing import Pool
import numpy as np
import cv2
from PIL import Image
import time
from shutil import get_terminal_size
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
def main():
mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
opt = {}
opt['n_thread'] = 20
opt['compression_level'] = 3 # 3 is the default value in cv2
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['input_folder'] = './data/DIV2K/DIV2K_train_HR'
opt['save_folder'] = './data/DIV2K/DIV2K800_sub'
opt['crop_sz'] = 480 # the size of each sub-image
opt['step'] = 240 # step of the sliding crop window
opt['thres_sz'] = 48 # size threshold
extract_signle(opt)
elif mode == 'pair':
GT_folder = '/home/cug210/data/Lover/code/STANet-master/LEVIR-CD/test/B'
save_GT_folder = '/home/cug210/data/Lover/code/STANet-master/LEVIR-CD/test_256/B'
crop_sz = 256 # the size of each sub-image (GT)
step = 256 # step of the sliding crop window (GT)
thres_sz = 256 # size threshold
img_GT_list = _get_paths_from_images(GT_folder)
print('process GT...')
opt['input_folder'] = GT_folder
opt['save_folder'] = save_GT_folder
opt['crop_sz'] = crop_sz
opt['step'] = step
opt['thres_sz'] = thres_sz
extract_signle(opt)
else:
raise ValueError('Wrong mode.')
def extract_signle(opt):
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print('mkdir [{:s}] ...'.format(save_folder))
else:
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
sys.exit(1)
img_list = _get_paths_from_images(input_folder)
def update(arg):
pbar.update(arg)
pbar = ProgressBar(len(img_list))
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(worker, args=(path, opt), callback=update)
pool.close()
pool.join()
print('All subprocesses done.')
def worker(path, opt):
crop_sz = opt['crop_sz']
step = opt['step']
thres_sz = opt['thres_sz']
img_name = osp.basename(path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
n_channels = len(img.shape)
if n_channels == 2:
h, w = img.shape
elif n_channels == 3:
h, w, c = img.shape
else:
raise ValueError('Wrong image shape - {}'.format(n_channels))
h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz:
h_space = np.append(h_space, h - crop_sz)
w_space = np.arange(0, w - crop_sz + 1, step)
if w - (w_space[-1] + crop_sz) > thres_sz:
w_space = np.append(w_space, w - crop_sz)
index = 0
for x in h_space:
for y in w_space:
index += 1
if n_channels == 2:
crop_img = img[x:x + crop_sz, y:y + crop_sz]
else:
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
crop_img = np.ascontiguousarray(crop_img)
cv2.imwrite(
osp.join(opt['save_folder'],
img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
return 'Processing {:s} ...'.format(img_name)
class ProgressBar(object):
'''A progress bar which can print the progress
modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
'''
def __init__(self, task_num=0, bar_width=50, start=True):
self.task_num = task_num
max_bar_width = self._get_max_bar_width()
self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
self.completed = 0
if start:
self.start()
def _get_max_bar_width(self):
terminal_width, _ = get_terminal_size()
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
if max_bar_width < 10:
print('terminal width is too small ({}), please consider widen the terminal for better '
'progressbar visualization'.format(terminal_width))
max_bar_width = 10
return max_bar_width
def start(self):
if self.task_num > 0:
sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
' ' * self.bar_width, self.task_num, 'Start...'))
else:
sys.stdout.write('completed: 0, elapsed: 0s')
sys.stdout.flush()
self.start_time = time.time()
def update(self, msg='In progress...'):
self.completed += 1
elapsed = time.time() - self.start_time + 1e-9
fps = self.completed / elapsed
if self.task_num > 0:
percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
mark_width = int(self.bar_width * percentage)
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
sys.stdout.write('\033[2F') # cursor up 2 lines
sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
else:
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
self.completed, int(elapsed + 0.5), fps))
sys.stdout.flush()
# ###################
# ### Data Utils ####
# ###################
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def _get_paths_from_images(path):
"""get image path list from image folder"""
assert osp.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
print("..fname is:",fname)
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format(path)
return images
if __name__ == '__main__':
main()