STANet 基于时空自注意力的遥感图像变化检测模型

STANet 基于时空自注意力的遥感图像变化检测模型

检测数据集LEVIR-CD

环境配置

//要求:
windows or Linux
Python 3.6+
CPU or NVIDIA GPU
CUDA 9.0+
PyTorch > 1.0
visdom==0.1.8.1
dominate
  • 我们采用Linux + cuda11.0 + python=3.8,相关安装命令如下:
//创建Python虚拟环境
conda create -n STANet_3.8 python==3.8    

//激活环境
conda activate STANet_3.8
    
//安装pytorch环境
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
//查看环境安装情况
pip install ipython
#(STANet_3.8) :~$ ipython
#Python 3.8.0 (default, Nov  6 2019, 21:49:08) 
#Type 'copyright', 'credits' or 'license' for more information
#IPython 7.21.0 -- An enhanced Interactive Python. Type '?' for help.
    
#In [1]:  import torch
    
#In [2]: torch.cuda.is_available()
#Out[2]: True
//出现以上内容表明环境已经安装好啦。    
    
//安装visdom
pip install visdom==0.1.8.1 
//安装dominate
pip install dominate

我们的环境应该是配置好啦,接着下载数据集让我们准备训练吧!

代码数据集下载

代码下载链接:https://github.com/justchenhao/STANet

数据集下载:https://justchenhao.github.io/LEVIR/

网络结构图:

img

baseline运行代码

#一共200个epoch。我们对前100个epoch保持相同的学习速率,并在剩余的100个epoch中将其线性衰减为0。
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDF0 --lr 0.001 --model CDF0 --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
#基础训练 val.py内容
if __name__ == '__main__':
    opt = TestOptions().parse()   # get training options
    opt = make_val_opt(opt)
    opt.phase = 'val'
    opt.dataroot = './LEVIR-CD/test'
    opt.dataset_mode = 'changedetection'
    opt.n_class = 2

    #opt.SA_mode = 'PAM'
    opt.arch = 'mynet3'
    opt.model = 'CDF0'
    opt.name = 'LEVIR-CDF0'
    opt.results_dir = './results/'
    opt.epoch = '160_F1_1_0.77180'  //训练后的权重文件
    opt.num_test = np.inf

BAM运行代码

python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFA0 --lr 0.001 --model CDFA --SA_mode BAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop

PAM运行代码

#增加金字塔时空注意力模块PAM
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFAp0 --lr 0.001 --model CDFA --SA_mode PAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop

测试

# 按照下面的例子修改val.py即可
if __name__ == '__main__':
    opt = TestOptions().parse()   # get training options
    opt = make_val_opt(opt)
    opt.phase = 'test'
    opt.dataroot = 'path-to-LEVIR-CD-test' # data root 
    opt.dataset_mode = 'changedetection'
    opt.n_class = 2
    opt.SA_mode = 'PAM' # BAM | PAM 
    opt.arch = 'mynet3'
    opt.model = 'CDFA' # model type
    opt.name = 'LEVIR-CDFAp0' # project name
    opt.results_dir = './results/' # save predicted images 
    opt.epoch = 'best-epoch-in-val' # which epoch to test
    opt.num_test = np.inf
    val(opt)

你可能感兴趣的:(环境配置,pytorch,深度学习,python)