工作太忙,导致完全没有时间学习了。国庆期间,抽空找个baseline继续学习一波
https://www.kaggle.com/code/jjinho/open-book-llm-science-exam/notebook
trn = pd.read_csv("/kaggle/input/kaggle-llm-science-exam/train.csv")
model = SentenceTransformer(MODEL, device='cuda')
model.max_seq_length = MAX_LENGTH
model = model.half()
sentence_index = read_index("/kaggle/input/wikipedia-2023-07-faiss-index/wikipedia_202307.index")
# 训练集中prompt转化为embed
prompt_embeddings = model.encode(trn.prompt.values, batch_size=BATCH_SIZE, device=DEVICE, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True).half()
prompt_embeddings = prompt_embeddings.detach().cpu().numpy()
# 选择top3相似
search_score, search_index = sentence_index.search(prompt_embeddings, 3)
# wiki的index
df = pd.read_parquet("/kaggle/input/wikipedia-20230701/wiki_2023_index.parquet", columns=['id', 'file'])
wikipedia_file_data = []
for i, (scr, idx) in tqdm(enumerate(zip(search_score, search_index)), total=len(search_score)):
## Get indices by score threshold
#scr_idx = idx[np.where(scr <= 0.85)]
scr_idx = idx
_df = df.loc[scr_idx].copy()
_df['prompt_id'] = i
wikipedia_file_data.append(_df)
wikipedia_file_data = pd.concat(wikipedia_file_data).reset_index(drop=True)
wikipedia_file_data = wikipedia_file_data[['id', 'prompt_id', 'file']].drop_duplicates().sort_values(['file', 'id']).reset_index(drop=True)
wiki_text_data = []
for file in tqdm(wikipedia_file_data.file.unique(), total=len(wikipedia_file_data.file.unique())):
_id = [str(i) for i in wikipedia_file_data[wikipedia_file_data['file']==file]['id'].tolist()]
_df = pd.read_parquet(f"{WIKI_PATH}/{file}", columns=['id', 'text'])
_df = _df[_df['id'].isin(_id)]
wiki_text_data.append(_df)
_ = gc.collect()
wiki_text_data = pd.concat(wiki_text_data).drop_duplicates().reset_index(drop=True)
_ = gc.collect()
wiki_data_embeddings = model.encode(processed_wiki_text_data.text, batch_size=BATCH_SIZE, device=DEVICE, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True).half()
wiki_data_embeddings = wiki_data_embeddings.detach().cpu().numpy()
得到wiki 增强的数据