Open Book LLM Science Exam

工作太忙,导致完全没有时间学习了。国庆期间,抽空找个baseline继续学习一波

https://www.kaggle.com/code/jjinho/open-book-llm-science-exam/notebook

  • 首先将维基百科数据都保存为faiss index,根据train训练集的prompt选择最相似的3个
trn = pd.read_csv("/kaggle/input/kaggle-llm-science-exam/train.csv")

model = SentenceTransformer(MODEL, device='cuda')
model.max_seq_length = MAX_LENGTH
model = model.half()

sentence_index = read_index("/kaggle/input/wikipedia-2023-07-faiss-index/wikipedia_202307.index")

# 训练集中prompt转化为embed
prompt_embeddings = model.encode(trn.prompt.values, batch_size=BATCH_SIZE, device=DEVICE, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True).half()
prompt_embeddings = prompt_embeddings.detach().cpu().numpy()

# 选择top3相似
search_score, search_index = sentence_index.search(prompt_embeddings, 3)
  • 根据index得到文件
# wiki的index
df = pd.read_parquet("/kaggle/input/wikipedia-20230701/wiki_2023_index.parquet", columns=['id', 'file'])

wikipedia_file_data = []

for i, (scr, idx) in tqdm(enumerate(zip(search_score, search_index)), total=len(search_score)):
    
    ## Get indices by score threshold
    #scr_idx = idx[np.where(scr <= 0.85)]
    scr_idx = idx
    _df = df.loc[scr_idx].copy()
    _df['prompt_id'] = i
    wikipedia_file_data.append(_df)
wikipedia_file_data = pd.concat(wikipedia_file_data).reset_index(drop=True)
wikipedia_file_data = wikipedia_file_data[['id', 'prompt_id', 'file']].drop_duplicates().sort_values(['file', 'id']).reset_index(drop=True)
  • 根据文件得到text
wiki_text_data = []

for file in tqdm(wikipedia_file_data.file.unique(), total=len(wikipedia_file_data.file.unique())):
    _id = [str(i) for i in wikipedia_file_data[wikipedia_file_data['file']==file]['id'].tolist()]
    _df = pd.read_parquet(f"{WIKI_PATH}/{file}", columns=['id', 'text'])

    _df = _df[_df['id'].isin(_id)]
    wiki_text_data.append(_df)
    _ = gc.collect()
wiki_text_data = pd.concat(wiki_text_data).drop_duplicates().reset_index(drop=True)
_ = gc.collect()
  • Parse documents into sentences
  • 根据text进行sentence embedding
wiki_data_embeddings = model.encode(processed_wiki_text_data.text, batch_size=BATCH_SIZE, device=DEVICE, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True).half()
wiki_data_embeddings = wiki_data_embeddings.detach().cpu().numpy()

得到wiki 增强的数据

你可能感兴趣的:(python)