AI制药 - AlphaFold Multimer 的 MSA Pairing 源码

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://blog.csdn.net/caroline_wendy/article/details/129794694

目前最新版本是v2.3.1,2023.1.12

  1. AlphaFold multimer v1 于 2021 年 7 月发布,同时发表了一篇描述其方法和结果的论文。AlphaFold multimer v1 使用了与 AlphaFold 单体相同的模型结构和训练方法,但增加了一些特征和损失函数来处理多条链。AlphaFold multimer v1 在几个蛋白质复合物的基准测试中取得了最先进的性能。

  2. AlphaFold multimer v2 于 2021 年 9 月 21 日发布,作为一个错误修复版本。AlphaFold multimer v2 没有改变模型参数或结构,但修复了松弛阶段的一些问题,并更新了一些第三方库。

  3. AlphaFold multimer v3 于 2021 年 12 月 3 日发布,带来了新的模型参数,预计在大型蛋白质复合物上更准确。AlphaFold multimer v3 使用了与 AlphaFold multimer v1 相同的模型结构和训练方法,但使用了不同的数据和超参数。AlphaFold multimer v3 还包括了一些内存优化和可用性改进。

入口函数:run_alphafold.py

调用逻辑:

def predict_structure(
    fasta_path: str,
    fasta_name: str,
    output_dir_base: str,
    data_pipeline: Union[pipeline.DataPipeline, pipeline_multimer.DataPipeline],
    model_runners: Dict[str, model.RunModel],
    amber_relaxer: relax.AmberRelaxation,
    benchmark: bool,
    random_seed: int):

其中,data_pipeline,选择pipeline_multimer.DataPipeline

  msa_output_dir = os.path.join(output_dir, 'msas')
  if not os.path.exists(msa_output_dir):
    os.makedirs(msa_output_dir)
  feature_dict = data_pipeline.process(
      input_fasta_path=fasta_path,
      msa_output_dir=msa_output_dir)

调用文件:pipeline_multimer.py,其中核心逻辑有三块:

  1. _process_single_chain,单链处理逻辑
  2. add_assembly_features,添加特征来区分不同的链
  3. pair_and_merge,配对合并MSA

源码及注释如下:

  def process(self,
              input_fasta_path: str,
              msa_output_dir: str) -> pipeline.FeatureDict:
    """Runs alignment tools on the input sequences and creates features."""
    
    # 打开输入的fasta文件并读取内容
    with open(input_fasta_path) as f:
      input_fasta_str = f.read()
    
    # 解析fasta文件中的序列和描述信息
    input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
		
    # 根据序列和描述信息创建一个链ID的映射表
    chain_id_map = _make_chain_id_map(sequences=input_seqs,
                                      descriptions=input_descs)
    
    # 将链ID的映射表保存为json文件
    chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json')
    with open(chain_id_map_path, 'w') as f:
      chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain)
                           for chain_id, fasta_chain in chain_id_map.items()}
      json.dump(chain_id_map_dict, f, indent=4, sort_keys=True)
		
    # 初始化一个空字典来存储所有链的特征
    all_chain_features = {}
    
    # 初始化一个空字典来存储已经处理过的序列的特征,避免重复计算
    sequence_features = {}
    
    # 判断输入的序列是否是复合物或单体,即是否只有一种不同的序列
    is_homomer_or_monomer = len(set(input_seqs)) == 1
    
    # 遍历每个链ID和对应的fasta信息
    for chain_id, fasta_chain in chain_id_map.items():
      
      # 如果该链的序列已经在sequence_features中,直接复制其特征到all_chain_features中
      if fasta_chain.sequence in sequence_features:
        all_chain_features[chain_id] = copy.deepcopy(
            sequence_features[fasta_chain.sequence])
        continue
      
      # 否则,调用另一个函数来处理单个链,包括运行比对工具,生成特征等
      chain_features = self._process_single_chain(
          chain_id=chain_id,
          sequence=fasta_chain.sequence,
          description=fasta_chain.description,
          msa_output_dir=msa_output_dir,
          is_homomer_or_monomer=is_homomer_or_monomer)
			
      # 将单个链的特征转换为单体特征,即添加一些额外的信息,如链ID等
      chain_features = convert_monomer_features(chain_features, chain_id=chain_id)
      
      # 将单个链的特征添加到all_chain_features中
      all_chain_features[chain_id] = chain_features
      # 将单个链的特征添加到sequence_features中,以备后用
      sequence_features[fasta_chain.sequence] = chain_features
		
    # 为所有链的特征添加组装特征,即考虑多个链之间的相互作用等
    all_chain_features = add_assembly_features(all_chain_features)
		
    # 将所有链的特征进行配对和合并,得到一个numpy数组格式的样本
    np_example = feature_processing.pair_and_merge(
        all_chain_features=all_chain_features)
		
    # 将所有链的特征进行配对和合并,得到一个numpy数组格式的样本
    # Pad MSA to avoid zero-sized extra_msa.
    np_example = pad_msa(np_example, 512)
		
    # 返回最终的样本
    return np_example

