Pytorch torch.save() 保存特征向量

文章目录

  • 1 需求
  • 2 实现

1 需求

Pytorch torch.save() 保存特征向量_第1张图片

存取上述特征向量

2 实现

  • 数据结构: 使用list存储这些向量,[(r_emb, query), ...]
  • 工具: torch.save()tensor保存为.pth,存取对象是字典
"""
保存特征向量,推荐使用torch保存,直接保存为tensor
"""
import torch


def save_feature(feature_list, feature_path):
    feature = {

    }
    for i, (r_emb, query) in enumerate(feature_list):
        feature[f"r_emb_{i}"] = r_emb
        feature[f"query_{i}"] = query

    torch.save(feature, feature_path)
    pass

def load_feature(feature_path):
    feature = torch.load(feature_path)
    feature_list = []
    for i in range(len(feature.keys()) // 2):
        r_emb = feature[f"r_emb_{i}"]
        query = feature[f"query_{i}"]
        feature_list.append((r_emb, query))
        ...
    return feature_list
    ...

if __name__ == "__main__":
    r_emb_1 = torch.randn((32, 75, 512))
    query_1 = torch.randn((32, 22, 512))

    r_emb_2 = torch.randn((32, 75, 512))
    query_2 = torch.randn((32, 26, 512))

    feature_list = [(r_emb_1, query_1), (r_emb_2, query_2)]
    feature_path = "./save_feature.pth"

    # save_feature(feature_list, feature_path)
    feature = load_feature(feature_path)
    print("query_1 shape:", feature[0][1].shape)
    pass

在这里插入图片描述


你可能感兴趣的:(pytorch,深度学习,python)