图像通用预处理操作

index = 0 # 用于记录帧数
while(1):
	success, frame = capture.read() # 一帧帧的读取,帧为bgr格式的
	if not success: break
	index += 1
	print('detect frame: %d' % (index))
	
	# 将frame输入网络中,在这之前,我们要对frame进行一些图象预处理
	# 和read_img的操作一样,直接抄过来就好了
	# 将img_mean和img_scale都直接用实数表示了
	img = frame[:,:,::-1] 	# bgr->rgb,对通道进行翻转
	# 图片被线性拉伸或者压缩了,后面要进行坐标的重新映射
	img = cv2.resize(img, (640, 640), interpolation=cv2.INTER_LINEAR)
	img = (img - mean) * scale 		# 像素归一化
	img = np.asarray(img, dtype=np.float32) # 转成float32位的数组
	img = np.expand_dims(img, 0)			# 在axis=0处加入一个维度,img.shape=(1,640,640,3)
	frame = img.transpose(0,3,1,2)			# 调整维度顺序:(batchsize,channel,height,width)
	
	# 因为也是一张张的传进去推理的,所以batchsize=1
	output = model_inference(model_path, frame) # 这意味着,参数里也要有model_path
	# 判断该帧是否有目标,没有目标的话,就直接原样输出该帧
	if output[0].shape[0] == 0:
		writer.write(frame)
		continue
	
	# 有目标的话就对图片进行后处理
	det_bboxes = output[0][:, 0:4]
	det_scores = output[0][4]
	det_labels = output[0][5]
	kpts = output[0][:, 6:]
	
	for idx in range(len(det_bboxes)):
		det_bbox = det_bboxes[idx]		# 选取一个目标的检测框
		kpt = kpts[idx] 				# 选取其对应的关键点数据
		if det_scores[idx] > score_threshold: # 参数列表里还得有score_threshold
			color_map = _CLASS_COLOR_MAP[int(det_labels[idx])]
			# 这里要对坐标进行转换,已知视频的height和width,变换后为640,可知:
			# 真实坐标truth = now * original_size / 640
			# 画识别框
			start_point = (det_bbox[0]*width / 640, det_bbox[1]*height / 640)
			end_point = (det_bbox[2]*width / 640, det_bbox[3]*height / 640)
			frame = cv2.rectangle(frame, start_point, end_point, color_map[::-1],2)
			# 写点文字到图上去
			cv2.putText(frame, "id:{}".format(int(det_labels[idx])), 
						(int(det_bbox[0]*width / 640 + 5),int(det_bbox[1]*width / 640 + 15)),
						cv2.FONT_HERSHEY_SIMPLEX, 0.5, color_map[::-1], 2)
			cv2.putText(frame, "score:{:2.1f}".format(int(det_labels[idx])), 
						(int(det_bbox[0]*width / 640 + 5),int(det_bbox[1]*width / 640 + 30)),
						cv2.FONT_HERSHEY_SIMPLEX, 0.5, color_map[::-1], 2)
			# 关键点需要重新定位,kpt是一个51个元素的array,(x, y, conf)排列,处理感觉好麻烦
			# 感觉实时的话,这里真会浪费很多时间,终于知道为啥要求640 x 640的视频了
			plot_skeleton_kpts(frame, kpt)
			
			writer.write(frame)
			if camera_id != -1:
				cv2.imshow('Mask Detection', frame)
				if cv2.waitKey(1) & 0xFF == ord('q'): break
		writer.release()	


 

你可能感兴趣的:(python,计算机视觉,人工智能,通用代码)