其中,_process_single_chain的核心逻辑,如下:

  1. 调用self._monomer_data_pipeline.process(),生成单链的MSA信息
  2. 针对于多链,调用self._all_seq_msa_features

源码及注释如下:

  def _process_single_chain(
      self,
      chain_id: str,
      sequence: str,
      description: str,
      msa_output_dir: str,
      is_homomer_or_monomer: bool) -> pipeline.FeatureDict:
    """Runs the monomer pipeline on a single chain."""
    # 为单个链生成fasta字符串
    chain_fasta_str = f'>chain_{chain_id}\n{sequence}\n'
    # 为单个链创建msa输出目录
    chain_msa_output_dir = os.path.join(msa_output_dir, chain_id)
    # 如果目录不存在,就创建它
    if not os.path.exists(chain_msa_output_dir):
      os.makedirs(chain_msa_output_dir)
    # 使用临时fasta文件运行单体数据流程
    with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
      logging.info('Running monomer pipeline on chain %s: %s',
                   chain_id, description)
      # 获取单个链的特征字典
      chain_features = self._monomer_data_pipeline.process(
          input_fasta_path=chain_fasta_path,
          msa_output_dir=chain_msa_output_dir)
			
      # 如果有两个或更多不同的序列,就构建配对特征
      # We only construct the pairing features if there are 2 or more unique
      # sequences.
      if not is_homomer_or_monomer:
        all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path,
                                                          chain_msa_output_dir)
        # 更新单个链的特征字典
        chain_features.update(all_seq_msa_features)
    # 返回单个链的特征字典
    return chain_features

其中,_all_seq_msa_features的核心逻辑,如下:

  • 额外添加MSA的物种信息,即msa_species_identifiers
  def _all_seq_msa_features(self, input_fasta_path, msa_output_dir):
    """Get MSA features for unclustered uniprot, for pairing."""
    # 为未聚类的uniprot获取msa输出路径
    out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
    # 运行msa工具,获取sto格式的结果
    result = pipeline.run_msa_tool(
        self._uniprot_msa_runner, input_fasta_path, out_path, 'sto',
        self.use_precomputed_msas)
    # 解析sto格式的结果,得到msa对象
    msa = parsers.parse_stockholm(result['sto'])
    # 截断msa对象,使其序列数不超过最大值
    msa = msa.truncate(max_seqs=self._max_uniprot_hits)
    # 从msa对象中提取特征
    all_seq_features = pipeline.make_msa_features([msa])
    # 筛选出有效的特征
    valid_feats = msa_pairing.MSA_FEATURES + (
        'msa_species_identifiers',  # MSA物种标识符
    )
    # 为特征添加前缀
    feats = {f'{k}_all_seq': v for k, v in all_seq_features.items()
             if k in valid_feats}
    # 返回特征字典
    return feats

其中,add_assembly_features的核心逻辑,如下:

  • 区分同源二聚体和异源二聚体,使用不同的链名标识。
  • 同时,添加不同的特征,用于区分同源和异源,如asym_id、sym_id、entity_id。
