序号 | 函数名 | 简要说明 |
---|---|---|
1 | round_by_factor(number: int, factor: int) -> int |
返回最接近 number ,且能被 factor 整除的整数。 |
2 | ceil_by_factor(number: int, factor: int) -> int |
返回大于等于 number ,且能被 factor 整除的最小整数。 |
3 | floor_by_factor(number: int, factor: int) -> int |
返回小于等于 number ,且能被 factor 整除的最大整数。 |
4 | smart_resize(height: int, width: int, ...) -> tuple[int, int] |
根据给定的高和宽,调整图像尺寸,使其满足特定条件(如可被因数整除、像素数在范围内、保持长宽比)。 |
5 | to_rgb(pil_image: Image.Image) -> Image.Image |
将 PIL 图像转换为 RGB 模式,处理 RGBA 图像的透明通道。 |
6 | fetch_image(ele: dict, size_factor: int = IMAGE_FACTOR) -> Image.Image |
从各种输入(URL、本地路径、Base64、PIL.Image)获取图像,并进行尺寸调整。 |
7 | smart_nframes(ele: dict, total_frames: int, video_fps: int) -> int |
计算用于模型输入的视频帧数,确保帧数满足特定因数要求,并在最小和最大帧数范围内。 |
8 | _read_video_torchvision(ele: dict) -> (torch.Tensor, float) |
使用 torchvision 库读取视频,返回视频帧和帧率。 |
9 | is_decord_available() -> bool |
检查是否安装了 decord 库。 |
10 | _read_video_decord(ele: dict) -> (torch.Tensor, float) |
使用 decord 库读取视频,返回视频帧和帧率。 |
11 | get_video_reader_backend() -> str |
获取用于读取视频的后端库名称,优先使用 decord 。 |
12 | fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor |
读取并处理视频,返回处理后的视频帧。 |
13 | extract_vision_info(conversations: list) -> list[dict] |
从对话中提取与视觉相关的信息,如图像或视频。 |
14 | process_vision_info(conversations: list, return_video_kwargs: bool = False) -> tuple |
处理视觉信息,获取图像和视频数据,以供模型使用。 |
process_vision_info
├── extract_vision_info
├── fetch_image (对于图像)
│ ├── to_rgb
│ └── smart_resize
│ ├── round_by_factor
│ ├── ceil_by_factor
│ └── floor_by_factor
└── fetch_video (对于视频)
├── get_video_reader_backend
│ └── is_decord_available
├── _read_video_torchvision 或 _read_video_decord
│ └── smart_nframes
│ ├── round_by_factor
│ ├── ceil_by_factor
│ └── floor_by_factor
└── smart_resize
├── round_by_factor
├── ceil_by_factor
└── floor_by_factor
process_vision_info
是核心函数,它根据对话内容,分别处理图像和视频。fetch_image
、to_rgb
和 smart_resize
来获取并调整图像。fetch_video
、get_video_reader_backend
、_read_video_torchvision
或 _read_video_decord
、smart_nframes
和 smart_resize
来获取并处理视频帧。round_by_factor
、ceil_by_factor
、floor_by_factor
被多次调用,用于确保尺寸和帧数满足特定的因数要求。IMAGE_FACTOR = 28
:这是图像尺寸调整的因数,图像的高度和宽度都将被调整为 28
的倍数。
MIN_PIXELS = 4 * 28 * 28
:图像的最小像素数,确保图像不小于特定大小。
MAX_PIXELS = 16384 * 28 * 28
:图像的最大像素数,限制图像的最大尺寸,防止过大的图像占用过多内存。
MAX_RATIO = 200
:图像的最大宽高比,用于防止过度拉伸或压缩的图像。
VIDEO_MIN_PIXELS = 128 * 28 * 28
:视频帧的最小像素数。
VIDEO_MAX_PIXELS = 768 * 28 * 28
:视频帧的最大像素数。
FRAME_FACTOR = 2
:视频帧数需要是此因数的倍数。
FPS = 2.0
:默认的视频采样帧率。
FPS_MIN_FRAMES = 4
:视频采样的最小帧数。
FPS_MAX_FRAMES = 768
:视频采样的最大帧数。
VIDEO_TOTAL_PIXELS
:从环境变量 VIDEO_MAX_PIXELS
中获取视频的总像素数,如果未设置,则默认使用 128000 * 28 * 28 * 0.9
,并将其转换为整数。这是对视频输入尺寸的限制。
extract_vision_info
函数功能:
extract_vision_info
函数用于从对话内容(conversations
)中提取所有与视觉相关的信息(如图像或视频),并将这些信息以列表的形式返回。
参数:
conversations
: 类型为 list[dict]
或 list[list[dict]]
,表示对话的列表。每个对话可以是一个包含消息字典的列表,或者直接是消息字典。返回值:
vision_infos
: 类型为 list[dict]
,包含所有提取的视觉信息的字典。代码解析:
初始化空列表 vision_infos
:
vision_infos = []
确保 conversations
是列表的列表格式:
if isinstance(conversations[0], dict):
conversations = [conversations]
conversations
的第一个元素是字典,说明传入的是单个对话,而不是对话的列表。为了统一处理,将其包装成列表的列表形式。遍历每个对话和消息:
for conversation in conversations:
for message in conversation:
检查消息的内容是否为列表:
if isinstance(message["content"], list):
提取视觉信息:
for ele in message["content"]:
if (
"image" in ele
or "image_url" in ele
or "video" in ele
or ele["type"] in ("image", "image_url", "video")
):
vision_infos.append(ele)
ele
。"image"
、"image_url"
或 "video"
键,或者其类型(ele["type"]
)是 "image"
、"image_url"
或 "video"
,则将该元素添加到 vision_infos
列表中。返回提取的视觉信息列表:
return vision_infos
process_vision_info
函数功能:
process_vision_info
函数用于处理从对话内容中提取的视觉信息,包括读取和处理图像和视频数据,最终返回处理后的结果。
参数:
conversations
: 类型为 list[dict]
或 list[list[dict]]
,表示对话的列表。return_video_kwargs
: 类型为 bool
,默认为 False
。如果为 True
,则在返回值中包含视频的额外参数(如帧率)。返回值:
根据 return_video_kwargs
的值,返回不同的内容:
如果 return_video_kwargs
为 False
:
(image_inputs, video_inputs)
image_inputs
: 处理后的图像列表(list[Image.Image]
),如果没有图像,则为 None
。video_inputs
: 处理后的视频列表(list[torch.Tensor]
或 list[list[Image.Image]]
),如果没有视频,则为 None
。如果 return_video_kwargs
为 True
:
(image_inputs, video_inputs, {'fps': video_sample_fps_list})
代码解析:
提取视觉信息:
vision_infos = extract_vision_info(conversations)
extract_vision_info
函数,从对话中提取所有的视觉信息,得到 vision_infos
列表。初始化存储变量:
image_inputs = []
video_inputs = []
video_sample_fps_list = []
image_inputs
: 用于存储处理后的图像数据。video_inputs
: 用于存储处理后的视频数据。video_sample_fps_list
: 用于存储每个视频的采样帧率。处理每个视觉信息:
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
elif "video" in vision_info:
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
video_sample_fps_list.append(video_sample_fps)
video_inputs.append(video_input)
else:
raise ValueError("image, image_url or video should in content.")
遍历 vision_infos
列表,对每个视觉信息进行处理。
处理图像:
vision_info
中包含 "image"
或 "image_url"
键,调用 fetch_image
函数处理图像。image_inputs
列表中。处理视频:
vision_info
中包含 "video"
键,调用 fetch_video
函数处理视频,参数 return_video_sample_fps=True
表示需要返回视频的采样帧率。video_input
和视频帧率 video_sample_fps
。video_inputs
列表,将帧率添加到 video_sample_fps_list
列表。异常处理:
ValueError
,提示内容中应包含 "image"
、"image_url"
或 "video"
。处理可能的空列表:
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
image_inputs
或 video_inputs
列表为空,则将其设置为 None
。根据参数返回结果:
if return_video_kwargs:
return image_inputs, video_inputs, {'fps': video_sample_fps_list}
return image_inputs, video_inputs
return_video_kwargs
为 True
,则返回包含视频帧率信息的字典。False
(默认情形),则只返回图像和视频数据。示例:
假设有如下对话内容:
conversations = [
# 第一个对话
[
{'role': 'user', 'content': [
{'type': 'text', 'data': '请查看这张图片。'},
{'type': 'image', 'image_url': 'http://example.com/image1.jpg'}
]},
{'role': 'assistant', 'content': '好的,我正在查看。'}
],
# 第二个对话
[
{'role': 'user', 'content': [
{'type': 'text', 'data': '这是一个视频。'},
{'type': 'video', 'video': 'http://example.com/video1.mp4'}
]},
{'role': 'assistant', 'content': '我正在处理视频。'}
]
]
调用 process_vision_info(conversations)
:
提取视觉信息:
extract_vision_info
函数遍历对话,找到包含视觉信息的元素。vision_infos
列表,包含两个元素:
{'type': 'image', 'image_url': 'http://example.com/image1.jpg'}
{'type': 'video', 'video': 'http://example.com/video1.mp4'}
处理视觉信息:
fetch_image
读取并处理图像,结果添加到 image_inputs
列表。fetch_video
读取并处理视频,结果添加到 video_inputs
列表,同时帧率添加到 video_sample_fps_list
。返回结果:
image_inputs
是一个包含处理后图像的列表。video_inputs
是一个包含处理后视频的列表。return_video_kwargs
为 True
,还会返回视频帧率信息。tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]
用于描述函数的返回值类型。让我们逐步解析这一复杂的类型注解,理解每个部分的含义。
tuple[...]
tuple[...]
表示一个元组类型,元组中的每个元素的位置和类型都是固定的。list[Image.Image] | None
list[Image.Image]
:表示一个 Image.Image
对象(来自 PIL 库)的列表,即图像对象的列表。| None
:符号 |
表示类型的联合(Union),即该值可以是前面的类型或后面的类型。list[Image.Image] | None
:表示该元素要么是一个 Image.Image
对象的列表,要么是 None
。解释:函数可能返回一个包含图像的列表,如果没有图像,则返回 None
。
list[torch.Tensor | list[Image.Image]] | None
torch.Tensor | list[Image.Image]
:
torch.Tensor
:表示一个 PyTorch 的张量,一般用于表示视频数据(如视频帧序列)。list[Image.Image]
:表示一个 Image.Image
对象的列表,即图像对象的列表。torch.Tensor | list[Image.Image]
:表示该元素可以是 torch.Tensor
或者 list[Image.Image]
。list[...]
:表示上述类型的列表,即列表中的每个元素可以是 torch.Tensor
或 list[Image.Image]
。| None
:表示该值也可以是 None
。list[torch.Tensor | list[Image.Image]] | None
:表示该元素要么是一个列表,列表中的每个元素是 torch.Tensor
或 list[Image.Image]
,要么是 None
。解释:函数可能返回一个视频数据的列表,如果没有视频,则返回 None
。
Optional[dict]
Optional[dict]
:Optional
是 typing
模块中的一个泛型类型,用于表示可选类型,即类型可以是指定的类型或 None
。Optional[dict]
等价于 dict | None
。解释:函数可能返回一个字典(如视频的额外参数),如果没有额外参数,则返回 None
。
示例
假设:
Image.Image
对象的列表。torch.Tensor
,表示视频帧数据。fps
。返回值可能是:
(
[image1, image2], # list[Image.Image]
[video_tensor], # list[torch.Tensor]
{'fps': [video_fps_value]} # dict
)
或者,如果没有图像,只有视频:
(
None, # 没有图像
[video_tensor], # list[torch.Tensor]
{'fps': [video_fps_value]} # dict
)
或者,如果只有图像,没有视频:
(
[image1, image2], # list[Image.Image]
None, # 没有视频
None # 没有额外参数
)
round_by_factor(number: int, factor: int) -> int
功能:将给定的数字 number
调整为最接近的、能被 factor
整除的整数。
实现:
def round_by_factor(number: int, factor: int) -> int:
return round(number / factor) * factor
ceil_by_factor(number: int, factor: int) -> int
功能:将给定的数字 number
调整为大于或等于它的、能被 factor
整除的最小整数。
实现:
def ceil_by_factor(number: int, factor: int) -> int:
return math.ceil(number / factor) * factor
floor_by_factor(number: int, factor: int) -> int
功能:将给定的数字 number
调整为小于或等于它的、能被 factor
整除的最大整数。
实现:
def floor_by_factor(number: int, factor: int) -> int:
return math.floor(number / factor) * factor
smart_resize(...) -> tuple[int, int]
功能:根据给定的高度和宽度,智能地调整图像尺寸,使其满足以下条件:
factor
整除。min_pixels
和 max_pixels
之间。参数:
height
: 原始高度。width
: 原始宽度。factor
: 因数,默认为 IMAGE_FACTOR
(28)。min_pixels
: 最小像素数,默认为 MIN_PIXELS
。max_pixels
: 最大像素数,默认为 MAX_PIXELS
。实现:
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
# 检查宽高比是否过大
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
# 调整高度和宽度,使其能被factor整除
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
# 调整像素数在指定范围内
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
解释:
MAX_RATIO
,如果超过,则抛出错误,防止图像过于拉伸或压缩。round_by_factor
将高度和宽度调整为最接近的、能被 factor
整除的值,且不小于 factor
。max_pixels
和 min_pixels
的关系,调整高度和宽度:
max_pixels
,则计算一个缩放系数 beta
,通过 floor_by_factor
函数减少高度和宽度。min_pixels
,则计算一个放大系数 beta
,通过 ceil_by_factor
函数增加高度和宽度。to_rgb(pil_image: Image.Image) -> Image.Image
功能:将给定的 PIL 图像对象转换为 RGB 模式。如果图像是带有透明度的 RGBA 模式,则将其转换为 RGB 模式,并填充白色背景。
实现:
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == 'RGBA':
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # 使用alpha通道作为掩码
return white_background
else:
return pil_image.convert("RGB")
解释:
RGBA
模式,表示图像带有透明度通道,需要将透明部分填充为白色。white_background
,大小与原图相同。paste
方法,将原始图像粘贴到白色背景上,使用 alpha 通道作为掩码,以保留透明度信息。RGB
模式并返回。fetch_image(ele: dict, size_factor: int = IMAGE_FACTOR) -> Image.Image
功能:根据给定的图像信息,从多种来源(如 URL、本地路径、Base64 编码、PIL.Image 对象)获取图像,并进行预处理,包括转换为 RGB 模式和调整尺寸。
参数:
ele
: 包含图像信息的字典。size_factor
: 调整尺寸的因数,默认为 IMAGE_FACTOR
(28)。实现:
def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
# 获取图像数据
if "image" in ele:
image = ele["image"]
else:
image = ele["image_url"]
image_obj = None
# 根据图像数据类型进行处理
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
response = requests.get(image, stream=True)
image_obj = Image.open(BytesIO(response.content))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
# 转换为RGB模式
image = to_rgb(image_obj)
## 调整尺寸
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=size_factor,
)
else:
width, height = image.size
min_pixels = ele.get("min_pixels", MIN_PIXELS)
max_pixels = ele.get("max_pixels", MAX_PIXELS)
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
解释:
获取图像数据:
ele
中获取图像信息,优先使用键 "image"
,否则使用 "image_url"
。image_obj
为 None
。根据图像数据的类型进行处理:
image
是一个 Image.Image
对象,直接赋值给 image_obj
。image
是以 "http://"
或 "https://"
开头的字符串,表示是网络 URL:
requests
库获取图像内容。Image.open
读取图像。image
是以 "file://"
开头的字符串,表示是本地文件路径:
"file://"
,然后使用 Image.open
读取图像。image
以 "data:image"
开头,表示是 Base64 编码的图像数据:
Image.open
读取图像。image
是本地文件路径,直接使用 Image.open
读取。检查图像是否成功读取:
image_obj
仍为 None
,则抛出错误,提示无法识别的图像输入格式。转换为 RGB 模式:
to_rgb
函数,将图像转换为 RGB 模式,处理透明度问题。调整图像尺寸:
ele
中提供了 "resized_height"
和 "resized_width"
,则使用这些值进行尺寸调整,调用 smart_resize
函数。min_pixels
和 max_pixels
(如果未提供,则使用默认值)。smart_resize
函数,根据原始尺寸、因数和像素范围,计算新的高度和宽度。image.resize
方法调整图像尺寸。返回处理后的图像:
smart_nframes
def smart_nframes(
ele: dict,
total_frames: int,
video_fps: int | float,
) -> int:
...
功能:
smart_nframes
函数用于计算用于模型输入的视频帧数,确保帧数满足一定的条件和限制。
参数:
ele
: 包含视频配置信息的字典,支持以下键:
nframes
: 希望提取的帧数。fps
: 希望以多少帧率来提取帧。min_frames
: 当使用 fps
时,指定最小帧数。max_frames
: 当使用 fps
时,指定最大帧数。total_frames
: 视频的总帧数。video_fps
: 视频的原始帧率。流程:
检查冲突参数:
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
这一步确保 ele
字典中不能同时既有 fps
又有 nframes
,否则抛出断言错误。
根据配置计算帧数:
如果提供了 nframes
:
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
使用 round_by_factor
函数将 nframes
四舍五入到最近的 FRAME_FACTOR
的倍数,确保帧数是特定因子的整数倍。
如果提供了 fps
:
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
nframes = total_frames / video_fps * fps
if nframes > total_frames:
logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
nframes = floor_by_factor(nframes, FRAME_FACTOR)
fps
,如果未提供则使用默认值 FPS
。min_frames
和 max_frames
,确保它们是 FRAME_FACTOR
的倍数。nframes
。nframes
超过了总帧数。nframes
限制在 min_frames
和 max_frames
之间,并确保不超过总帧数。floor_by_factor
将 nframes
向下取整到最近的 FRAME_FACTOR
的倍数。验证帧数是否合理:
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
确保计算出的 nframes
在有效范围内,否则抛出 ValueError
。
返回计算的帧数:
return nframes
_read_video_torchvision
def _read_video_torchvision(
ele: dict,
) -> (torch.Tensor, float):
...
功能:
使用 torchvision
库的 io.read_video
函数读取视频文件,并返回视频帧的张量和采样后的帧率。
参数:
ele
: 包含视频配置信息的字典,支持以下键:
video
: 视频路径,支持本地路径、file://
、http://
、https://
。video_start
: 视频起始时间(秒)。video_end
: 视频结束时间(秒)。流程:
处理视频路径:
video_path = ele["video"]
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
if "http://" in video_path or "https://" in video_path:
warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
if "file://" in video_path:
video_path = video_path[7:]
如果 torchvision
版本低于 0.19.0:
http://
或 https://
读取视频,提示用户升级。file://
开头,去掉前面的 file://
。读取视频:
st = time.time()
video, audio, info = io.read_video(
video_path,
start_pts=ele.get("video_start", 0.0),
end_pts=ele.get("video_end", None),
pts_unit="sec",
output_format="TCHW",
)
使用 io.read_video
读取视频,指定起始和结束时间,输出格式为 (T, C, H, W)
,即帧数、通道数、高度、宽度。
获取视频信息:
total_frames, video_fps = video.size(0), info["video_fps"]
logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
获取视频的总帧数和原始帧率,记录读取时间。
计算需要的帧数:
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
调用之前的 smart_nframes
函数计算需要的帧数。
从视频中采样帧:
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
video = video[idx]
torch.linspace
生成一个索引列表,从视频帧中均匀采样 nframes
帧。sample_fps
。返回视频张量和采样帧率:
return video, sample_fps
is_decord_available
def is_decord_available() -> bool:
import importlib.util
return importlib.util.find_spec("decord") is not None
功能:
检查 decord
库是否可用。
流程:
importlib.util.find_spec("decord")
检查是否可以找到 decord
模块的规格(spec)。True
,否则返回 False
。_read_video_decord
def _read_video_decord(
ele: dict,
) -> (torch.Tensor, float):
...
功能:
使用 decord
库的 VideoReader
读取视频文件,并返回视频帧的张量和采样后的帧率。
参数:
ele
: 包含视频配置信息的字典,支持以下键:
video
: 视频路径,支持本地路径、file://
、http://
、https://
。video_start
: 视频起始时间(暂不支持)。video_end
: 视频结束时间(暂不支持)。流程:
导入 decord
库:
import decord
处理视频路径:
video_path = ele["video"]
st = time.time()
获取视频路径,记录开始时间。
创建 VideoReader
实例:
vr = decord.VideoReader(video_path)
使用 decord
的 VideoReader
读取视频。
暂不支持起始和结束时间:
if 'video_start' in ele or 'video_end' in ele:
raise NotImplementedError("not support start_pts and end_pts in decord for now.")
目前暂不支持通过 decord
指定起始和结束时间,如果发现有这样的参数,抛出 NotImplementedError
。
获取视频信息:
total_frames, video_fps = len(vr), vr.get_avg_fps()
logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
获取视频的总帧数和平均帧率,记录读取时间。
计算需要的帧数:
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
调用 smart_nframes
计算需要的帧数。
从视频中采样帧:
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
torch.linspace
生成索引列表,均匀采样 nframes
帧。vr.get_batch(idx)
获取对应帧,转换为 NumPy 数组。(T, C, H, W)
。返回视频张量和采样帧率:
return video, sample_fps
get_video_reader_backend
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER
elif is_decord_available():
video_reader_backend = "decord"
else:
video_reader_backend = "torchvision"
print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
return video_reader_backend
功能:
根据环境变量或库的可用性,确定使用哪个视频读取后端。
流程:
检查环境变量:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER
如果环境变量 FORCE_QWENVL_VIDEO_READER
被设置,则强制使用该后端。
检查 decord
库是否可用:
elif is_decord_available():
video_reader_backend = "decord"
如果 decord
库可用,则使用 decord
。
默认使用 torchvision
:
else:
video_reader_backend = "torchvision"
如果不满足上述条件,默认使用 torchvision
。
输出使用的后端信息并返回:
print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
return video_reader_backend
打印使用的后端信息,返回后端名称。
注解:
@lru_cache(maxsize=1)
装饰器,表示函数的返回值会被缓存,当再次调用时直接返回缓存值,避免重复计算。fetch_video
def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
...
功能:
根据提供的配置,获取并处理视频数据,返回适用于模型输入的视频张量或图像列表。
参数:
ele
: 包含视频配置信息的字典,支持以下键:
video
: 视频路径,或包含一系列图像的列表。min_pixels
、max_pixels
、resized_height
、resized_width
等,用于调整视频尺寸。image_factor
: 调整尺寸时使用的因子,默认值为 IMAGE_FACTOR
。return_video_sample_fps
: 是否返回采样后的帧率,布尔值。流程:
判断 ele["video"]
的类型:
if isinstance(ele["video"], str):
...
else:
...
处理视频文件:
video_reader_backend = get_video_reader_backend()
try:
video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
except Exception as e:
logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
get_video_reader_backend()
确定后端,然后调用对应的读取函数获取视频张量和采样帧率。torchvision
读取视频。获取视频尺寸信息和像素限制:
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels:
logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
max_pixels = min(max_pixels_supposed, max_pixels)
min_pixels
和 max_pixels
,以限制视频的总像素数,避免内存占用过大。调整视频帧尺寸:
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
resized_height
和 resized_width
,则使用这些值进行尺寸调整。smart_resize
根据原始尺寸和像素限制计算新的高度和宽度。transforms.functional.resize
调整视频帧尺寸。返回结果:
if return_video_sample_fps:
return video, sample_fps
return video
(video, sample_fps)
。处理帧图像列表:
else:
assert isinstance(ele["video"], (list, tuple))
process_info = ele.copy()
process_info.pop("type", None)
process_info.pop("video", None)
images = [
fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
for video_element in ele["video"]
]
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
if len(images) < nframes:
images.extend([images[-1]] * (nframes - len(images)))
if return_video_sample_fps:
return images, process_info.pop("fps", 2.0)
return images
ele["video"]
是一个图像列表,遍历每一帧图像,调用 fetch_image
处理。FRAME_FACTOR
的倍数,不足的话用最后一帧填充。Decord 是一个专为深度学习和视频处理设计的高性能视频读取库。它旨在提供高效、简洁、易用的视频数据加载接口,方便在深度学习模型中使用视频数据。
import decord
from decord import VideoReader
decord.bridge.set_bridge('torch') # 设置与 PyTorch 兼容的桥接
# 创建视频读取器
vr = VideoReader('path/to/your/video.mp4')
# 获取视频的总帧数
total_frames = len(vr)
# 读取特定帧,例如第10帧
frame_10 = vr[9] # 索引从0开始
# 批量读取帧
indices = [0, 5, 10, 15, 20]
frames = vr.get_batch(indices) # 返回指定帧的批量数据
Torchvision 是 PyTorch 官方的计算机视觉工具包,提供了常用的数据集、模型和图像视频处理工具。它是 PyTorch 生态系统中处理视觉数据的核心库。
torchvision.datasets
: 提供常用的计算机视觉数据集,如 MNIST、CIFAR10、ImageNet 等的下载和加载接口。torchvision.models
: 包含预训练的深度学习模型,如 ResNet、AlexNet、VGG 等,可用于迁移学习和特征提取。torchvision.transforms
: 提供一系列图像预处理和数据增强的方法,如裁剪、缩放、翻转、归一化等。torchvision.io
: 提供读取和写入图像、视频数据的接口,包括 read_image
、read_video
等方法。图像处理:
from torchvision import transforms
from PIL import Image
# 定义图像转换方法
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(), # 将图像转换为张量,并将像素值归一化到 [0,1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 标准化
std=[0.229, 0.224, 0.225])
])
# 加载和处理图像
image = Image.open('path/to/your/image.jpg')
image_tensor = transform(image)
视频处理:
import torchvision.io as io
# 读取视频
video_path = 'path/to/your/video.mp4'
video, audio, info = io.read_video(video_path, pts_unit='sec')
# video 是形状为 [T, H, W, C] 的张量,T 是帧数
# 可以进行帧采样或其他处理
!