2021SC@SDUSC
这篇分析plot.py文件,就如其名称一样,主要是一些用以展示的代码,也不是核心代码
from copy import copy
from pathlib import Path
import cv2
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
from PIL import Image, ImageDraw, ImageFont
from utils.general import user_config_dir, is_ascii, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness
copy:用于对象的拷贝操作,该模块只提供了两个主要的方法,cpoy.cpoy与cpoy.deepcopy,分别表示浅复制和深复制
Path,cv2,math,numpy,pandas在general.py中已经介绍过了
matplotlib:是python最著名的绘图库,提供了一整套和matlab相似的命令API,是这个文件的主要外部库
seaborn:基于matplotlib的python可视化库,是在matplotlib的基础上进行了更高级的API封装。
class Colors:
# Ultralytics color palette https://ultralytics.com/
def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values()
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb('#' + c) for c in hex]
self.n = len(self.palette)
def __call__(self, i, bgr=False):
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c
@staticmethod
def hex2rgb(h): # rgb order (PIL)
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
hex:十六进制格式的颜色
palette:rgbgeshideyanse
n:数组长度
hex2rgb函数将以十六进制表示的颜色转换为RGB格式
call函数在调用时返回索引为i的颜色,当i超过n时用i模n的索引来取得颜色
def check_font(font='Arial.ttf', size=10):
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
font = Path(font)
font = font if font.exists() else (CONFIG_DIR / font.name)
try:
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
except Exception as e: # download if missing
url = "https://ultralytics.com/assets/" + font.name
print(f'Downloading {url} to {font}...')
torch.hub.download_url_to_file(url, str(font))
return ImageFont.truetype(str(font), size)
font:检查的字体
该函数检查有否有对应的字体文件,没有从网上下载到对应的路径
PIL的ImageFont模块定义了相同名称的类,即ImageFont类。这个类的实力存储bitmap字体,用于ImageDraw类的text()方法,不多讲解,感兴趣的可以参考ImageFont 模块 — Pillow (PIL Fork) 8.4.0 文档
class Annotator:
check_font() # download TTF if necessary
# YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=True):
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
self.pil = pil
if self.pil: # use PIL
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im)
self.font = check_font(font, size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
self.fh = self.font.getsize('a')[1] - 3 # font height
else: # use cv2
self.im = im
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
# Add one xyxy box to image with label
if self.pil or not is_ascii(label):
self.draw.rectangle(box, width=self.lw, outline=color) # box
if label:
w, h = self.font.getsize(label) # text width
self.draw.rectangle([box[0], box[1] - self.fh, box[0] + w + 1, box[1] + 1], fill=color)
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
self.draw.text((box[0], box[1] - h), label, fill=txt_color, font=self.font)
else: # cv2
c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(self.im, c1, c2, color, thickness=self.lw, lineType=cv2.LINE_AA)
if label:
tf = max(self.lw - 1, 1) # font thickness
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0]
c2 = c1[0] + w, c1[1] - h - 3
cv2.rectangle(self.im, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(self.im, label, (c1[0], c1[1] - 2), 0, self.lw / 3, txt_color, thickness=tf,
lineType=cv2.LINE_AA)
def rectangle(self, xy, fill=None, outline=None, width=1):
# Add rectangle to image (PIL-only)
self.draw.rectangle(xy, fill, outline, width)
def text(self, xy, text, txt_color=(255, 255, 255)):
# Add text to image (PIL-only)
w, h = self.font.getsize(text) # text width, height
self.draw.text((xy[0], xy[1] - h + 1), text, fill=txt_color, font=self.font)
def result(self):
# Return annotated image as array
return np.asarray(self.im)
init方法:
im:图片
line_width:线宽
font_size:字体大小
font:字体名称
pil:是否使用pillow
如果使用pillow,将图片格式转换为pillow的格式,fh为字体高度
ImageDraw提供简单的二维图像Image物体,可以使用此模块创建新图像、对现有图像进行注释或润色,具体参考ImageDraw 模块 — Pillow (PIL Fork) 8.4.0 文档
lw为线宽
box_label方法:向图片中增加一个xyxy的box,并且加上标签
box:xyxy的box
label:标签
无论使用PIL或者opencv都是在对图像加一个box,其格式是xyxy,即box左上角的点坐标和右下角点的坐标,并且标注box的标签
rectangle 方法:
向图像中画一个长方形
text方法:
向图像中添加box的标签
result方法:
返回最终的图像,其格式是numpy数组
该类实现了向图片中画出预测框并且添加标签
如图是经过操作后的图像,标注出了预测框以及预测出来的类别以及置信度
def hist2d(x, y, n=100):
# 2d histogram used in labels.png and evolve.png
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
return np.log(hist[xidx, yidx])
根据x,y的直方图分布,来返回绘制颜色,区间数量多的颜色更亮,反之更暗
x和y都是np数组
np.linspace(start,stop,num,endpoint,retstep,dtype)
在指定的间隔内返回均匀间隔的数字 ,返回num均匀分布的样本在[start,stop]之间
np.clip(a,a_min,a_max,out=None)是将a限定在a_min和a_max之间,当a大于a_max时返回a_max,a小于a_min返回a_min,否则返回a本身
np.histogram2d可以将两个二维数组做出它的直方图
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
from scipy.signal import butter, filtfilt
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
def butter_lowpass(cutoff, fs, order):
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
return butter(order, normal_cutoff, btype='low', analog=False)
b, a = butter_lowpass(cutoff, fs, order=order)
return filtfilt(b, a, data) # forward-backward filter
data:原数据
cutoff:被丢掉的频率
fs:滤波器大小
这个函数实现了低通滤波,即保留图像中频率比较低的部分,丢掉频率高的部分,“低通”就是低频能够通过,高频无法通过。
butter为配置滤波器,filtfilt实现滤波
具体可参考官网scipy.signal.butter — SciPy v1.7.1 Manual
def output_to_target(output):
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
targets = []
for i, o in enumerate(output):
for *box, conf, cls in o.cpu().numpy():
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
return np.array(targets)
output:模型的输出
该函数将模型的输出转换为我们想要的格式,即[batch_id,class_id,x,y,w,h,conf]
output的格式为[boxes,conf,cla],分别代表了预测框、置信度、类别
标签的格式为[batch_id,class_id,x,y,w,h,conf]* M,M为整个batch的预测框数量。
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
images:一个batch的图片
labels:一个batch的标签
paths:一个batch的文件名
fname:保存可视化之后大图的文件路径
names:类别名
max_size:限制每张可视化图片的最大图片大小
max_subplots:最多可视化batch_size=16张图片
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(targets, torch.Tensor):
targets = targets.cpu().numpy()
if np.max(images[0]) <= 1:
images *= 255.0 # de-normalise (optional)
将images和labels从tensor转换为numpy类型
如果images为0-1,将其乘上255转换为0-255
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
bs,h,w分别为batch_size,图片的高度、宽度
# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, im in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
break
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
im = im.transpose(1, 2, 0)
mosaic[y:y + h, x:x + w, :] = im
mosaic为初始化大图
对images进行遍历,x和y为转化为mosaic的像素位置
这块代码就是将images进行放大,复制到mosaic
# Resize (optional)
scale = max_size / ns / max(h, w)
if scale < 1:
h = math.ceil(scale * h)
w = math.ceil(scale * w)
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
对mosaic进行resize,h和w分别为新的高和宽,scale为缩小倍数
# Annotate
fs = int((h + w) * ns * 0.01) # font size
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs)
for i in range(i + 1):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
if paths:
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(targets) > 0:
ti = targets[targets[:, 0] == i] # image targets
boxes = xywh2xyxy(ti[:, 2:6]).T
classes = ti[:, 1].astype('int')
labels = ti.shape[1] == 6 # labels if no conf column
conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
if boxes.shape[1]:
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
boxes[[0, 2]] *= w # scale to pixels
boxes[[1, 3]] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes *= scale
boxes[[0, 2]] += x
boxes[[1, 3]] += y
for j, box in enumerate(boxes.T.tolist()):
cls = classes[j]
color = colors(cls)
cls = names[cls] if names else cls
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
annotator.box_label(box, label, color=color)
annotator.im.save(fname) # save
接下来就是对图片进行标注,fs为font size,annotator为上面定义的类
x和y为左上角的点,然后用anatator画一个长方形,如果path不为空并标注出box的类别标签
image_targets为当前batch的标签,boxes、classes、labels、conf分别是预测框、类别、是否可视化标签、置信度,其中labels表示当image_targets.shape[1]==6时需要可视化的是标签而不是预测框。
如果预测框是归一化了的将其放大到原图大小,否则乘以scale_factor
接下来对boxes的坐标加上左上角的坐标,boxes原先的坐标是基于当前grid的左上角的相对坐标,加上左上角的坐标变换为全局坐标
接下来在子图上画框,cls、color为类别和颜色,如果是画预测框并且conf>0.25,则画出一个预测框,设置conf>0.25是为了去除掉那些重复预测出来的框。
最后将其保存在相应的路径下。
本篇文章比较重要的部分就是对图片进行画框和标注类别的处理,还有一些方法还没有介绍到,将在下一篇文章继续介绍这部分内容。