欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/132602155
OpenFold Multimer 在训练过程的数据加载时,需要将 MSA 与 Template 信息转换成 Feature,再进行训练,这样速度较慢。通过修改数据集类 OpenFoldSingleMultimerDataset
的 __getitem__
方法,可以加速训练过程。
在训练过程中,需要读取 mmcif_cache.json
文件,数据结构如下:
{
"4ewn": {
"release_date": "2012-12-05",
"chain_ids": [
"D"
],
"seqs": [
"MLAKRI..."
],
"no_chains": 1,
"resolution": 1.9
},
"5m9r": {
"release_date": "2017-02-22",
"chain_ids": [
"A",
"B"
],
"seqs": [
"MQDNS...",
"MQDNS..."
],
"no_chains": 2,
"resolution": 1.44
},
#...
}
当前的训练数据格式,例如 train_200_mini.csv
,如下:
pdb_id,chain_id,resolution,release_date,seq,len,chain_type,filepath
7m5z,"A,B",3.06,2021-10-06,"LEDVV...,QNKLE...","263,264","protein,protein",[pdb_path]/structures/m5/pdb7m5z.ent.gz
7k05,"A,B",1.85,2021-10-06,"MSFPP...,MSFPP...","200,200","protein,protein",[pdb_path]/structures/k0/pdb7k05.ent.gz
# ...
同时需要将 feature 的路径,也加入到训练文件 mmcif_cache.json
中,进而,通过预读文件,进行特征抽取,即:
[your folder]/multimer_train/features
使用特征文件夹中,已经预处理之后的特征 features.pkl
,进行训练即可:
# 单个文件夹内容
chain_id_map.json
features.pkl
sequences.fasta
训练文件的转换命令,如下:
python openfold_scripts/main_mmcif_cache_transfer.py -i data/train_200_mini.csv -f [your folder]/multimer_train/features -o mydata/openfold/mmcif_cache_mini.json
源码如下:
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/8/31
"""
import argparse
import json
import os
import sys
from pathlib import Path
import pandas as pd
from tqdm import tqdm
p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:
sys.path.append(p)
class MmcifCacheTransfer(object):
"""
训练 CSV 转换成 OpenFold 的 mmcif_cache.json 格式
"""
def __init__(self):
pass
@staticmethod
def process(input_path, feature_dir, output_path):
print(f"[Info] 输入文件: {input_path}")
print(f"[Info] 特征文件夹: {feature_dir}")
print(f"[Info] 输出文件: {output_path}")
assert os.path.isfile(input_path)
df = pd.read_csv(input_path)
print(f"[Info] 输入样本: {len(df)}")
mmcif_cache_dict = dict()
# pdb_id,chain_id,resolution,release_date,seq,len,chain_type,filepath
for _, row in tqdm(df.iterrows(), "[Info] pdb"):
pdb_id = row["pdb_id"]
release_date = row["release_date"]
chain_ids = row["chain_id"].split(",")
seqs = row["seq"].split(",")
no_chains = len(chain_ids)
resolution = float(row["resolution"])
feature_folder = os.path.join(feature_dir, pdb_id[1:3], f"pdb{pdb_id}_{''.join(chain_ids)}")
pdb_dict = {
"release_date": str(release_date),
"chain_ids": chain_ids,
"seqs": seqs,
"no_chains": no_chains,
"resolution": resolution,
"feature_folder": feature_folder
}
mmcif_cache_dict[pdb_id] = pdb_dict
with open(output_path, "w") as fp:
fp.write(json.dumps(mmcif_cache_dict, indent=4))
print(f"[Info] 全部处理完成: {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-path",
help="the input file path.",
type=Path,
required=True,
)
parser.add_argument(
"-f",
"--feature-dir",
help="the preprocess feature dir.",
type=Path,
required=True
)
parser.add_argument(
"-o",
"--output-path",
help="the output file path.",
type=Path,
required=True
)
args = parser.parse_args()
input_path = str(args.input_path)
feature_dir = str(args.feature_dir)
output_path = str(args.output_path)
assert os.path.isfile(input_path)
# from root_dir import ROOT_DIR, DATA_DIR
# input_path = os.path.join(ROOT_DIR, "data", "train_200_mini.csv")
# output_path = os.path.join(DATA_DIR, "openfold", "mmcif_cache_mini.json")
mct = MmcifCacheTransfer()
mct.process(input_path, feature_dir, output_path)
if __name__ == '__main__':
main()
OpenFold Multimer 的特征读取逻辑,在 openfold/data/data_modules.py#OpenFoldSingleMultimerDataset()
中,即:
if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
if os.path.exists(path + e):
ext = e
break
if ext is None:
raise ValueError("Invalid file type")
# TODO: Add pdb and core exts to data_pipeline for multimer
path += ext
if ext == ".cif":
data = self._parse_mmcif(
path, mmcif_id, self.alignment_dir, alignment_index)
else:
raise ValueError("Extension branch missing")
else:
path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=self.alignment_dir)
修改成直接加载 Feature 的形式,即:
if self.mode == 'train' or self.mode == 'eval':
# 训练或评估时,使用预处理的特征
feat_folder = self.mmcif_data_cache[mmcif_id]['feature_folder']
feat_path = os.path.join(feat_folder, "features.pkl")
# logger.info(f"[Info] feat_path: {feat_path}")
data = {}
with open(feat_path, "rb") as f:
feat_dict = pickle.load(f)
data.update(feat_dict)
# logger.info(f"[Info] data: {data.keys()}")
else:
path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=self.alignment_dir)
同时,还需要修改训练数据总数:
def __len__(self):
# 数据部分都由 mmcif_data_cache 提供
# return len(self._chain_ids)
return len(self.mmcif_data_cache.keys)
模型训练的参数,如下:
python3 train_openfold.py \
--train_data_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
--train_alignment_dir mydata/alignment_dir/ \
--train_mmcif_data_cache_path [your folder]/multimer_train/openfold_cache/mmcif_cache_mini.json \
--template_mmcif_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
--output_dir mydata/output_dir/ \
--max_template_date "2021-10-10" \
--config_preset "model_1_multimer_v3" \
--template_release_dates_cache_path mmcif_cache.json \
--precision bf16 \
--gpus 1 \
--replace_sampler_ddp=True \
--seed 42 \
--deepspeed_config_path deepspeed_config.json \
--checkpoint_every_epoch \
--obsolete_pdbs_file_path [your folder]/af2-data-v230/pdb_mmcif/obsolete.dat
模型训练占用显存较多,V100 目前无法支持,调低 crop_size 与 num_workers,降低资源占用,配置位于 openfold/config.py
中,即:
# crop_size
elif "multimer" in name:
c.update(multimer_config_update.copy_and_resolve_references())
c.data.train.crop_size = 64 # TODO: 用于测试
# num_workers
"data_module": {
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
# "num_workers": 16,
"num_workers": 2, # TODO: 用于测试
"pin_memory": True,
},
},
其中,crop_size = 64 占用显存约是 5141MiB
训练日志,如下:
Epoch 0: 0%| | 0/199 [00:00<?, ?it/s]INFO:openfold/data/data_modules.py:mmcif_id is: 7poc, idx: 148 and has 4 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7u49, idx: 97 and has 3 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7z7h, idx: 114 and has 6 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7nup, idx: 111 and has 4 chains
cum_loss: tensor([84.1698], device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>) losses: {'distogram': tensor(4.1562, device='cuda:0', dtype=torch.float64), 'experimentally_resolved': tensor(0.6914, device='cuda:0'), 'fape': tensor(1.6598, device='cuda:0', dtype=torch.float64), 'plddt_loss': tensor(3.9062, device='cuda:0', dtype=torch.float64), 'masked_msa': tensor(3.0938, device='cuda:0'), 'supervised_chi': tensor(0.7941, device='cuda:0', dtype=torch.float64), 'violation': tensor(3.6495, device='cuda:0'), 'tm': tensor(4.1562, device='cuda:0', dtype=torch.float64), 'chain_center_of_mass': tensor([1.3754], device='cuda:0', dtype=torch.float64), 'unscaled_loss': tensor([10.5212], device='cuda:0', dtype=torch.float64), 'loss': tensor([84.1698], device='cuda:0', dtype=torch.float64)}
Epoch 0: 1%|▉ | 1/199 [02:55<9:38:06, 175.18s/it, loss=84.2, v_num=]