Recommendation using embeddings and nearest neighbor search(openAI cookbook 案例学习)

Recommendation using embeddings and nearest neighbor search (openAI cookbook 案例学习)

  • 学习一下这个案例
  • openai-cookbook/examples/Recommendation_using_embeddings.ipynb at 45c6406e8bb42e502d0394a9f1d217e5494ba4a2 · openai/openai-cookbook (github.com)

    文章目录

    • Recommendation using embeddings and nearest neighbor search (openAI cookbook 案例学习)
    • 导入库
    • 载入数据和数据探索性分析
    • 构建缓存保存嵌入向量
    • 基于嵌入向量推荐相似文章
    • 推荐示例

导入库



# 导入必要的库
import pandas as pd  # 用于数据处理和分析
import pickle  # 用于序列化和反序列化Python对象

# 导入openai.embeddings_utils模块中的函数
from openai.embeddings_utils import (
    get_embedding,  # 用于获取文本嵌入向量
    distances_from_embeddings,  # 用于计算文本嵌入向量之间的距离
    tsne_components_from_embeddings,  # 用于计算文本嵌入向量的t-SNE降维表示
    chart_from_components,  # 用于绘制t-SNE降维图表
    indices_of_nearest_neighbors_from_distances,  # 用于计算最近邻索引
)

# 定义常量EMBEDDING_MODEL,表示使用的文本嵌入模型
EMBEDDING_MODEL = "text-embedding-ada-002"

载入数据和数据探索性分析

# 定义变量dataset_path,表示数据集文件的路径
dataset_path = "data/AG_news_samples.csv"

# 使用pandas库中的read_csv函数读取数据集文件,并将结果存储在变量df中
df = pd.read_csv(dataset_path)

# 定义变量n_examples,表示要打印的数据集示例数
n_examples = 5

# 使用pandas库中的head函数打印数据集的前n_examples个示例
df.head(n_examples)

构建缓存保存嵌入向量

# 建立一个嵌入向量缓存,以避免重复计算
# 缓存是一个字典,将元组(text, model)映射到嵌入向量,以pickle文件的形式保存

# 设置嵌入向量缓存文件的路径
embedding_cache_path = "data/recommendations_embeddings_cache.pkl"

# 如果缓存文件存在,则加载缓存,否则创建一个空字典
try:
    embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
    embedding_cache = {}

# 将缓存保存到磁盘上
with open(embedding_cache_path, "wb") as embedding_cache_file:
    pickle.dump(embedding_cache, embedding_cache_file)

# 定义一个函数,如果缓存中存在,则从缓存中检索嵌入向量,否则通过API请求
def embedding_from_string(
    string: str,
    model: str = EMBEDDING_MODEL,
    embedding_cache=embedding_cache
) -> list:
    """返回给定字符串的嵌入向量,使用缓存以避免重复计算。"""
    if (string, model) not in embedding_cache.keys():
        embedding_cache[(string, model)] = get_embedding(string, model)
        with open(embedding_cache_path, "wb") as embedding_cache_file:
            pickle.dump(embedding_cache, embedding_cache_file)
    return embedding_cache[(string, model)]

这段代码主要是建立一个嵌入向量缓存,以避免重复计算。缓存是一个字典,将元组(text, model)映射到嵌入向量,以 pickle 文件的形式保存。首先,设置嵌入向量缓存文件的路径,并尝试从文件中加载缓存。如果缓存文件不存在,则创建一个空字典。然后,将缓存保存到磁盘上。最后,定义一个函数,如果缓存中存在,则从缓存中检索嵌入向量,否则通过 API 请求。如果从 API 请求获取了嵌入向量,则将其添加到缓存中并保存到磁盘上。

基于嵌入向量推荐相似文章

  • 主要是三个步骤:

    • 获取所有文章描述的相似嵌入向量
    • 计算源标题与所有其他文章之间的距离
    • 打印与源标题最接近的其他文章
  • 代码如下

    def print_recommendations_from_strings(
        strings: list[str],  # 所有字符串的列表
        index_of_source_string: int,  # 源字符串在列表中的索引
        k_nearest_neighbors: int = 1,  # 要打印的最近邻居数量,默认为1
        model=EMBEDDING_MODEL,  # 使用的嵌入模型,默认为EMBEDDING_MODEL
    ) -> list[int]:
        """打印给定字符串的k个最近邻居。"""
        # 获取所有字符串的嵌入向量
        embeddings = [embedding_from_string(string, model=model) for string in strings]
        # 获取源字符串的嵌入向量
        query_embedding = embeddings[index_of_source_string]
        # 获取源嵌入向量与其他嵌入向量之间的距离(使用embeddings_utils.py中的函数)
        distances = distances_from_embeddings(query_embedding, embeddings, distance_metric="cosine")
        # 获取最近邻居的索引(使用embeddings_utils.py中的函数)
        indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)
    
        # 打印源字符串
        query_string = strings[index_of_source_string]
        print(f"源字符串: {query_string}")
        # 打印其k个最近邻居
        k_counter = 0
        for i in indices_of_nearest_neighbors:
            # 跳过与源字符串完全匹配的字符串
            if query_string == strings[i]:
                continue
            # 在打印k个文章后停止
            if k_counter >= k_nearest_neighbors:
                break
            k_counter += 1
    
            # 打印相似字符串及其距离
            print(
                f"""
            --- Recommendation #{k_counter} (nearest neighbor {k_counter} of {k_nearest_neighbors}) ---
            String: {strings[i]}
            Distance: {distances[i]:0.3f}"""
            )
    
        return indices_of_nearest_neighbors
    

推荐示例

# 将所有文章的描述转换为一个字符串列表
article_descriptions = df["description"].tolist()

# 使用print_recommendations_from_strings函数打印与第一篇关于Tony Blair的文章最相似的5篇文章
tony_blair_articles = print_recommendations_from_strings(
    strings=article_descriptions,  # 让我们以文章描述为基础进行相似性比较
    index_of_source_string=0,  # 让我们查看与第一篇关于Tony Blair的文章相似的文章
    k_nearest_neighbors=5,  # 让我们查看最相似的5篇文章
)
  • 结果如下

    Source string: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases.
    
            --- Recommendation #1 (nearest neighbor 1 of 5) ---
            String: THE re-election of British Prime Minister Tony Blair would be seen as an endorsement of the military action in Iraq, Prime Minister John Howard said today.
            Distance: 0.153
    
            --- Recommendation #2 (nearest neighbor 2 of 5) ---
            String: LONDON, England -- A US scientist is reported to have observed a surprising jump in the amount of carbon dioxide, the main greenhouse gas.
            Distance: 0.160
    
            --- Recommendation #3 (nearest neighbor 3 of 5) ---
            String: The anguish of hostage Kenneth Bigley in Iraq hangs over Prime Minister Tony Blair today as he faces the twin test of a local election and a debate by his Labour Party about the divisive war.
            Distance: 0.160
    
            --- Recommendation #4 (nearest neighbor 4 of 5) ---
            String: Israel is prepared to back a Middle East conference convened by Tony Blair early next year despite having expressed fears that the British plans were over-ambitious and designed
            Distance: 0.171
    
            --- Recommendation #5 (nearest neighbor 5 of 5) ---
            String: AFP - A battle group of British troops rolled out of southern Iraq on a US-requested mission to deadlier areas near Baghdad, in a major political gamble for British Prime Minister Tony Blair.
            Distance: 0.173
    

你可能感兴趣的:(推荐系统,embedding)