def add_assembly_features(
    all_chain_features: MutableMapping[str, pipeline.FeatureDict],
    ) -> MutableMapping[str, pipeline.FeatureDict]:
    """添加特征来区分不同的链。

  Args:
    all_chain_features: 一个字典,将链的id映射到每条链的特征字典。

  Returns:
    all_chain_features: 一个字典,将形式为`_`的字符串映射到相应的链特征。例如,一个同源二聚体的两条链会有键A_1和A_2。一个异源二聚体的两条链会有键A_1和B_1。
  """
  # 按序列分组链
  # 创建一个空字典,用来存储序列和实体id的对应关系
  seq_to_entity_id = {}
  # 创建一个默认字典,用来按序列分组链的特征
  grouped_chains = collections.defaultdict(list)
  # 遍历所有链的特征
  for chain_id, chain_features in all_chain_features.items():
    # 获取链的序列
    seq = str(chain_features['sequence'])
    # 如果序列不在字典中,就给它分配一个新的实体id
    if seq not in seq_to_entity_id:
      seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
    # 将链的特征添加到对应序列的列表中
    grouped_chains[seq_to_entity_id[seq]].append(chain_features)

  # 创建一个新的空字典,用来存储添加了新特征的链
  new_all_chain_features = {}
  # 初始化一个链的id
  chain_id = 1
  # 遍历按序列分组的链的特征
  for entity_id, group_chain_features in grouped_chains.items():
    # 遍历每个序列中的链,给它们分配一个对称id
    for sym_id, chain_features in enumerate(group_chain_features, start=1):
      # 用实体id和对称id构造一个新的键,如A_1或B_2
      new_all_chain_features[
          f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features
      # 获取链的长度
      seq_length = chain_features['seq_length']
      # 添加不对称id,对称id和实体id作为特征
      chain_features['asym_id'] = chain_id * np.ones(seq_length)
      chain_features['sym_id'] = sym_id * np.ones(seq_length)
      chain_features['entity_id'] = entity_id * np.ones(seq_length)
      # 更新链的id
      chain_id += 1

  # 返回添加了新特征的链的字典
  return new_all_chain_features

其中,pair_and_merge的核心逻辑,如下:

  • process_unmerged_features,合并预处理。
  • create_paired_features,配对特征。
  • deduplicate_unpaired_sequences,移除与配对序列重复的未配对序列
  • merge_chain_features,合并链特征。
def pair_and_merge(
    all_chain_features: MutableMapping[str, pipeline.FeatureDict]
    ) -> pipeline.FeatureDict:
  """对特征进行增强、配对和合并的处理。

  Args:
    all_chain_features: 一个可变映射,存储每条链的特征字典。

  Returns:
    一个特征字典。
  """
	# 对未合并的特征进行处理
  process_unmerged_features(all_chain_features)
	
  # 将所有链的特征转换为列表
  np_chains_list = list(all_chain_features.values())
	
  # 判断是否需要对MSA序列进行配对, _is_homomer_or_monomer中true是同源,false是异源
  # pair_msa_sequences表示异源
  pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)

  if pair_msa_sequences:  # 异源
    # 使用msa_pairing模块创建配对的特征
    np_chains_list = msa_pairing.create_paired_features(
        chains=np_chains_list)
    # 去除未配对的重复序列
    np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
  
  # 裁剪链的长度,限制MSA和模板的数量
  np_chains_list = crop_chains(
      np_chains_list,
      msa_crop_size=MSA_CROP_SIZE,
      pair_msa_sequences=pair_msa_sequences,
      max_templates=MAX_TEMPLATES)
  # 合并链的特征
  np_example = msa_pairing.merge_chain_features(
      np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
      max_templates=MAX_TEMPLATES)
  
  # 对最终的特征进行处理
  np_example = process_final(np_example)
  return np_example

其中,process_unmerged_features的核心逻辑,在chain_features中添加若干特征:

  • 包括deletion_matrix、deletion_matrix_all_seq、deletion_mean、all_atom_mask、all_atom_positions、entity_mask
  • 与多链相关的assembly_num_chains

源码如下:

def process_unmerged_features(
    all_chain_features: MutableMapping[str, pipeline.FeatureDict]):
  """对合并前的每条链的特征进行后处理。"""
  num_chains = len(all_chain_features)
  for chain_features in all_chain_features.values():
    # 将删除矩阵转换为浮点数。
    chain_features['deletion_matrix'] = np.asarray(
        chain_features.pop('deletion_matrix_int'), dtype=np.float32)
    if 'deletion_matrix_int_all_seq' in chain_features:
      chain_features['deletion_matrix_all_seq'] = np.asarray(
          chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32)

    # 计算删除矩阵的均值。
    chain_features['deletion_mean'] = np.mean(
        chain_features['deletion_matrix'], axis=0)

    # 根据aatype添加all_atom_mask和虚拟的all_atom_positions。
    all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
        chain_features['aatype']]
    chain_features['all_atom_mask'] = all_atom_mask
    chain_features['all_atom_positions'] = np.zeros(
        list(all_atom_mask.shape) + [3])

    # 添加assembly_num_chains。
    chain_features['assembly_num_chains'] = np.asarray(num_chains)

  # 添加entity_mask。
  for chain_features in all_chain_features.values():
    chain_features['entity_mask'] = (
        chain_features['entity_id'] != 0).astype(np.int32)

