关于[::-1],[:, ::-1],[:, :, ::-1]的区别:
from PIL import Image
import imageio
import numpy as np
img = np.arange(48).reshape(4,4,3)
print("no.1",img[::-1])
print("no.2",img[:,::-1])
print("no.3",img[:,:,::-1])
输出长这样:
no.1 [[[36 37 38]
[39 40 41]
[42 43 44]
[45 46 47]]
[[24 25 26]
[27 28 29]
[30 31 32]
[33 34 35]]
[[12 13 14]
[15 16 17]
[18 19 20]
[21 22 23]]
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]]
no.2 [[[ 9 10 11]
[ 6 7 8]
[ 3 4 5]
[ 0 1 2]]
[[21 22 23]
[18 19 20]
[15 16 17]
[12 13 14]]
[[33 34 35]
[30 31 32]
[27 28 29]
[24 25 26]]
[[45 46 47]
[42 43 44]
[39 40 41]
[36 37 38]]]
no.3 [[[ 2 1 0]
[ 5 4 3]
[ 8 7 6]
[11 10 9]]
[[14 13 12]
[17 16 15]
[20 19 18]
[23 22 21]]
[[26 25 24]
[29 28 27]
[32 31 30]
[35 34 33]]
[[38 37 36]
[41 40 39]
[44 43 42]
[47 46 45]]]
可以看出来[::-1]指的是将4*4*3中第0维即第一个4逆序读取,体现在矩阵里就是每一个4(第二个4)*3小矩阵的整体顺序被倒置了,即第一个4*3矩阵被排到了最后;
[:, ::-1]指的是第1维逆序读取,体现在矩阵里就是每一个4*3小矩阵里面,12个数的顺序被倒置了,但每个4*3矩阵本身顺序没有变化;
[:, :, ::-1]指的是第2维逆序读取,体现在矩阵里就是每一个4*3矩阵的每一行,3个数的顺序被倒置了,其他12个数或是矩阵本身顺序还是正确的。
import os
import torch
import numpy as np
import scipy.misc as m
from torch.utils import data
from ptsemseg.utils import recursive_glob
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale
class cityscapesLoader(data.Dataset):# 重写dataset类,定义初始化和获取路径以及大小等
"""cityscapesLoader
https://www.cityscapes-dataset.com
Data is derived from CityScapes, and can be downloaded from here:
https://www.cityscapes-dataset.com/downloads/
Many Thanks to @fvisin for the loader repo:
https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py
"""
colors = [ # [ 0, 0, 0],
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[0, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
]
label_colours = dict(zip(range(19), colors))# 将MASK的颜色和编号一一对应
mean_rgb = {
"pascal": [103.939, 116.779, 123.68],
"cityscapes": [0.0, 0.0, 0.0],
} # pascal mean for PSPNet and ICNet pre-trained model
def __init__(
self,
root,
split="train",#确定当前使用的是train/val/test文件夹
is_transform=False,
img_size=(512, 1024),
augmentations=None,
img_norm=True,
version="cityscapes",
test_mode=False,
):
"""__init__
:param root:
:param split:
:param is_transform:
:param img_size:
:param augmentations
"""
self.root = root
self.split = split
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.n_classes = 19
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array(self.mean_rgb[version])
self.files = {}
self.images_base = os.path.join(self.root, "leftImg8bit", self.split)# 训练集文件夹路径
self.annotations_base = os.path.join(self.root, "gtFine", self.split)# GT文件夹路径
self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".png")# 递归检索对应图片
self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]# 无效类别
self.valid_classes = [
7,
8,
11,
12,
13,
17,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
31,
32,
33,
]
self.class_names = [
"unlabelled",
"road",
"sidewalk",
"building",
"wall",
"fence",
"pole",
"traffic_light",
"traffic_sign",
"vegetation",
"terrain",
"sky",
"person",
"rider",
"car",
"truck",
"bus",
"train",
"motorcycle",
"bicycle",
]
self.ignore_index = 250
self.class_map = dict(zip(self.valid_classes, range(19)))# 有效类别生成对应编号
if not self.files[split]:
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
print("Found %d %s images" % (len(self.files[split]), split))
def __len__(self):
"""__len__"""
return len(self.files[self.split])
def __getitem__(self, index):# dataset这个类返回的就是这个函数返回的值
"""__getitem__
:param index:
"""
img_path = self.files[self.split][index].rstrip()
lbl_path = os.path.join(
self.annotations_base,
img_path.split(os.sep)[-2],
os.path.basename(img_path)[:-15] + "gtFine_labelIds.png",
)
img = m.imread(img_path)
img = np.array(img, dtype=np.uint8)# h*w*c
lbl = m.imread(lbl_path)
lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8))
if self.augmentations is not None:
img, lbl = self.augmentations(img, lbl)
if self.is_transform:
img, lbl = self.transform(img, lbl)
return img, lbl# 返回了经过增强和转换的数据
def transform(self, img, lbl):
"""transform
:param img:
:param lbl:
"""
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR,::-1是翻转读取,此处指将通道那一维逆序读取,方便OpenCV处理
img = img.astype(np.float64)
img -= self.mean
if self.img_norm:
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# HWC -> CHW
img = img.transpose(2, 0, 1)
classes = np.unique(lbl)
lbl = lbl.astype(float)
lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
# 按输入图片比例缩放标注图,因为后面是直接拿标注图当mask盖在原图上对比,要求大小一致
lbl = lbl.astype(int)
if not np.all(classes == np.unique(lbl)):# 对比标注图类别和设定类别数量是否一致
print("WARN: resizing labels yielded fewer classes")
if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes):
print("after det", classes, np.unique(lbl))
raise ValueError("Segmentation map contained invalid class values")
img = torch.from_numpy(img).float()
# 生成返回的tensor会和ndarry共享数据,任何对tensor的操作都会影响到ndarry
lbl = torch.from_numpy(lbl).long()
return img, lbl
def decode_segmap(self, temp):# 根据类别上色
r = temp.copy()# h*w
g = temp.copy()
b = temp.copy()
for l in range(0, self.n_classes):
r[temp == l] = self.label_colours[l][0]
g[temp == l] = self.label_colours[l][1]
b[temp == l] = self.label_colours[l][2]
# r,g,b都是和输入(h*w)大小一致的图,对应类别的像素上分别只有该类别的r,g和b值
rgb = np.zeros((temp.shape[0], temp.shape[1], 3))# 生成h*w*3的矩阵,因为mask只有RGB三个通道
rgb[:, :, 0] = r / 255.0# 矩阵每一行的3个数(0,1,2),分别填入R,G,B的数值
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
return rgb
def encode_segmap(self, mask):# 去除无效类
# Put all void classes to zero
for _voidc in self.void_classes:
mask[mask == _voidc] = self.ignore_index
for _validc in self.valid_classes:
mask[mask == _validc] = self.class_map[_validc]
return mask
if __name__ == "__main__":
import matplotlib.pyplot as plt
augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip(0.5)])
local_path = "/datasets01/cityscapes/112817/"
dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations)
bs = 4
trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
for i, data_samples in enumerate(trainloader):
imgs, labels = data_samples# 因为cityscapes对象返回了输入图像和标注图像
import pdb
pdb.set_trace()
imgs = imgs.numpy()[:, ::-1, :, :]# NCHW,将BGR换回RGB
imgs = np.transpose(imgs, [0, 2, 3, 1])# NHWC,每一列是R,G,B
f, axarr = plt.subplots(bs, 2)
for j in range(bs):
axarr[j][0].imshow(imgs[j])
axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
plt.show()
a = input()
if a == "ex":
break
else:
plt.close()