python实现YUV转RGB

序言

因为在项目中要用到yuv格式的视频图像进行模型推理,但是我们模型通常都是只接收RGB格式的图片,所以在推理前就要先把YUV格式转换为RGB格式。

一、什么是YUV

在代码实现之前,需要去了解什么是YUV格式编码,明白了其编码方式后再看实现会简单的多,如果只是需要代码实现,那可以直接简单略过。

YUV是一种类似RGB的颜色模型,起源于黑白和彩电的过渡时期,是一种颜色编码方法,其中Y代表亮度,UV组合起来可以表示色度。YUV信息只有y的信息就足以显示黑白的图片(所以我们早期的黑白电视就是只用了Y),这样的设计很好地解决了彩色电视机与黑白电视的兼容问题。并且,YUV不像RGB那样要求三个独立的视频信号同时传输,所以用YUV方式传送占用极少的频宽。

YUV码流的存储格式其实与其采样的方式密切相关,主流的采样方式有三种,YUV4:4:4,YUV4:2:2,YUV4:2:0,关于其详细原理,可以参考这篇文章一文读懂 YUV 的采样与格式,YUV与RGB编码这里就不再过多介绍,本文是实现根据其采样格式来从码流中还原每个像素点的YUV值,然后通过YUV与RGB的转换公式提取出每个像素点的RGB值,然后显示保存下来。

2. 实现过程

在实现代码之前,首先要明确自己要转换的YUV是什么存储格式,因为不同的存储格式需要不同的提取方式。简单看下下面这两种不同的存储格式:

yuv420P(YU12 和 YV12 )格式如下,先存储Y分量,再存储U分量,最后存储V分量:
python实现YUV转RGB_第1张图片
yuv420SP(NV12 和 NV21 )格式如下,先存储Y分量,然后交替存储UV分量:
python实现YUV转RGB_第2张图片
作为对比,再上一张RGB的存储格式:

RGB存储方式:RGB三个分量按照B、G、R的顺序储存。(4:4:4)
python实现YUV转RGB_第3张图片

所以我们在实现之前需要明确是什么格式存储的,否则提取出来的UV分量可能是错的,导致转换后的图像色彩不对。

3. 代码实现

第一版实现,使用for循环逐点提取YUV值转换成RGB(非常耗时,不建议使用,为了更清楚实现的过程):