其中,merge_chain_features,合并链特征:

  • 区分同源体,以及配对和不配对的特征合并。
def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
                         pair_msa_sequences: bool,
                         max_templates: int) -> pipeline.FeatureDict:
  """将多条链的特征合并为单个FeatureDict.

  参数:
    np_chains_list: 每条链的FeatureDict的列表.
    pair_msa_sequences: 是否合并配对的MSA.
    max_templates: 包含的模板的最大数量.

  返回:
    整个复合物的单个FeatureDict.
  """
  # 对模板进行填充,使其数量不超过最大值
  np_chains_list = _pad_templates(
      np_chains_list, max_templates=max_templates)
  # 对同源体的密集MSA进行合并
  np_chains_list = _merge_homomers_dense_msa(np_chains_list)
  # 不配对的MSA特征将始终被分块对角化;配对的MSA特征将被连接.
  np_example = _merge_features_from_multiple_chains(
      np_chains_list, pair_msa_sequences=False)
  if pair_msa_sequences:
    # 将配对和不配对的特征连接起来
    np_example = _concatenate_paired_and_unpaired_features(np_example)
  # 根据合并后的特征进行修正
  np_example = _correct_post_merged_feats(
      np_example=np_example,
      np_chains_list=np_chains_list,
      pair_msa_sequences=pair_msa_sequences)

  return np_example

其中,create_paired_features的核心逻辑,主要步骤:

  1. 对链进行序列配对,得到配对的行索引:pair_sequences
  2. 对配对的行进行重新排序:reorder_paired_rows
  3. 特征填充:pad_features
def create_paired_features(
    chains: Iterable[pipeline.FeatureDict]) ->  List[pipeline.FeatureDict]:
  """返回原始链的特征,其中包含配对的 NUM_SEQ 特征。

  Args:
    chains:  每条链的特征字典的列表。

  Returns:
    一个特征字典的列表,其中序列特征只包含要配对的行。
  """
  chains = list(chains)
  chain_keys = chains[0].keys()

  if len(chains) < 2:
    return chains
  else:
    updated_chains = []
    # 对链进行序列配对,得到配对的行索引
    paired_chains_to_paired_row_indices = pair_sequences(chains)
    # 对配对的行进行重新排序
    paired_rows = reorder_paired_rows(
        paired_chains_to_paired_row_indices)

    for chain_num, chain in enumerate(chains):
      # 创建一个新的链特征字典,不包含_all_seq后缀的特征
      new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
      for feature_name in chain_keys:
        if feature_name.endswith('_all_seq'):
          # 对特征进行填充
          feats_padded = pad_features(chain[feature_name], feature_name)
          # 只保留配对的行
          new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
      # 添加num_alignments_all_seq特征
      new_chain['num_alignments_all_seq'] = np.asarray(
          len(paired_rows[:, chain_num]))
      updated_chains.append(new_chain)
    return updated_chains

其中,pair_sequences的核心逻辑,主要根据物种配对MSA的信息,如下:

  • 根据序列相似度匹配MSA行:_match_rows_by_sequence_similarity
def pair_sequences(examples: List[pipeline.FeatureDict]
                   ) -> Dict[int, np.ndarray]:
  """返回跨链配对的MSA序列的索引。"""

  num_examples = len(examples)

  # 创建一个列表,存储每条链的物种字典
  all_chain_species_dict = []
  # 创建一个集合,存储共同的物种
  common_species = set()
  for chain_features in examples:
    # 将链的特征转换为MSA数据框
    msa_df = _make_msa_df(chain_features)
    # 根据MSA数据框创建物种字典
    species_dict = _create_species_dict(msa_df)
    all_chain_species_dict.append(species_dict)
    # 将物种字典中的物种添加到共同物种集合中
    common_species.update(set(species_dict))

  # 对共同物种进行排序
  common_species = sorted(common_species)
  common_species.remove(b'')  # 移除目标序列的物种。

  # 创建一个列表,存储配对的MSA行
  all_paired_msa_rows = [np.zeros(len(examples), int)]
  # 创建一个字典,按照出现在多少条链中分组配对的MSA行
  all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
  all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]

  # 遍历共同物种
  for species in common_species:
    if not species:
      continue
    # 创建一个列表,存储每条链中该物种的MSA数据框
    this_species_msa_dfs = []
    # 记录该物种出现在多少条链中
    species_dfs_present = 0
    for species_dict in all_chain_species_dict:
      if species in species_dict:
        this_species_msa_dfs.append(species_dict[species])
        species_dfs_present += 1
      else:
        this_species_msa_dfs.append(None)

    # 跳过只出现在一条链中的物种
    if species_dfs_present <= 1:
      continue

    # 跳过MSA数据框过大的物种
    if np.any(
        np.array([len(species_df) for species_df in
                  this_species_msa_dfs if
                  isinstance(species_df, pd.DataFrame)]) > 600):
      continue

    # 根据序列相似度匹配MSA行
    paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
    # 将匹配的MSA行添加到列表和字典中
    all_paired_msa_rows.extend(paired_msa_rows)
    all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
  # 将字典中的值转换为数组
  all_paired_msa_rows_dict = {
      num_examples: np.array(paired_msa_rows) for
      num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
  }
  return all_paired_msa_rows_dict

