下面是全部的代码:
import os
import torch
import numpy as np
import scipy.misc as m
from PIL import Image
from torch.utils import data
from dataloaders.utils import recursive_glob, decode_segmap
from mypath import Path
class CityscapesSegmentation(data.Dataset):
def __init__(self, root=Path.db_root_dir('cityscapes'), split="train", transform=None):
self.root = root
self.split = split
self.transform = transform
self.files = {}
self.n_classes = 19
self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
self.annotations_base = os.path.join(self.root, 'gtFine', self.split)
self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')
self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] # 16
self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] # 19
self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \
'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \
'motorcycle', 'bicycle'] # 20
self.ignore_index = 255
self.class_map = dict(zip(self.valid_classes, range(self.n_classes)))
if not self.files[split]:
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
print("Found %d %s images" % (len(self.files[split]), split))
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
img_path = self.files[self.split][index].rstrip()
lbl_path = os.path.join(self.annotations_base,
img_path.split(os.sep)[-2], # os.sep=='/' get city name
os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')
_img = Image.open(img_path).convert('RGB')
_tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
_tmp = self.encode_segmap(_tmp)
_target = Image.fromarray(_tmp)
sample = {'image': _img, 'label': _target}
if self.transform: # to do Data transformation or Data enhancement and convert torch
sample = self.transform(sample)
return sample
def encode_segmap(self, mask): # to change original image pixel value to 0-18 and 255 according class id
# Put all void classes to zero
for _voidc in self.void_classes:
mask[mask == _voidc] = self.ignore_index # no need class and unto set 255 (white)
for _validc in self.valid_classes:
mask[mask == _validc] = self.class_map[_validc] # 19 classes encode from 0 to 18
return mask
if __name__ == '__main__':
from dataloaders import custom_transforms as tr
from dataloaders.utils import decode_segmap
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt # to show image
composed_transforms_tr = transforms.Compose([
tr.RandomHorizontalFlip(),
tr.RandomScale((0.5, 0.75)),
tr.RandomCrop((512, 1024)),
tr.RandomRotate(5),
tr.ToTensor()])
cityscapes_train = CityscapesSegmentation(split='train',
transform=composed_transforms_tr)
dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2)
for ii, sample in enumerate(dataloader):
for jj in range(sample["image"].size()[0]):
img = sample['image'].numpy() # from torch convert to numpy n x c x h x w
gt = sample['label'].numpy() # from torch convert to numpy n x c x h x w
tmp = np.array(gt[jj]).astype(np.uint8) # tmp.shape=c x h x w
tmp = np.squeeze(tmp, axis=0) # if c=1,tmp.shape=c x h x w; or tmp.shape=c x h x w
segmap = decode_segmap(tmp, dataset='cityscapes')
img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) # img_tmp=h x w x c
plt.figure()
plt.title('display')
plt.subplot(211)
plt.imshow(img_tmp)
plt.subplot(212)
plt.imshow(segmap)
if ii == 1:
break
plt.show(block=True)
下面怎么读取图片的 可以参考:https://blog.csdn.net/zz2230633069/article/details/84640867
self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')
转换的为:
composed_transforms_tr = transforms.Compose([
tr.RandomHorizontalFlip(),
tr.RandomScale((0.5, 0.75)),
tr.RandomCrop((512, 1024)),
tr.RandomRotate(5),
tr.ToTensor()])
上面关于图像变换或者说增强的实现代码如下:
上面的前四个变换都保持了原图和标签的type为PIL.PngImagePlugin.PngImageFile,这些图的像素数值大小和类型(uint8)不发生改变,结构也没有变化(原图为h x w x 3,标签图为h x w)
class RandomHorizontalFlip(object):
def __call__(self, sample):
img = sample['image']
mask = sample['label']
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return {'image': img,
'label': mask}
class RandomScale(object):
def __init__(self, limit):
self.limit = limit
def __call__(self, sample):
img = sample['image']
mask = sample['label']
assert img.size == mask.size
scale = random.uniform(self.limit[0], self.limit[1])
w = int(scale * img.size[0])
h = int(scale * img.size[1])
img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)
return {'image': img, 'label': mask}
class RandomCrop(object):
def __init__(self, size, padding=0):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size # h, w
self.padding = padding
def __call__(self, sample):
img, mask = sample['image'], sample['label']
if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0)
mask = ImageOps.expand(mask, border=self.padding, fill=0)
assert img.size == mask.size
w, h = img.size
th, tw = self.size # target size
if w == tw and h == th:
return {'image': img,
'label': mask}
if w < tw or h < th:
img = img.resize((tw, th), Image.BILINEAR)
mask = mask.resize((tw, th), Image.NEAREST)
return {'image': img,
'label': mask}
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
img = img.crop((x1, y1, x1 + tw, y1 + th))
mask = mask.crop((x1, y1, x1 + tw, y1 + th))
return {'image': img,
'label': mask}
class RandomRotate(object):
def __init__(self, degree):
self.degree = degree
def __call__(self, sample):
img = sample['image']
mask = sample['label']
rotate_degree = random.random() * 2 * self.degree - self.degree
img = img.rotate(rotate_degree, Image.BILINEAR)
mask = mask.rotate(rotate_degree, Image.NEAREST)
return {'image': img,
'label': mask}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))
mask[mask == 255] = 0 #
img = torch.from_numpy(img).float()
mask = torch.from_numpy(mask).float()
return {'image': img,
'label': mask}
直到第五个也就是最后一个(ToTensor函数)变化,对原图首先从PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w x c )到(c x h x w),最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将原图转变为一个tensor可以输入后面的深度学习网络中了。
与此相对的标签图也是从
PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w )增加一维得到(h x w x 1)接着调整维度到(1 x h x w),然后mask里面的数值进行处理:255.值大小的全部被重置为0,所以mask里面的值现在只有0-18这些数字了;最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将标签图转变为一个tensor可以输入后面的深度学习网络中了。
对上面的两个tensor的重新变成图像的代码如下:
for ii, sample in enumerate(dataloader):
for jj in range(sample["image"].size()[0]):
img = sample['image'].numpy() # from torch convert to numpy n x 3 x h x w
gt = sample['label'].numpy() # from torch convert to numpy n x 1 x h x w
tmp = np.array(gt[jj]).astype(np.uint8) # tmp.shape=1 x h x w
tmp = np.squeeze(tmp, axis=0) # if c=1,tmp.shape=h x w; or tmp.shape=c x h x w dimension-reduction
segmap = decode_segmap(tmp, dataset='cityscapes')
img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) # img_tmp=h x w x 3
plt.figure()
plt.title('display')
plt.subplot(211)
plt.imshow(img_tmp)
plt.subplot(212)
plt.imshow(segmap)
if ii == 1:
break
plt.show(block=True)
里面的标签图(h x w)解码代码如下:
只要是同一类的就给相应的RGB数值,然后整合三张图到一张图里面
segmap = decode_segmap(tmp, dataset='cityscapes') # tmp.shape=h x w
def decode_segmap(label_mask, dataset, plot=False):
"""Decode segmentation class labels into a color image
Args:
label_mask (np.ndarray): an (M,N) array of integer values denoting
the class label at each spatial location.
plot (bool, optional): whether to show the resulting color image
in a figure.
Returns:
(np.ndarray, optional): the resulting decoded color image.
"""
if dataset == 'pascal':
n_classes = 21
label_colours = get_pascal_labels()
elif dataset == 'cityscapes':
n_classes = 19
label_colours = get_cityscapes_labels()
else:
raise NotImplementedError
r = label_mask.copy() # h x w
g = label_mask.copy() # h x w
b = label_mask.copy() # h x w
for ll in range(0, n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) # h x w x 3初始化
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
if plot:
plt.imshow(rgb)
plt.show()
else:
return rgb
下面就是label_colours的和类别对应色彩代码详情可以看cityscapes的标签颜色对照表https://blog.csdn.net/zz2230633069/article/details/84591532:
def get_cityscapes_labels():
return np.array([
# [ 0, 0, 0],
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[0, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32]])
def get_pascal_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (21, 3)
"""
return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]])