def yuv2rgb(Y, U, V):
    bgr_data = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3), dtype=np.uint8)
    for h_idx in range(Y_HEIGHT):
        for w_idx in range(Y_WIDTH):
            y = Y[h_idx, w_idx]
            u = U[int(h_idx // 2), int(w_idx // 2)]
            v = V[int(h_idx // 2), int(w_idx // 2)]

            c = (y - 16) * 298
            d = u - 128
            e = v - 128

            r = (c + 409 * e + 128) // 256
            g = (c - 100 * d - 208 * e + 128) // 256
            b = (c + 516 * d + 128) // 256

            bgr_data[h_idx, w_idx, 2] = 0 if r < 0 else (255 if r > 255 else r)
            bgr_data[h_idx, w_idx, 1] = 0 if g < 0 else (255 if g > 255 else g)
            bgr_data[h_idx, w_idx, 0] = 0 if b < 0 else (255 if b > 255 else b)

    return bgr_data

第二版实现,使用numpy数组运算进行加速(速度非常快,建议用这版):

def np_yuv2rgb(Y,U,V):
    bgr_data = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3), dtype=np.uint8)
    V = np.repeat(V, 2, 0)
    V = np.repeat(V, 2, 1)
    U = np.repeat(U, 2, 0)
    U = np.repeat(U, 2, 1)

    c = (Y-np.array([16])) * 298
    d = U - np.array([128])
    e = V - np.array([128])

    r = (c + 409 * e + 128) // 256
    g = (c - 100 * d - 208 * e + 128) // 256
    b = (c + 516 * d + 128) // 256

    r = np.where(r < 0, 0, r)
    r = np.where(r > 255,255,r)

    g = np.where(g < 0, 0, g)
    g = np.where(g > 255,255,g)

    b = np.where(b < 0, 0, b)
    b = np.where(b > 255,255,b)

    bgr_data[:, :, 2] = r
    bgr_data[:, :, 1] = g
    bgr_data[:, :, 0] = b

    return bgr_data

实验对比:输入同一张(1152*648)的YUV图像。

第一版耗时:5.698601007461548
第二版耗时:0.04670834541320801

速度提升了一百多倍,图片越大,提升效果越明显,最后转换出来的图像如图所示:

全部代码:

import os
import cv2
import numpy as np

IMG_WIDTH = 1152
IMG_HEIGHT = 648
IMG_SIZE = int(IMG_WIDTH * IMG_HEIGHT * 3 / 2)

Y_WIDTH = IMG_WIDTH
Y_HEIGHT = IMG_HEIGHT
Y_SIZE = int(Y_WIDTH * Y_HEIGHT)

U_V_WIDTH = int(IMG_WIDTH / 2)
U_V_HEIGHT = int(IMG_HEIGHT / 2)
U_V_SIZE = int(U_V_WIDTH * U_V_HEIGHT)


def from_I420(yuv_data, frames):
    Y = np.zeros((frames, IMG_HEIGHT, IMG_WIDTH), dtype=np.uint8)
    U = np.zeros((frames, U_V_HEIGHT, U_V_WIDTH), dtype=np.uint8)
    V = np.zeros((frames, U_V_HEIGHT, U_V_WIDTH), dtype=np.uint8)

    for frame_idx in range(0, frames):
        y_start = frame_idx * IMG_SIZE
        u_start = y_start + Y_SIZE
        v_start = u_start + U_V_SIZE
        v_end = v_start + U_V_SIZE

        Y[frame_idx, :, :] = yuv_data[y_start : u_start].reshape((Y_HEIGHT, Y_WIDTH))
        U[frame_idx, :, :] = yuv_data[u_start : v_start].reshape((U_V_HEIGHT, U_V_WIDTH))
        V[frame_idx, :, :] = yuv_data[v_start : v_end].reshape((U_V_HEIGHT, U_V_WIDTH))
    return Y, U, V

def from_YV12(yuv_data, frames):
    Y = np.zeros((frames, IMG_HEIGHT, IMG_WIDTH), dtype=np.uint8)
    U = np.zeros((frames, U_V_HEIGHT, U_V_WIDTH), dtype=np.uint8)
    V = np.zeros((frames, U_V_HEIGHT, U_V_WIDTH), dtype=np.uint8)

    for frame_idx in range(0, frames):
        y_start = frame_idx * IMG_SIZE
        v_start = y_start + Y_SIZE
        u_start = v_start + U_V_SIZE
        u_end = u_start + U_V_SIZE

        Y[frame_idx, :, :] = yuv_data[y_start : v_start].reshape((Y_HEIGHT, Y_WIDTH))
        V[frame_idx, :, :] = yuv_data[v_start : u_start].reshape((U_V_HEIGHT, U_V_WIDTH))
        U[frame_idx, :, :] = yuv_data[u_start : u_end].reshape((U_V_HEIGHT, U_V_WIDTH))
    return Y, U, V


def from_NV12(yuv_data, frames):
    Y = np.zeros((frames, IMG_HEIGHT, IMG_WIDTH), dtype=np.uint8)
    U = np.zeros((frames, U_V_HEIGHT, U_V_WIDTH), dtype=np.uint8)
    V = np.zeros((frames, U_V_HEIGHT, U_V_WIDTH), dtype=np.uint8)

    for frame_idx in range(0, frames):
        y_start = frame_idx * IMG_SIZE
        u_v_start = y_start + Y_SIZE
        u_v_end = u_v_start + (U_V_SIZE * 2)

        Y[frame_idx, :, :] = yuv_data[y_start : u_v_start].reshape((Y_HEIGHT, Y_WIDTH))
        U_V = yuv_data[u_v_start : u_v_end].reshape((U_V_SIZE, 2))
        U[frame_idx, :, :] = U_V[:, 0].reshape((U_V_HEIGHT, U_V_WIDTH))
        V[frame_idx, :, :] = U_V[:, 1].reshape((U_V_HEIGHT, U_V_WIDTH))
    return Y, U, V


def from_NV21(yuv_data, frames):
    Y = np.zeros((frames, IMG_HEIGHT, IMG_WIDTH), dtype=np.uint8)
    U = np.zeros((frames, U_V_HEIGHT, U_V_WIDTH), dtype=np.uint8)
    V = np.zeros((frames, U_V_HEIGHT, U_V_WIDTH), dtype=np.uint8)

    for frame_idx in range(0, frames):
        y_start = frame_idx * IMG_SIZE
        u_v_start = y_start + Y_SIZE
        u_v_end = u_v_start + (U_V_SIZE * 2)

        Y[frame_idx, :, :] = yuv_data[y_start : u_v_start].reshape((Y_HEIGHT, Y_WIDTH))
        U_V = yuv_data[u_v_start : u_v_end].reshape((U_V_SIZE, 2))
        V[frame_idx, :, :] = U_V[:, 0].reshape((U_V_HEIGHT, U_V_WIDTH))
        U[frame_idx, :, :] = U_V[:, 1].reshape((U_V_HEIGHT, U_V_WIDTH))
    return Y, U, V

def np_yuv2rgb(Y,U,V):
    bgr_data = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3), dtype=np.uint8)
    V = np.repeat(V, 2, 0)
    V = np.repeat(V, 2, 1)
    U = np.repeat(U, 2, 0)
    U = np.repeat(U, 2, 1)

    c = (Y-np.array([16])) * 298
    d = U - np.array([128])
    e = V - np.array([128])

    r = (c + 409 * e + 128) // 256
    g = (c - 100 * d - 208 * e + 128) // 256
    b = (c + 516 * d + 128) // 256

    r = np.where(r < 0, 0, r)
    r = np.where(r > 255,255,r)

    g = np.where(g < 0, 0, g)
    g = np.where(g > 255,255,g)

    b = np.where(b < 0, 0, b)
    b = np.where(b > 255,255,b)

    bgr_data[:, :, 2] = r
    bgr_data[:, :, 1] = g
    bgr_data[:, :, 0] = b

    return bgr_data