其中,_match_rows_by_sequence_similarity的核心逻辑:

def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
                                       ) -> List[List[int]]:
  """根据序列相似度找出跨链的MSA序列配对。

  首先,将每条链的MSA序列按照它们与各自目标序列的相似度进行排序。然后,从最相似的序列开始进行配对。

  Args:
    this_species_msa_dfs: 一个列表,包含了特定物种的MSA特征的数据框。

  Returns:
   一个列表的列表,每个列表包含M个索引,对应于配对的MSA行,其中M是链的数量。
  """
  all_paired_msa_rows = []

  # 获取每个数据框中的序列数量
  num_seqs = [len(species_df) for species_df in this_species_msa_dfs
              if species_df is not None]
  # 取最小的序列数量
  take_num_seqs = np.min(num_seqs)

  # 定义一个函数,按照相似度对数据框进行排序
  sort_by_similarity = (
      lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))

  for species_df in this_species_msa_dfs:
    if species_df is not None:
      # 对该物种的数据框进行排序
      species_df_sorted = sort_by_similarity(species_df)
      # 获取前take_num_seqs个MSA行的索引
      msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
    else:
      # 如果该物种不存在,则取最后一行(填充行)的索引
      msa_rows = [-1] * take_num_seqs  
    all_paired_msa_rows.append(msa_rows)
  # 将所有链的MSA行索引转置
  all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
  return all_paired_msa_rows

其中,deduplicate_unpaired_sequences的核心逻辑:

def deduplicate_unpaired_sequences(
    np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
  """移除与配对序列重复的未配对序列。"""

  # 获取特征的名称
  feature_names = np_chains[0].keys()
  # 获取MSA相关的特征
  msa_features = MSA_FEATURES

  for chain in np_chains:
    # 将msa_all_seq特征转换为元组,方便哈希
    sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
    keep_rows = []
    # 遍历未配对的MSA序列,移除任何与已配对序列相同的行
    for row_num, seq in enumerate(chain['msa']):
      if tuple(seq) not in sequence_set:
        keep_rows.append(row_num)
    # 更新MSA相关的特征,只保留需要的行
    for feature_name in feature_names:
      if feature_name in msa_features:
        chain[feature_name] = chain[feature_name][keep_rows]
    # 更新num_alignments特征
    chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
  return np_chains

其中,reorder_paired_rows的核心逻辑:

  • 创建一个包含跨链配对MSA行的索引列表.
  • 对配对链的数量进行降序遍历
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
                        ) -> np.ndarray:
  """创建一个包含跨链配对MSA行的索引列表.

  参数:
    all_paired_msa_rows_dict: 一个映射,从配对链的数量到配对索引.

  返回:
    一个列表的列表,每个列表包含跨链配对MSA行的索引.
    配对索引列表按以下顺序排序:
      1) 配对比对中的链的数量,即,所有链的配对将排在前面.
      2) e值
  """
  # 初始化一个空列表
  all_paired_msa_rows = []

  # 对配对链的数量进行降序遍历
  for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
    # 获取当前数量的配对索引
    paired_rows = all_paired_msa_rows_dict[num_pairings]
    # 计算每个配对索引的乘积的绝对值
    paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
    # 按照乘积的大小进行升序排序
    paired_rows_sort_index = np.argsort(paired_rows_product)
    # 将排序后的配对索引添加到列表中
    all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])

  # 将列表转换为数组并返回
  return np.array(all_paired_msa_rows)

你可能感兴趣的:(深度学习,人工智能,python,深度学习)