学习 ChatGLM3 的项目内容,过程中使用 AI 代码工具,对代码进行解释,帮助自己快速理解代码。这篇文章记录 ChatGLM3 tool_registry.py 的代码解析内容。
from copy import deepcopy
import inspect
from pprint import pformat
import traceback
from types import GenericAlias
from typing import get_origin, Annotated
_TOOL_HOOKS = {}
_TOOL_DESCRIPTIONS = {}
这段代码定义了几个全局变量和导入了一些模块。让我来逐个解释:
from copy import deepcopy:从 copy 模块导入 deepcopy 函数,用于深拷贝对象。
import inspect:导入 inspect 模块,用于获取对象的信息。
from pprint import pformat:从 pprint 模块导入 pformat 函数,用于格式化打印对象。
import traceback:导入 traceback 模块,用于打印异常堆栈信息。
from types import GenericAlias:从 types 模块导入 GenericAlias 类,用于表示泛型类型。
from typing import get_origin, Annotated:从 typing 模块导入 get_origin 和 Annotated 函数,用于获取泛型类型的原始类型和注解信息。
_TOOL_HOOKS = {}:定义一个空的全局字典变量 _TOOL_HOOKS,用于存储工具的钩子函数。
_TOOL_DESCRIPTIONS = {}:定义一个空的全局字典变量 _TOOL_DESCRIPTIONS,用于存储工具的描述信息。
这段代码的作用可能是为后续的工具注册和存储钩子函数以及描述信息提供了一个全局的数据结构。
def register_tool(func: callable):
tool_name = func.__name__
tool_description = inspect.getdoc(func).strip()
python_params = inspect.signature(func).parameters
tool_params = []
for name, param in python_params.items():
annotation = param.annotation
if annotation is inspect.Parameter.empty:
raise TypeError(f"Parameter `{name}` missing type annotation")
if get_origin(annotation) != Annotated:
raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
typ, (description, required) = annotation.__origin__, annotation.__metadata__
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
if not isinstance(description, str):
raise TypeError(f"Description for `{name}` must be a string")
if not isinstance(required, bool):
raise TypeError(f"Required for `{name}` must be a bool")
tool_params.append({
"name": name,
"description": description,
"type": typ,
"required": required
})
tool_def = {
"name": tool_name,
"description": tool_description,
"params": tool_params
}
print("[registered tool] " + pformat(tool_def))
_TOOL_HOOKS[tool_name] = func
_TOOL_DESCRIPTIONS[tool_name] = tool_def
return func
这段代码定义了一个名为 register_tool 的函数,该函数接受一个可调用对象 func 作为参数。
以下是代码的详细解析:
def dispatch_tool(tool_name: str, tool_params: dict) -> str:
if tool_name not in _TOOL_HOOKS:
return f"Tool `{tool_name}` not found. Please use a provided tool."
tool_call = _TOOL_HOOKS[tool_name]
try:
ret = tool_call(**tool_params)
except:
ret = traceback.format_exc()
return str(ret)
def get_tools() -> dict:
return deepcopy(_TOOL_DESCRIPTIONS)
这段代码定义了两个函数:dispatch_tool 和 get_tools。让我为你逐个解释:
函数 dispatch_tool(tool_name: str, tool_params: dict) -> str:
该函数接受两个参数 tool_name 和 tool_params,并返回一个字符串。
函数 get_tools() -> dict:
该函数不接受任何参数,返回一个字典。
这两个函数一起提供了工具的调度和获取工具信息的功能。dispatch_tool 函数用于调用具体的工具函数,而 get_tools 函数用于获取所有已注册工具的描述信息。
deepcopy: deepcopy 是一个函数,用于创建一个对象的深拷贝。深拷贝是指创建一个新对象,将原始对象的所有元素递归地复制到新对象中,包括嵌套的对象。换句话说,它会创建一个原始对象的完全独立副本,而不仅仅是引用原始对象的内存地址。
深拷贝对于需要完全独立的副本的情况非常有用,尤其是在处理可变对象时。通过深拷贝,可以确保修改一个对象的副本不会影响到原始对象,因为它们是相互独立的。
例如,假设有一个包含嵌套列表和字典的对象 obj,如果直接对 obj 进行赋值操作,那么新对象将只是原始对象的引用,而不是副本。这意味着对新对象的修改也会反映到原始对象中。但是,如果使用 deepcopy 函数创建一个新对象 new_obj,那么 new_obj 将是 obj 的深拷贝副本,对 new_obj 的修改不会影响到 obj。
@register_tool
def random_number_generator(
seed: Annotated[int, 'The random seed used by the generator', True],
range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
) -> int:
"""
Generates a random number x, s.t. range[0] <= x < range[1]
"""
if not isinstance(seed, int):
raise TypeError("Seed must be an integer")
if not isinstance(range, tuple):
raise TypeError("Range must be a tuple")
if not isinstance(range[0], int) or not isinstance(range[1], int):
raise TypeError("Range must be a tuple of integers")
import random
return random.Random(seed).randint(*range)
这段代码定义了一个名为 random_number_generator 的函数,并使用 @register_tool 装饰器将其注册为一个工具。
函数接受两个参数 seed 和 range,并返回一个整数。下面是对代码的详细解释:
@register_tool:@ 符号是装饰器语法,用于在函数定义之前修饰函数。@register_tool 表示将该函数注册为一个工具。具体工具注册的逻辑在你提供的代码中没有呈现,可以在其他地方找到。
def random_number_generator(seed: Annotated[int, ‘The random seed used by the generator’, True], range: Annotated[tuple[int, int], ‘The range of the generated numbers’, True]) -> int::这是函数的定义部分。函数名为 random_number_generator,接受两个参数 seed 和 range,并指定返回类型为整数。
“”" Generates a random number x, s.t. range[0] <= x < range[1] “”":这是函数的文档字符串(docstring),用于描述函数的功能。根据文档字符串的描述,该函数生成一个介于 range[0] 和 range[1] 之间的随机整数 x。
参数验证部分:在函数体内部,对传入的参数进行验证,确保它们具有正确的类型和值。
import random:导入 Python 标准库中的 random 模块,用于生成随机数。
return random.Random(seed).randint(*range):使用 random 模块生成一个随机整数,并将其作为函数的返回值。random.Random(seed) 创建了一个具有指定种子 seed 的随机数生成器对象,然后使用 randint(*range) 方法生成介于 range[0] 和 range[1] 之间的随机整数。
总之,这段代码定义了一个将参数验证和随机数生成结合在一起的函数。它使用装饰器将函数注册为一个工具,并在调用时生成指定范围内的随机整数。
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the current weather for `city_name`
"""
if not isinstance(city_name, str):
raise TypeError("City name must be a string")
key_selection = {
"current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
}
import requests
try:
resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
resp.raise_for_status()
resp = resp.json()
ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
except:
import traceback
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
return str(ret)
这段代码定义了一个名为 get_weather 的函数,并使用 @register_tool 装饰器将其注册为一个工具。
函数接受一个参数 city_name,并返回一个字符串。下面是对代码的详细解释:
@register_tool:@ 符号是装饰器语法,用于在函数定义之前修饰函数。@register_tool 表示将该函数注册为一个工具。具体工具注册的逻辑在你提供的代码中没有呈现,可以在其他地方找到。
def get_weather(city_name: Annotated[str, ‘The name of the city to be queried’, True]) -> str::这是函数的定义部分。函数名为 get_weather,接受一个 city_name 参数,指定返回类型为字符串。
“”" Get the current weather for city_name “”":这是函数的文档字符串(docstring),用于描述函数的功能。根据文档字符串的描述,该函数用于获取指定城市的当前天气情况。
参数验证部分:在函数体内部,对传入的参数进行验证,确保它们具有正确的类型和值。
key_selection = {…}:定义了一个字典变量 key_selection,用于存储需要从 API 响应中提取的天气信息的键值选择。该字典的键代表不同的天气信息,而对应的值是一个列表,包含了该天气信息所对应的子键。
import requests:导入 Python 第三方库 requests,用于发送 HTTP 请求。
try::尝试执行一段代码,并捕获可能的异常。
except::捕获可能的异常。
return str(ret):返回结果,将结果转换为字符串类型后返回。
总之,这段代码定义了一个用于获取指定城市天气的函数。它使用 requests 库发送 HTTP 请求获取天气数据,并从响应中提取指定的天气信息。如果发生任何异常,它会将错误提示信息和堆栈信息返回。
请注意,这段代码中的 @register_tool 装饰器和 requests 库是额外的依赖项,你可能需要在其他地方找到这些实现或库的定义。
完结!