第一章deeplabv3+源码之慢慢解析 根目录(1)main.py–get_argparser函数
第一章deeplabv3+源码之慢慢解析 根目录(2)main.py–get_dataset函数
第一章deeplabv3+源码之慢慢解析 根目录(3)main.py–validate函数
第一章deeplabv3+源码之慢慢解析 根目录(4)main.py–main函数
第一章deeplabv3+源码之慢慢解析 根目录(5)predict.py–get_argparser函数和main函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(1)voc.py–voc_cmap函数和download_extract函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(2)voc.py–VOCSegmentation类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(3)cityscapes.py–Cityscapes类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(4)utils.py–6个小函数
第三章deeplabv3+源码之慢慢解析 metrics文件夹stream_metrics.py–StreamSegMetrics类和AverageMeter类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a1)hrnetv2.py–4个函数和可执行代码
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a2)hrnetv2.py–Bottleneck类和BasicBlock类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a3)hrnetv2.py–StageModule类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a4)hrnetv2.py–HRNet类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b1)mobilenetv2.py–2个类和2个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b2)mobilenetv2.py–MobileNetV2类和mobilenet_v2函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c1)resnet.py–2个基础函数,BasicBlock类和Bottleneck类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c2)resnet.py–ResNet类和10个不同结构的调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d1)xception.py–SeparableConv2d类和Block类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d2)xception.py–Xception类和xception函数
第四章deeplabv3+源码之慢慢解析 network文件夹(2)_deeplab.py–ASPP相关的4个类和1个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(3)_deeplab.py–DeepLabV3类,DeepLabHeadV3Plus类和DeepLabHead类
第四章deeplabv3+源码之慢慢解析 network文件夹(4)modeling.py–5个私有函数(4个骨干网,1个模型载入)
第四章deeplabv3+源码之慢慢解析 network文件夹(5)modeling.py–12个调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(6)utils.py–_SimpleSegmentationModel类和IntermediateLayerGetter类
第五章deeplabv3+源码之慢慢解析 utils文件夹(1)ext_transforms.py.py–[17个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(2)loss.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(3)scheduler.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(4)utils.py–[1个类,4个函数]
第五章deeplabv3+源码之慢慢解析 utils文件夹(5)visualizer.py–[1个类]
总结
from torch.utils.data import dataset #因直接预测,需要使用数据部分
#以下同main.py
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np
#以下是数据部分所需
from torch.utils import data
from datasets import VOCSegmentation, Cityscapes, cityscapes
from torchvision import transforms as T
from metrics import StreamSegMetrics
#以下是神经网络所需
import torch
import torch.nn as nn
#以下是可视化和图片操作所需
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from glob import glob
提示:看过main.py部分的get_argparser函数,再看下面的内容。是不是有种似曾相识的感觉?代码学习要多积累,切记切记。
def get_argparser():
parser = argparse.ArgumentParser()
# Datset Options
#input参数,此处输入预测文件夹的路径,建议即使是单一预测图像也放在文件夹中,养成良好的路径管理习惯。如本代码测试数据选用了samples文件夹下的图像,结果图像放在自建的test_results文件夹。
parser.add_argument("--input", type=str, required=True,
help="path to a single image or image directory")
parser.add_argument("--dataset", type=str, default='voc',
choices=['voc', 'cityscapes'], help='Name of training set')#同main.py的get_argparser函数第10行。
# Deeplab Options
available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \
not (name.startswith("__") or name.startswith('_')) and callable(
network.modeling.__dict__[name])
)#同main.py的get_argparser函数第19行。
parser.add_argument("--model", type=str, default='deeplabv3plus_mobilenet',
choices=available_models, help='model name')#同main.py的get_argparser函数第26行。
parser.add_argument("--separable_conv", action='store_true', default=False,
help="apply separable conv to decoder and aspp")#同main.py的get_argparser函数第30行。
parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16])#同main.py的get_argparser函数第32行。
# Train Options
parser.add_argument("--save_val_results_to", default=None,
help="save segmentation results to the specified dir")#此处开启了验证保存。这里输入的是保存结果的目录。
parser.add_argument("--crop_val", action='store_true', default=False,
help='crop validation (default: False)')#同main.py的get_argparser函数第50行。
parser.add_argument("--val_batch_size", type=int, default=4,
help='batch size for validation (default: 4)')#同main.py的get_argparser函数第54行。
parser.add_argument("--crop_size", type=int, default=513)#同main.py的get_argparser函数第57行。
parser.add_argument("--ckpt", default=None, type=str,
help="resume from checkpoint")#同main.py的get_argparser函数第58行。
parser.add_argument("--gpu_id", type=str, default='0',
help="GPU ID")#同main.py的get_argparser函数第65行。
return parser
提示:同样,对照main.py部分的main函数,更容易理解此处的代码。
def main():
opts = get_argparser().parse_args() #同main.py
if opts.dataset.lower() == 'voc': #数据集选择
opts.num_classes = 21
decode_fn = VOCSegmentation.decode_target #使用解码后的数据
elif opts.dataset.lower() == 'cityscapes':
opts.num_classes = 19
decode_fn = Cityscapes.decode_target
os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id #GPU选择
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: %s" % device)
# Setup dataloader
image_files = []
if os.path.isdir(opts.input): #从get_argparser函数中的input参数获得路径,逐一添加图像
for ext in ['png', 'jpeg', 'jpg', 'JPEG']:
files = glob(os.path.join(opts.input, '**/*.%s'%(ext)), recursive=True)
if len(files)>0:
image_files.extend(files)
elif os.path.isfile(opts.input):
image_files.append(opts.input)
# Set up model (all models are 'constructed at network.modeling)
model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) #同main.py中main函数第37行。
if opts.separable_conv and 'plus' in opts.model: #同main.py中main函数第38行。
network.convert_to_separable_conv(model.classifier) #同main.py中main函数第39行。
utils.set_bn_momentum(model.backbone, momentum=0.01) #同main.py中main函数第40行。
if opts.ckpt is not None and os.path.isfile(opts.ckpt): #以下同main.py中main函数第81-86行。(是81-98行的简化)
# https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state"])
model = nn.DataParallel(model)
model.to(device)
print("Resume model from %s" % opts.ckpt)
del checkpoint
else:
print("[!] Retrain")
model = nn.DataParallel(model)
model.to(device)
#denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images
if opts.crop_val: #裁剪验证,对数据进行尺寸变化
transform = T.Compose([
T.Resize(opts.crop_size),
T.CenterCrop(opts.crop_size),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
else:
transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
if opts.save_val_results_to is not None: #新建验证结果保存文件夹
os.makedirs(opts.save_val_results_to, exist_ok=True)
with torch.no_grad():
model = model.eval() #预测,不训练,关闭梯度更新。此部分的理解,可以参考main.py中validate函数20-28行。
for img_path in tqdm(image_files):
ext = os.path.basename(img_path).split('.')[-1]
img_name = os.path.basename(img_path)[:-len(ext)-1]
img = Image.open(img_path).convert('RGB')
img = transform(img).unsqueeze(0) # To tensor of NCHW
#对上一句的补充,unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度,比如原本有个三行的数据(3),unsqueeze(0)后就会在0的位置加了一维就变成一行三列(1,3)。
img = img.to(device)
pred = model(img).max(1)[1].cpu().numpy()[0] # HW #可以参考main.py中validate函数27行。
colorized_preds = decode_fn(pred).astype('uint8') #使用上一句得到的图像索引(pred),在解码目标(本main函数第5行)中得到对应的图像。
colorized_preds = Image.fromarray(colorized_preds)
if opts.save_val_results_to:
colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png')) #将验证结果(即得到的解码图像)保存到前面指定的文件夹。
Tips
解析参数函数get_argparser函数,对比main.py中的get_argparser函数,每一条都标出了注释,可逐一对理解,不同的也做出了解释,相对简单。
本文的main函数中可借鉴main.py中的validate和main函数进行学习,相对容易理解。
根目录下的两个代码已解析完毕。按显示的顺序,下一个介绍datasets文件夹。