def yuv2rgb(Y, U, V):
    bgr_data = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3), dtype=np.uint8)
    for h_idx in range(Y_HEIGHT):
        for w_idx in range(Y_WIDTH):
            y = Y[h_idx, w_idx]
            u = U[int(h_idx // 2), int(w_idx // 2)]
            v = V[int(h_idx // 2), int(w_idx // 2)]

            c = (y - 16) * 298
            d = u - 128
            e = v - 128

            r = (c + 409 * e + 128) // 256
            g = (c - 100 * d - 208 * e + 128) // 256
            b = (c + 516 * d + 128) // 256

            bgr_data[h_idx, w_idx, 2] = 0 if r < 0 else (255 if r > 255 else r)
            bgr_data[h_idx, w_idx, 1] = 0 if g < 0 else (255 if g > 255 else g)
            bgr_data[h_idx, w_idx, 0] = 0 if b < 0 else (255 if b > 255 else b)

    return bgr_data

if __name__ == '__main__':
    import time

    yuv = "request/YUV/2021-05-06/test.yuv"
    frames = int(os.path.getsize(yuv) / IMG_SIZE)

    with open(yuv, "rb") as yuv_f:
        time1 = time.time()
        yuv_bytes = yuv_f.read()
        yuv_data = np.frombuffer(yuv_bytes, np.uint8)

        # Y, U, V = from_I420(yuv_data, frames)
        # Y, U, V = from_YV12(yuv_data, frames)
        # Y, U, V = from_NV12(yuv_data, frames)
        Y, U, V = from_NV21(yuv_data, frames)

        rgb_data = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3), dtype=np.uint8)
        for frame_idx in range(frames):
            # bgr_data = yuv2rgb(Y[frame_idx, :, :], U[frame_idx, :, :], V[frame_idx, :, :])            # for 
            bgr_data = np_yuv2rgb(Y[frame_idx, :, :], U[frame_idx, :, :], V[frame_idx, :, :])           # numpy 
            time2 = time.time()
            print(time2 - time1)
            if bgr_data is not None:
                cv2.imwrite("frame_{}.jpg".format(frame_idx), bgr_data)
                frame_idx +=1

代码和测试用的YUV文件已上传到github,使用时请参考文档。

你可能感兴趣的:(笔记,python)