这部分代码很坑,原作者代码里若不更改命令行参数norm,则会进行两次标准化
import os
import torch
import argparse
import numpy as np
import scipy.misc as misc
from ptsemseg.models import get_model
from ptsemseg.loader import get_loader
from ptsemseg.utils import convert_state_dict
try:
import pydensecrf.densecrf as dcrf
except:
print(
"Failed to import pydensecrf,\
CRF post-processing will not work"
)# 导入CRF后处理
def test(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_file_name = os.path.split(args.model_path)[1]# 命令行传参,模型路径
model_name = model_file_name[: model_file_name.find("_")]
# Setup image
print("Read Input Image from : {}".format(args.img_path))# 图片路径
img = misc.imread(args.img_path)
data_loader = get_loader(args.dataset)
loader = data_loader(root=None, is_transform=True, img_norm=args.img_norm, test_mode=True)
n_classes = loader.n_classes# 获取指定训练集的类别数
resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp="bicubic")
# 将图片变形成模型接受的尺寸
orig_size = img.shape[:-1]# 除了最后一个元素(通道)的切片,返回H*W
if model_name in ["pspnet", "icnet", "icnetBN"]:
# uint8 with RGB mode, resize width and height which are odd numbers
img = misc.imresize(img, (orig_size[0] // 2 * 2 + 1, orig_size[1] // 2 * 2 + 1))
else:
img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]))
# 个别网络的输入是原size+1
img = img[:, :, ::-1]# 最后一维逆序读取
img = img.astype(np.float64)
img -= loader.mean# 标准化,减去均值
if args.img_norm:
img = img.astype(float) / 255.0
# NHWC -> NCHW
img = img.transpose(2, 0, 1)
img = np.expand_dims(img, 0)# 增加一维
img = torch.from_numpy(img).float()
# Setup Model
model_dict = {"arch": model_name}
model = get_model(model_dict, n_classes, version=args.dataset)
state = convert_state_dict(torch.load(args.model_path)["model_state"])
# 读取了网络结构的名字和对应的参数
model.load_state_dict(state)
model.eval()# model.eval() :针对单张图片,不启用 BatchNormalization 和 Dropout
model.to(device)
images = img.to(device)
outputs = model(images)# n张图片*n个class的概率*h*w
if args.dcrf:
unary = outputs.data.cpu().numpy()
unary = np.squeeze(unary, 0)
unary = -np.log(unary)
unary = unary.transpose(2, 1, 0)
w, h, c = unary.shape
unary = unary.transpose(2, 0, 1).reshape(loader.n_classes, -1)
unary = np.ascontiguousarray(unary)
resized_img = np.ascontiguousarray(resized_img)
d = dcrf.DenseCRF2D(w, h, loader.n_classes)
d.setUnaryEnergy(unary)
d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=resized_img, compat=1)
q = d.inference(50)
mask = np.argmax(q, axis=0).reshape(w, h).transpose(1, 0)
decoded_crf = loader.decode_segmap(np.array(mask, dtype=np.uint8))
dcrf_path = args.out_path[:-4] + "_drf.png"
misc.imsave(dcrf_path, decoded_crf)
print("Dense CRF Processed Mask Saved at: {}".format(dcrf_path))
pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
# 输出h*w,从outputs中取了每个像素预测概率最大的那个值和索引位置,
# 其中outputs.data.max(1)[]中,返回值有两个,第一个是概率最大值组成的矩阵,
# 第二个是最大值所在维索引组成的矩阵,这里取得是第二个,即[1]
# outputs.data.max(1)[1].cpu().numpy()返回1*w*h矩阵,squeeze删除维度为1的维
if model_name in ["pspnet", "icnet", "icnetBN"]:
pred = pred.astype(np.float32)
# float32 with F mode, resize back to orig_size
pred = misc.imresize(pred, orig_size, "nearest", mode="F")
decoded = loader.decode_segmap(pred)# 得到Mask颜色图
print("Classes found: ", np.unique(pred))# 得到寻找到的类
misc.imsave(args.out_path, decoded)
print("Segmentation Mask Saved at: {}".format(args.out_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Params")
parser.add_argument(
"--model_path",
nargs="?",
type=str,
default="fcn8s_pascal_1_26.pkl",
help="Path to the saved model",
)
parser.add_argument(
"--dataset",
nargs="?",
type=str,
default="pascal",
help="Dataset to use ['pascal, camvid, ade20k etc']",
)
parser.add_argument(
"--img_norm",
dest="img_norm",
action="store_true",
help="Enable input image scales normalization [0, 1] \
| True by default",
)
parser.add_argument(
"--no-img_norm",
dest="img_norm",
action="store_false",
help="Disable input image scales normalization [0, 1] |\
True by default",
)
parser.set_defaults(img_norm=True)
parser.add_argument(
"--dcrf",
dest="dcrf",
action="store_true",
help="Enable DenseCRF based post-processing | \
False by default",
)
parser.add_argument(
"--no-dcrf",
dest="dcrf",
action="store_false",
help="Disable DenseCRF based post-processing | \
False by default",
)
parser.set_defaults(dcrf=False)
parser.add_argument(
"--img_path", nargs="?", type=str, default=None, help="Path of the input image"
)
parser.add_argument(
"--out_path", nargs="?", type=str, default=None, help="Path of the output segmap"
)
args = parser.parse_args()
test(args)