多个模版结构特征提取

HhsearchHitFeaturizer和HmmsearchHitFeaturizer类的get_templates方法返回TemplateSearchResult。TemplateSearchResult含有features(TEMPLATE_FEATURES字典类型)以及errors(列表类型) 和 warnings (列表类型),模版特征字典的值都为np.array 类型,第一维度为模版数,如本示例中template_aatype特征维度为(3, 396, 22),template_all_atom_positions特征的维度为:(3, 396, 37, 3)。3:模版数;396:查询序列长度;22:氨基酸one-hot编码向量;37:肽链中所有原子类型数;3: 每个原子的xyz值。比提取单一模版特征单一模版特征提取  多一个维度(模版数这一维度)。

### 多个模版特征提取
import dataclasses
from typing import Optional, List, Sequence, Tuple, Mapping, Any, Dict
import datetime
import abc
import glob
from absl import logging
import numpy as np
import re
import functools
from Bio import PDB
import io
import collections
from Bio.Data import SCOPData


atom_types = [
    'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
    'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
    'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
    'CZ3', 'NZ', 'OXT'
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types)  # := 37.


HHBLITS_AA_TO_ID = {
    'A': 0,
    'B': 2,
    'C': 1,
    'D': 2,
    'E': 3,
    'F': 4,
    'G': 5,
    'H': 6,
    'I': 7,
    'J': 20,
    'K': 8,
    'L': 9,
    'M': 10,
    'N': 11,
    'O': 20,
    'P': 12,
    'Q': 13,
    'R': 14,
    'S': 15,
    'T': 16,
    'U': 1,
    'V': 17,
    'W': 18,
    'X': 20,
    'Y': 19,
    'Z': 3,
    '-': 21,
}


class Error(Exception):
  """Base class for exceptions."""


class NoChainsError(Error):
  """An error indicating that template mmCIF didn't have any chains."""


class SequenceNotInTemplateError(Error):
  """An error indicating that template mmCIF didn't contain the sequence."""


class PrefilterError(Exception):
  """A base class for template prefilter exceptions."""

MmCIFDict = Mapping[str, Sequence[str]]

TEMPLATE_FEATURES = {
    'template_aatype': np.float32,
    'template_all_atom_masks': np.float32,
    'template_all_atom_positions': np.float32,
    'template_domain_names': object,
    'template_sequence': object,
    'template_sum_probs': np.float32,
}


@dataclasses.dataclass(frozen=True)
class TemplateHit:
  """Class representing a template hit."""
  index: int
  name: str
  aligned_cols: int
  sum_probs: Optional[float]
  query: str
  hit_sequence: str
  indices_query: List[int]
  indices_hit: List[int]


@dataclasses.dataclass(frozen=True)
class TemplateSearchResult:
  features: Mapping[str, Any]
  errors: Sequence[str]
  warnings: Sequence[str]


class Error(Exception):
  """Base class for exceptions."""

class NoChainsError(Error):
  """An error indicating that template mmCIF didn't have any chains."""

class NoAtomDataInTemplateError(Error):
  """An error indicating that template mmCIF didn't contain atom positions."""

class TemplateAtomMaskAllZerosError(Error):
  """An error indicating that template mmCIF had all atom positions masked."""


class AlignRatioError(PrefilterError):
  """An error indicating that the hit align ratio to the query was too small."""


class CaDistanceError(Error):
  """An error indicating that a CA atom distance exceeds a threshold."""


####### start: 处理mmCIF 格式字符串##########
# Type aliases:
ChainId = str
SeqRes = str
PdbHeader = Mapping[str, Any]
PdbStructure = PDB.Structure.Structure


@dataclasses.dataclass(frozen=True)
class ResiduePosition:
  chain_id: str
  residue_number: int
  insertion_code: str


@dataclasses.dataclass(frozen=True)
class ResidueAtPosition:
  position: Optional[ResiduePosition]
  name: str
  is_missing: bool
  hetflag: str


@dataclasses.dataclass(frozen=True)
class SingleHitResult:
  features: Optional[Mapping[str, Any]]
  error: Optional[str]
  warning: Optional[str]


@dataclasses.dataclass(frozen=True)
class Monomer:
  id: str
  num: int


@dataclasses.dataclass(frozen=True)
class MmcifObject:
  """Representation of a parsed mmCIF file.

  Contains:
    file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
      files being processed.
    header: Biopython header.
    structure: Biopython structure.
    chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
      {'A': 'ABCDEFG'}
    seqres_to_structure: Dict; for each chain_id contains a mapping between
      SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
                                                        1: ResidueAtPosition,
                                                        ...}}
    raw_string: The raw string used to construct the MmcifObject.
  """
  file_id: str
  header: PdbHeader
  structure: PdbStructure
  chain_to_seqres: Mapping[ChainId, SeqRes]
  seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
  raw_string: Any


@dataclasses.dataclass(frozen=True)
class ParsingResult:
  """Returned by the parse function.

  Contains:
    mmcif_object: A MmcifObject, may be None if no chain could be successfully
      parsed.
    errors: A dict mapping (file_id, chain_id) to any exception generated.
  """
  mmcif_object: Optional[MmcifObject]
  errors: Mapping[Tuple[str, str], Any]


@dataclasses.dataclass(frozen=True)
class AtomSite:
  residue_name: str
  author_chain_id: str
  mmcif_chain_id: str
  author_seq_num: str
  mmcif_seq_num: int
  insertion_code: str
  hetatm_atom: str
  model_num: int


def _is_set(data: str) -> bool:
  """Returns False if data is a special mmCIF character indicating 'unset'."""
  return data not in ('.', '?')


def mmcif_loop_to_dict(prefix: str,
                       index: str,
                       parsed_info: MmCIFDict,
                       ) -> Mapping[str, Mapping[str, str]]:
  """Extracts loop associated with a prefix from mmCIF data as a dictionary.

  Args:
    prefix: Prefix shared by each of the data items in the loop.
      e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
      _entity_poly_seq.mon_id. Should include the trailing period.
    index: Which item of loop data should serve as the key.
    parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
      parser.

  Returns:
    Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
    indexed by the index column.
  """
  entries = mmcif_loop_to_list(prefix, parsed_info)
  return {entry[index]: entry for entry in entries}


def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
  """Returns list of atom sites; contains data not present in the structure."""
  return [AtomSite(*site) for site in zip(  # pylint:disable=g-complex-comprehension
      parsed_info['_atom_site.label_comp_id'],
      parsed_info['_atom_site.auth_asym_id'],
      parsed_info['_atom_site.label_asym_id'],
      parsed_info['_atom_site.auth_seq_id'],
      parsed_info['_atom_site.label_seq_id'],
      parsed_info['_atom_site.pdbx_PDB_ins_code'],
      parsed_info['_atom_site.group_PDB'],
      parsed_info['_atom_site.pdbx_PDB_model_num'],
      )]


def _get_atom_positions(
    mmcif_object: MmcifObject,
    auth_chain_id: str,
    max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]:
  """Gets atom positions and mask from a list of Biopython Residues."""
  num_res = len(mmcif_object.chain_to_seqres[auth_chain_id])

  relevant_chains = [c for c in mmcif_object.structure.get_chains()
                     if c.id == auth_chain_id]
  if len(relevant_chains) != 1:
    raise MultipleChainsError(
        f'Expected exactly one chain in structure with id {auth_chain_id}.')
  chain = relevant_chains[0]

  all_positions = np.zeros([num_res, atom_type_num, 3])
  all_positions_mask = np.zeros([num_res, atom_type_num],
                                dtype=np.int64)
  for res_index in range(num_res):
    pos = np.zeros([atom_type_num, 3], dtype=np.float32)
    mask = np.zeros([atom_type_num], dtype=np.float32)
    res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index]
    if not res_at_position.is_missing:
      res = chain[(res_at_position.hetflag,
                   res_at_position.position.residue_number,
                   res_at_position.position.insertion_code)]
      for atom in res.get_atoms():
        atom_name = atom.get_name()
        x, y, z = atom.get_coord()
        if atom_name in atom_order.keys():
          pos[atom_order[atom_name]] = [x, y, z]
          mask[atom_order[atom_name]] = 1.0
        elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
          # Put the coordinates of the selenium atom in the sulphur column.
          pos[atom_order['SD']] = [x, y, z]
          mask[atom_order['SD']] = 1.0

      # Fix naming errors in arginine residues where NH2 is incorrectly
      # assigned to be closer to CD than NH1.
      cd = atom_order['CD']
      nh1 = atom_order['NH1']
      nh2 = atom_order['NH2']
      if (res.get_resname() == 'ARG' and
          all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
          (np.linalg.norm(pos[nh1] - pos[cd]) >
           np.linalg.norm(pos[nh2] - pos[cd]))):
        pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
        mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()

    all_positions[res_index] = pos
    all_positions_mask[res_index] = mask
  _check_residue_distances(
      all_positions, all_positions_mask, max_ca_ca_distance)
  return all_positions, all_positions_mask


def _get_protein_chains(
    *, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]:
  """Extracts polymer information for protein chains only.

  Args:
    parsed_info: _mmcif_dict produced by the Biopython parser.

  Returns:
    A dict mapping mmcif chain id to a list of Monomers.
  """
  # Get polymer information for each entity in the structure.
  entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info)

  polymers = collections.defaultdict(list)
  for entity_poly_seq in entity_poly_seqs:
    polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append(
        Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'],
                num=int(entity_poly_seq['_entity_poly_seq.num'])))

  # Get chemical compositions. Will allow us to identify which of these polymers
  # are proteins.
  chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info)

  # Get chains information for each entity. Necessary so that we can return a
  # dict keyed on chain id rather than entity.
  struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info)

  entity_to_mmcif_chains = collections.defaultdict(list)
  for struct_asym in struct_asyms:
    chain_id = struct_asym['_struct_asym.id']
    entity_id = struct_asym['_struct_asym.entity_id']
    entity_to_mmcif_chains[entity_id].append(chain_id)

  # Identify and return the valid protein chains.
  valid_chains = {}
  for entity_id, seq_info in polymers.items():
    chain_ids = entity_to_mmcif_chains[entity_id]

    # Reject polymers without any peptide-like components, such as DNA/RNA.
    if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type'].lower()
            for monomer in seq_info]):
      for chain_id in chain_ids:
        valid_chains[chain_id] = seq_info
  return valid_chains


def get_release_date(parsed_info: MmCIFDict) -> str:
  """Returns the oldest revision date."""
  revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date']
  return min(revision_dates)


def mmcif_loop_to_list(prefix: str,
                       parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:
  """Extracts loop associated with a prefix from mmCIF data as a list.

  Reference for loop_ in mmCIF:
    http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html

  Args:
    prefix: Prefix shared by each of the data items in the loop.
      e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
      _entity_poly_seq.mon_id. Should include the trailing period.
    parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
      parser.

  Returns:
    Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
  """
  cols = []
  data = []
  for key, value in parsed_info.items():
    if key.startswith(prefix):
      cols.append(key)
      data.append(value)

  assert all([len(xs) == len(data[0]) for xs in data]), (
      'mmCIF error: Not all loops are the same length: %s' % cols)

  return [dict(zip(cols, xs)) for xs in zip(*data)]


def _get_first_model(structure: PdbStructure) -> PdbStructure:
  """Returns the first model in a Biopython structure."""
  return next(structure.get_models())


def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
  """Returns a basic header containing method, release date and resolution."""
  header = {}

  experiments = mmcif_loop_to_list('_exptl.', parsed_info)
  header['structure_method'] = ','.join([
      experiment['_exptl.method'].lower() for experiment in experiments])

  # Note: The release_date here corresponds to the oldest revision. We prefer to
  # use this for dataset filtering over the deposition_date.
  if '_pdbx_audit_revision_history.revision_date' in parsed_info:
    header['release_date'] = get_release_date(parsed_info)
  else:
    logging.warning('Could not determine release_date: %s',
                    parsed_info['_entry.id'])

  header['resolution'] = 0.00
  for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution',
                  '_reflns.d_resolution_high'):
    if res_key in parsed_info:
      try:
        raw_resolution = parsed_info[res_key][0]
        header['resolution'] = float(raw_resolution)
      except ValueError:
        logging.debug('Invalid resolution format: %s', parsed_info[res_key])

  return header


@functools.lru_cache(16, typed=False)
def mmcif_parse(*,
          file_id: str,
          mmcif_string: str,
          catch_all_errors: bool = True) -> ParsingResult:
  """Entry point, parses an mmcif_string.

  Args:
    file_id: A string identifier for this file. Should be unique within the
      collection of files being processed.
    mmcif_string: Contents of an mmCIF file.
    catch_all_errors: If True, all exceptions are caught and error messages are
      returned as part of the ParsingResult. If False exceptions will be allowed
      to propagate.

  Returns:
    A ParsingResult.
  """

  print("in function mmcif_parse")
  errors = {}
  try:
    parser = PDB.MMCIFParser(QUIET=True)
    handle = io.StringIO(mmcif_string)
    full_structure = parser.get_structure('', handle)
    first_model_structure = _get_first_model(full_structure)
    # Extract the _mmcif_dict from the parser, which contains useful fields not
    # reflected in the Biopython structure.
    parsed_info = parser._mmcif_dict  # pylint:disable=protected-access

    #print(f"parsed_info :{parsed_info}")

    # Ensure all values are lists, even if singletons.
    for key, value in parsed_info.items():
      if not isinstance(value, list):
        parsed_info[key] = [value]

    header = _get_header(parsed_info)

    # Determine the protein chains, and their start numbers according to the
    # internal mmCIF numbering scheme (likely but not guaranteed to be 1).
    valid_chains = _get_protein_chains(parsed_info=parsed_info)
    if not valid_chains:
      return ParsingResult(
          None, {(file_id, ''): 'No protein chains found in this file.'})
    seq_start_num = {chain_id: min([monomer.num for monomer in seq])
                     for chain_id, seq in valid_chains.items()}

    # Loop over the atoms for which we have coordinates. Populate two mappings:
    # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
    # the authors / Biopython).
    # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
    mmcif_to_author_chain_id = {}
    seq_to_structure_mappings = {}
    for atom in _get_atom_site_list(parsed_info):
      if atom.model_num != '1':
        # We only process the first model at the moment.
        continue

      mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id

      if atom.mmcif_chain_id in valid_chains:
        hetflag = ' '
        if atom.hetatm_atom == 'HETATM':
          # Water atoms are assigned a special hetflag of W in Biopython. We
          # need to do the same, so that this hetflag can be used to fetch
          # a residue from the Biopython structure by id.
          if atom.residue_name in ('HOH', 'WAT'):
            hetflag = 'W'
          else:
            hetflag = 'H_' + atom.residue_name
        insertion_code = atom.insertion_code
        if not _is_set(atom.insertion_code):
          insertion_code = ' '
        position = ResiduePosition(chain_id=atom.author_chain_id,
                                   residue_number=int(atom.author_seq_num),
                                   insertion_code=insertion_code)
        seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
        current = seq_to_structure_mappings.get(atom.author_chain_id, {})
        current[seq_idx] = ResidueAtPosition(position=position,
                                             name=atom.residue_name,
                                             is_missing=False,
                                             hetflag=hetflag)
        seq_to_structure_mappings[atom.author_chain_id] = current

    # Add missing residue information to seq_to_structure_mappings.
    for chain_id, seq_info in valid_chains.items():
      author_chain = mmcif_to_author_chain_id[chain_id]
      current_mapping = seq_to_structure_mappings[author_chain]
      for idx, monomer in enumerate(seq_info):
        if idx not in current_mapping:
          current_mapping[idx] = ResidueAtPosition(position=None,
                                                   name=monomer.id,
                                                   is_missing=True,
                                                   hetflag=' ')

    author_chain_to_sequence = {}
    for chain_id, seq_info in valid_chains.items():
      author_chain = mmcif_to_author_chain_id[chain_id]
      seq = []
      for monomer in seq_info:
        code = SCOPData.protein_letters_3to1.get(monomer.id, 'X')
        seq.append(code if len(code) == 1 else 'X')
      seq = ''.join(seq)
      author_chain_to_sequence[author_chain] = seq

    mmcif_object = MmcifObject(
        file_id=file_id,
        header=header,
        structure=first_model_structure,
        chain_to_seqres=author_chain_to_sequence,
        seqres_to_structure=seq_to_structure_mappings,
        raw_string=parsed_info)

    return ParsingResult(mmcif_object=mmcif_object, errors=errors)
  except Exception as e:  # pylint:disable=broad-except
    errors[(file_id, '')] = e
    if not catch_all_errors:
      raise
    return ParsingResult(mmcif_object=None, errors=errors)


@functools.lru_cache(16, typed=False)
def _read_file(path):
  with open(path, 'r') as f:
    file_data = f.read()
  return file_data

######## end: 处理mmCIF 格式字符串##########

def _find_template_in_pdb(
    template_chain_id: str,
    template_sequence: str,
    mmcif_object: MmcifObject) -> Tuple[str, str, int]:
  """Tries to find the template chain in the given pdb file.

  This method tries the three following things in order:
    1. Tries if there is an exact match in both the chain ID and the sequence.
       If yes, the chain sequence is returned. Otherwise:
    2. Tries if there is an exact match only in the sequence.
       If yes, the chain sequence is returned. Otherwise:
    3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
       If yes, the chain sequence is returned.
  If none of these succeed, a SequenceNotInTemplateError is thrown.

  Args:
    template_chain_id: The template chain ID.
    template_sequence: The template chain sequence.
    mmcif_object: The PDB object to search for the template in.

  Returns:
    A tuple with:
    * The chain sequence that was found to match the template in the PDB object.
    * The ID of the chain that is being returned.
    * The offset where the template sequence starts in the chain sequence.

  Raises:
    SequenceNotInTemplateError: If no match is found after the steps described
      above.
  """
  # Try if there is an exact match in both the chain ID and the (sub)sequence.
  pdb_id = mmcif_object.file_id
  chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
  if chain_sequence and (template_sequence in chain_sequence):
    logging.info(
        'Found an exact template match %s_%s.', pdb_id, template_chain_id)
    mapping_offset = chain_sequence.find(template_sequence)
    return chain_sequence, template_chain_id, mapping_offset

  # Try if there is an exact match in the (sub)sequence only.
  for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
    if chain_sequence and (template_sequence in chain_sequence):
      logging.info('Found a sequence-only match %s_%s.', pdb_id, chain_id)
      mapping_offset = chain_sequence.find(template_sequence)
      return chain_sequence, chain_id, mapping_offset

  # Return a chain sequence that fuzzy matches (X = wildcard) the template.
  # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
  regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence]
  regex = re.compile(''.join(regex))
  for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
    match = re.search(regex, chain_sequence)
    if match:
      logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id, chain_id)
      mapping_offset = match.start()
      return chain_sequence, chain_id, mapping_offset

  # No hits, raise an error.
  raise SequenceNotInTemplateError(
      'Could not find the template sequence in %s_%s. Template sequence: %s, '
      'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence,
                               mmcif_object.chain_to_seqres))


def _extract_template_features(
    mmcif_object: MmcifObject,
    pdb_id: str,
    mapping: Mapping[int, int],
    template_sequence: str,
    query_sequence: str,
    template_chain_id: str,
    kalign_binary_path: str) -> Tuple[Dict[str, Any], Optional[str]]:
  """Parses atom positions in the target structure and aligns with the query.

  Atoms for each residue in the template structure are indexed to coincide
  with their corresponding residue in the query sequence, according to the
  alignment mapping provided.

  Args:
    mmcif_object: mmcif_parsing.MmcifObject representing the template.
    pdb_id: PDB code for the template.
    mapping: Dictionary mapping indices in the query sequence to indices in
      the template sequence.
    template_sequence: String describing the amino acid sequence for the
      template protein.
    query_sequence: String describing the amino acid sequence for the query
      protein.
    template_chain_id: String ID describing which chain in the structure proto
      should be used.
    kalign_binary_path: The path to a kalign executable used for template
        realignment.

  Returns:
    A tuple with:
    * A dictionary containing the extra features derived from the template
      protein structure.
    * A warning message if the hit was realigned to the actual mmCIF sequence.
      Otherwise None.

  Raises:
    NoChainsError: If the mmcif object doesn't contain any chains.
    SequenceNotInTemplateError: If the given chain id / sequence can't
      be found in the mmcif object.
    QueryToTemplateAlignError: If the actual template in the mmCIF file
      can't be aligned to the query.
    NoAtomDataInTemplateError: If the mmcif object doesn't contain
      atom positions.
    TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
      unmasked residues.
  """
  if mmcif_object is None or not mmcif_object.chain_to_seqres:
    raise NoChainsError('No chains in PDB: %s_%s' % (pdb_id, template_chain_id))

  warning = None
  try:
    seqres, chain_id, mapping_offset = _find_template_in_pdb(
        template_chain_id=template_chain_id,
        template_sequence=template_sequence,
        mmcif_object=mmcif_object)
  except SequenceNotInTemplateError:
    # If PDB70 contains a different version of the template, we use the sequence
    # from the mmcif_object.
    chain_id = template_chain_id
    warning = (
        f'The exact sequence {template_sequence} was not found in '
        f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.')
    logging.warning(warning)
    # This throws an exception if it fails to realign the hit.
    seqres, mapping = _realign_pdb_template_to_query(
        old_template_sequence=template_sequence,
        template_chain_id=template_chain_id,
        mmcif_object=mmcif_object,
        old_mapping=mapping,
        kalign_binary_path=kalign_binary_path)
    logging.info('Sequence in %s_%s: %s successfully realigned to %s',
                 pdb_id, chain_id, template_sequence, seqres)
    # The template sequence changed.
    template_sequence = seqres
    # No mapping offset, the query is aligned to the actual sequence.
    mapping_offset = 0

  try:
    # Essentially set to infinity - we don't want to reject templates unless
    # they're really really bad.
    all_atom_positions, all_atom_mask = _get_atom_positions(
        mmcif_object, chain_id, max_ca_ca_distance=150.0)
  except (CaDistanceError, KeyError) as ex:
    raise NoAtomDataInTemplateError(
        'Could not get atom data (%s_%s): %s' % (pdb_id, chain_id, str(ex))
        ) from ex

  all_atom_positions = np.split(all_atom_positions, all_atom_positions.shape[0])
  all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])

  output_templates_sequence = []
  templates_all_atom_positions = []
  templates_all_atom_masks = []

  for _ in query_sequence:
    # Residues in the query_sequence that are not in the template_sequence:
    templates_all_atom_positions.append(
        np.zeros((atom_type_num, 3)))
    templates_all_atom_masks.append(np.zeros(atom_type_num))
    output_templates_sequence.append('-')

  for k, v in mapping.items():
    template_index = v + mapping_offset
    templates_all_atom_positions[k] = all_atom_positions[template_index][0]
    templates_all_atom_masks[k] = all_atom_masks[template_index][0]
    output_templates_sequence[k] = template_sequence[v]

  # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
  if np.sum(templates_all_atom_masks) < 5:
    raise TemplateAtomMaskAllZerosError(
        'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' %
        (pdb_id, chain_id, min(mapping.values()) + mapping_offset,
         max(mapping.values()) + mapping_offset))

  output_templates_sequence = ''.join(output_templates_sequence)

  templates_aatype = sequence_to_onehot(
      output_templates_sequence, HHBLITS_AA_TO_ID)

  return (
      {
          'template_all_atom_positions': np.array(templates_all_atom_positions),
          'template_all_atom_masks': np.array(templates_all_atom_masks),
          'template_sequence': output_templates_sequence.encode(),
          'template_aatype': np.array(templates_aatype),
          'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
      },
      warning)


def _is_after_cutoff(
    pdb_id: str,
    release_dates: Mapping[str, datetime.datetime],
    release_date_cutoff: Optional[datetime.datetime]) -> bool:
  """Checks if the template date is after the release date cutoff.

  Args:
    pdb_id: 4 letter pdb code.
    release_dates: Dictionary mapping PDB ids to their structure release dates.
    release_date_cutoff: Max release date that is valid for this query.

  Returns:
    True if the template release date is after the cutoff, False otherwise.
  """
  if release_date_cutoff is None:
    raise ValueError('The release_date_cutoff must not be None.')
  if pdb_id in release_dates:
    return release_dates[pdb_id] > release_date_cutoff
  else:
    # Since this is just a quick prefilter to reduce the number of mmCIF files
    # we need to parse, we don't have to worry about returning True here.
  return False

def _build_query_to_hit_index_mapping(
    hit_query_sequence: str,
    hit_sequence: str,
    indices_hit: Sequence[int],
    indices_query: Sequence[int],
    original_query_sequence: str) -> Mapping[int, int]:
  """Gets mapping from indices in original query sequence to indices in the hit.

  hit_query_sequence and hit_sequence are two aligned sequences containing gap
  characters. hit_query_sequence contains only the part of the original query
  sequence that matched the hit. When interpreting the indices from the .hhr, we
  need to correct for this to recover a mapping from original query sequence to
  the hit sequence.

  Args:
    hit_query_sequence: The portion of the query sequence that is in the .hhr
      hit
    hit_sequence: The portion of the hit sequence that is in the .hhr
    indices_hit: The indices for each aminoacid relative to the hit sequence
    indices_query: The indices for each aminoacid relative to the original query
      sequence
    original_query_sequence: String describing the original query sequence.

  Returns:
    Dictionary with indices in the original query sequence as keys and indices
    in the hit sequence as values.
  """
  # If the hit is empty (no aligned residues), return empty mapping
  if not hit_query_sequence:
    return {}

  # Remove gaps and find the offset of hit.query relative to original query.
  hhsearch_query_sequence = hit_query_sequence.replace('-', '')
  hit_sequence = hit_sequence.replace('-', '')
  hhsearch_query_offset = original_query_sequence.find(hhsearch_query_sequence)

  # Index of -1 used for gap characters. Subtract the min index ignoring gaps.
  min_idx = min(x for x in indices_hit if x > -1)
  fixed_indices_hit = [
      x - min_idx if x > -1 else -1 for x in indices_hit
  ]

  min_idx = min(x for x in indices_query if x > -1)
  fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]

  # Zip the corrected indices, ignore case where both seqs have gap characters.
  mapping = {}
  for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
    if q_t != -1 and q_i != -1:
      if (q_t >= len(hit_sequence) or
          q_i + hhsearch_query_offset >= len(original_query_sequence)):
        continue
      mapping[q_i + hhsearch_query_offset] = q_t

  return mapping


def _process_single_hit(
    query_sequence: str,
    hit: TemplateHit,
    mmcif_dir: str,
    max_template_date: datetime.datetime,
    release_dates: Mapping[str, datetime.datetime],
    obsolete_pdbs: Mapping[str, Optional[str]],
    kalign_binary_path: str,
    strict_error_check: bool = False) -> SingleHitResult:
  """Tries to extract template features from a single HHSearch hit."""
  
  print("in function _process_single_hit")

  #print(f"release_dates:{release_dates}")
  #print(f"obsolete_pdbs:{obsolete_pdbs}")


  # Fail hard if we can't get the PDB ID and chain name from the hit.
  hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)

  #print(f"hit_pdb_code {hit_pdb_code}")
  #print(f"hit_chain_id {hit_chain_id}")


  # This hit has been removed (obsoleted) from PDB, skip it.
  if hit_pdb_code in obsolete_pdbs and obsolete_pdbs[hit_pdb_code] is None:
    return SingleHitResult(
        features=None, error=None, warning=f'Hit {hit_pdb_code} is obsolete.')

  if hit_pdb_code not in release_dates:
    if hit_pdb_code in obsolete_pdbs:
      hit_pdb_code = obsolete_pdbs[hit_pdb_code]

  # Pass hit_pdb_code since it might have changed due to the pdb being obsolete.
  try:
    _assess_hhsearch_hit(
        hit=hit,
        hit_pdb_code=hit_pdb_code,
        query_sequence=query_sequence,
        release_dates=release_dates,
        release_date_cutoff=max_template_date)
  except PrefilterError as e:
    msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}'
    
    print("got PrefilterError")
    print(msg)

    logging.info(msg)
    if strict_error_check and isinstance(e, (DateError, DuplicateError)):
      # In strict mode we treat some prefilter cases as errors.
      return SingleHitResult(features=None, error=msg, warning=None)

    return SingleHitResult(features=None, error=None, warning=None)

  mapping = _build_query_to_hit_index_mapping(
      hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query,
      query_sequence)

  # The mapping is from the query to the actual hit sequence, so we need to
  # remove gaps (which regardless have a missing confidence score).
  template_sequence = hit.hit_sequence.replace('-', '')

  cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
  logging.debug('Reading PDB entry from %s. Query: %s, template: %s', cif_path,
                query_sequence, template_sequence)
  # Fail if we can't find the mmCIF file.
  cif_string = _read_file(cif_path)

  parsing_result = mmcif_parse(
      file_id=hit_pdb_code, mmcif_string=cif_string)

  #print(f"cif_string:{cif_string}")
  #print(f"parsing_result:{parsing_result}")

  if parsing_result.mmcif_object is not None:
    hit_release_date = datetime.datetime.strptime(
        parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d')
    if hit_release_date > max_template_date:
      error = ('Template %s date (%s) > max template date (%s).' %
               (hit_pdb_code, hit_release_date, max_template_date))
      if strict_error_check:
        return SingleHitResult(features=None, error=error, warning=None)
      else:
        logging.debug(error)
        return SingleHitResult(features=None, error=None, warning=None)

  try:
    features, realign_warning = _extract_template_features(
        mmcif_object=parsing_result.mmcif_object,
        pdb_id=hit_pdb_code,
        mapping=mapping,
        template_sequence=template_sequence,
        query_sequence=query_sequence,
        template_chain_id=hit_chain_id,
        kalign_binary_path=kalign_binary_path)
    if hit.sum_probs is None:
      features['template_sum_probs'] = [0]
    else:
      features['template_sum_probs'] = [hit.sum_probs]

    # It is possible there were some errors when parsing the other chains in the
    # mmCIF file, but the template features for the chain we want were still
    # computed. In such case the mmCIF parsing errors are not relevant.
    return SingleHitResult(
        features=features, error=None, warning=realign_warning)
  except (NoChainsError, NoAtomDataInTemplateError,
          TemplateAtomMaskAllZerosError) as e:
    # These 3 errors indicate missing mmCIF experimental data rather than a
    # problem with the template search, so turn them into warnings.
    warning = ('%s_%s (sum_probs: %s, rank: %s): feature extracting errors: '
               '%s, mmCIF parsing errors: %s'
               % (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
                  str(e), parsing_result.errors))
    if strict_error_check:
      return SingleHitResult(features=None, error=warning, warning=None)
    else:
      return SingleHitResult(features=None, error=None, warning=warning)
  except Error as e:
    error = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
             '%s, mmCIF parsing errors: %s'
             % (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
                str(e), parsing_result.errors))
    return SingleHitResult(features=None, error=error, warning=None)
                           

def _get_pdb_id_and_chain(hit: TemplateHit) -> Tuple[str, str]:
  """Returns PDB id and chain id for an HHSearch Hit."""
  # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
  id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name)
  if not id_match:
    raise ValueError(f'hit.name did not start with PDBID_chain: {hit.name}')
  pdb_id, chain_id = id_match.group(0).split('_')
  return pdb_id.lower(), chain_id


def _assess_hhsearch_hit(
    hit: TemplateHit,
    hit_pdb_code: str,
    query_sequence: str,
    release_dates: Mapping[str, datetime.datetime],
    release_date_cutoff: datetime.datetime,
    max_subsequence_ratio: float = 0.95,
    min_align_ratio: float = 0.05) -> bool:
   
    # 默认 min_align_ratio: float = 0.1,为了演示改为0.05
  
  """Determines if template is valid (without parsing the template mmcif file).

  Args:
    hit: HhrHit for the template.
    hit_pdb_code: The 4 letter pdb code of the template hit. This might be
      different from the value in the actual hit since the original pdb might
      have become obsolete.
    query_sequence: Amino acid sequence of the query.
    release_dates: Dictionary mapping pdb codes to their structure release
      dates.
    release_date_cutoff: Max release date that is valid for this query.
    max_subsequence_ratio: Exclude any exact matches with this much overlap.
    min_align_ratio: Minimum overlap between the template and query.

  Returns:
    True if the hit passed the prefilter. Raises an exception otherwise.

  Raises:
    DateError: If the hit date was after the max allowed date.
    AlignRatioError: If the hit align ratio to the query was too small.
    DuplicateError: If the hit was an exact subsequence of the query.
    LengthError: If the hit was too short.
  """
  print("in function _assess_hhsearch_hit")

  aligned_cols = hit.aligned_cols
  align_ratio = aligned_cols / len(query_sequence)
 
  print(f"align_ratio {align_ratio}")

  template_sequence = hit.hit_sequence.replace('-', '')
  length_ratio = float(len(template_sequence)) / len(query_sequence)

  print(f"length_ratio {length_ratio}")


  # Check whether the template is a large subsequence or duplicate of original
  # query. This can happen due to duplicate entries in the PDB database.
  duplicate = (template_sequence in query_sequence and
               length_ratio > max_subsequence_ratio)

  if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
    raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date '
                    f'({release_date_cutoff}).')

  if align_ratio <= min_align_ratio:
    raise AlignRatioError('Proportion of residues aligned to query too small. '
                          f'Align ratio: {align_ratio}.')

  if duplicate:
    raise DuplicateError('Template is an exact subsequence of query with large '
                         f'coverage. Length ratio: {length_ratio}.')

  if len(template_sequence) < 10:
    raise LengthError(f'Template too short. Length: {len(template_sequence)}.')
 
  return True


def _check_residue_distances(all_positions: np.ndarray,
                             all_positions_mask: np.ndarray,
                             max_ca_ca_distance: float):
  """Checks if the distance between unmasked neighbor residues is ok."""
  ca_position = atom_order['CA']
  prev_is_unmasked = False
  prev_calpha = None
  for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
    this_is_unmasked = bool(mask[ca_position])
    if this_is_unmasked:
      this_calpha = coords[ca_position]
      if prev_is_unmasked:
        distance = np.linalg.norm(this_calpha - prev_calpha)
        if distance > max_ca_ca_distance:
          raise CaDistanceError(
              'The distance between residues %d and %d is %f > limit %f.' % (
                  i, i + 1, distance, max_ca_ca_distance))
      prev_calpha = this_calpha
    prev_is_unmasked = this_is_unmasked


def sequence_to_onehot(
    sequence: str,
    mapping: Mapping[str, int],
    map_unknown_to_x: bool = False) -> np.ndarray:
  """Maps the given sequence into a one-hot encoded matrix.

  Args:
    sequence: An amino acid sequence.
    mapping: A dictionary mapping amino acids to integers.
    map_unknown_to_x: If True, any amino acid that is not in the mapping will be
      mapped to the unknown amino acid 'X'. If the mapping doesn't contain
      amino acid 'X', an error will be thrown. If False, any amino acid not in
      the mapping will throw an error.

  Returns:
    A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
    the sequence.

  Raises:
    ValueError: If the mapping doesn't contain values from 0 to
      num_unique_aas - 1 without any gaps.
  """
  num_entries = max(mapping.values()) + 1

  if sorted(set(mapping.values())) != list(range(num_entries)):
    raise ValueError('The mapping must have values from 0 to num_unique_aas-1 '
                     'without any gaps. Got: %s' % sorted(mapping.values()))

  one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)

  for aa_index, aa_type in enumerate(sequence):
    if map_unknown_to_x:
      if aa_type.isalpha() and aa_type.isupper():
        aa_id = mapping.get(aa_type, mapping['X'])
      else:
        raise ValueError(f'Invalid character in the sequence: {aa_type}')
    else:
      aa_id = mapping[aa_type]
    one_hot_arr[aa_index, aa_id] = 1

  return one_hot_arr


class TemplateHitFeaturizer(abc.ABC):
  """An abstract base class for turning template hits to template features."""

  def __init__(
      self,
      mmcif_dir: str,
      max_template_date: str,
      max_hits: int,
      kalign_binary_path: str,
      release_dates_path: Optional[str],
      obsolete_pdbs_path: Optional[str],
      strict_error_check: bool = False):
    """Initializes the Template Search.

    Args:
      mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
        is found by HHSearch, this directory is used to retrieve the template
        data.
      max_template_date: The maximum date permitted for template structures. No
        template with date higher than this date will be returned. In ISO8601
        date format, YYYY-MM-DD.
      max_hits: The maximum number of templates that will be returned.
      kalign_binary_path: The path to a kalign executable used for template
        realignment.
      release_dates_path: An optional path to a file with a mapping from PDB IDs
        to their release dates. Thanks to this we don't have to redundantly
        parse mmCIF files to get that information.
      obsolete_pdbs_path: An optional path to a file containing a mapping from
        obsolete PDB IDs to the PDB IDs of their replacements.
      strict_error_check: If True, then the following will be treated as errors:
        * If any template date is after the max_template_date.
        * If any template has identical PDB ID to the query.
        * If any template is a duplicate of the query.
        * Any feature computation errors.
    """
    self._mmcif_dir = mmcif_dir
    if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')):
      logging.error('Could not find CIFs in %s', self._mmcif_dir)
      raise ValueError(f'Could not find CIFs in {self._mmcif_dir}')

    try:
      self._max_template_date = datetime.datetime.strptime(
          max_template_date, '%Y-%m-%d')
    except ValueError:
      raise ValueError(
          'max_template_date must be set and have format YYYY-MM-DD.')
    self._max_hits = max_hits
    self._kalign_binary_path = kalign_binary_path
    self._strict_error_check = strict_error_check

    if release_dates_path:
      logging.info('Using precomputed release dates %s.', release_dates_path)
      self._release_dates = _parse_release_dates(release_dates_path)
    else:
      self._release_dates = {}

    if obsolete_pdbs_path:
      logging.info('Using precomputed obsolete pdbs %s.', obsolete_pdbs_path)
      self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
    else:
      self._obsolete_pdbs = {}

  @abc.abstractmethod
  def get_templates(
      self,
      query_sequence: str,
      hits: Sequence[TemplateHit]) -> TemplateSearchResult:
    """Computes the templates for given query sequence."""


class HhsearchHitFeaturizer(TemplateHitFeaturizer):
  """A class for turning a3m hits from hhsearch to template features."""

  def get_templates(
      self,
      query_sequence: str,
      hits: Sequence[TemplateHit]) -> TemplateSearchResult:
    """Computes the templates for given query sequence (more details above)."""
    logging.info('Searching for template for: %s', query_sequence)

    template_features = {}
    for template_feature_name in TEMPLATE_FEATURES:
      template_features[template_feature_name] = []

    num_hits = 0
    errors = []
    warnings = []

    #print(f"sorted hits:{sorted(hits, key=lambda x: x.sum_probs, reverse=True)}")

    for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True):
      # We got all the templates we wanted, stop processing hits.
      if num_hits >= self._max_hits:
        break

      result = _process_single_hit(
          query_sequence=query_sequence,
          hit=hit,
          mmcif_dir=self._mmcif_dir,
          max_template_date=self._max_template_date,
          release_dates=self._release_dates,
          obsolete_pdbs=self._obsolete_pdbs,
          strict_error_check=self._strict_error_check,
          kalign_binary_path=self._kalign_binary_path)

    
      #print(f"_process_single_hit result: {result}")

      if result.error:
        errors.append(result.error)

      # There could be an error even if there are some results, e.g. thrown by
      # other unparsable chains in the same mmCIF file.
      if result.warning:
        warnings.append(result.warning)

      if result.features is None:
        logging.info('Skipped invalid hit %s, error: %s, warning: %s',
                     hit.name, result.error, result.warning)
      else:
        # Increment the hit counter, since we got features out of this hit.
        num_hits += 1
        for k in template_features:
          template_features[k].append(result.features[k])

    for name in template_features:
      if num_hits > 0:
        template_features[name] = np.stack(
            template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
      else:
        # Make sure the feature has correct dtype even if empty.
        template_features[name] = np.array([], dtype=TEMPLATE_FEATURES[name])

    return TemplateSearchResult(
        features=template_features, errors=errors, warnings=warnings)


class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
  """A class for turning a3m hits from hmmsearch to template features."""

  def get_templates(
      self,
      query_sequence: str,
      hits: Sequence[TemplateHit]) -> TemplateSearchResult:
    """Computes the templates for given query sequence (more details above)."""
    logging.info('Searching for template for: %s', query_sequence)

    template_features = {}
    for template_feature_name in TEMPLATE_FEATURES:
      template_features[template_feature_name] = []

    already_seen = set()
    errors = []
    warnings = []

    if not hits or hits[0].sum_probs is None:
      sorted_hits = hits
    else:
      sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True)

   
    #print(f"sorted_hits:{sorted_hits}")

    for hit in sorted_hits:
      # We got all the templates we wanted, stop processing hits.
      if len(already_seen) >= self._max_hits:
        break

      result = _process_single_hit(
          query_sequence=query_sequence,
          hit=hit,
          mmcif_dir=self._mmcif_dir,
          max_template_date=self._max_template_date,
          release_dates=self._release_dates,
          obsolete_pdbs=self._obsolete_pdbs,
          strict_error_check=self._strict_error_check,
          kalign_binary_path=self._kalign_binary_path)

      if result.error:
        errors.append(result.error)

      # There could be an error even if there are some results, e.g. thrown by
      # other unparsable chains in the same mmCIF file.
      if result.warning:
        warnings.append(result.warning)

      if result.features is None:
        logging.debug('Skipped invalid hit %s, error: %s, warning: %s',
                      hit.name, result.error, result.warning)
      else:
        already_seen_key = result.features['template_sequence']
        if already_seen_key in already_seen:
          continue
        # Increment the hit counter, since we got features out of this hit.
        already_seen.add(already_seen_key)
        for k in template_features:
          template_features[k].append(result.features[k])

    if already_seen:
      for name in template_features:
        template_features[name] = np.stack(
            template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
    else:
      num_res = len(query_sequence)
      # Construct a default template with all zeros.
      
      print("Construct a default template with all zeros.")
      
      template_features = {
          'template_aatype': np.zeros(
              (1, num_res, len(residue_constants.restypes_with_x_and_gap)),
              np.float32),
          'template_all_atom_masks': np.zeros(
              (1, num_res, residue_constants.atom_type_num), np.float32),
          'template_all_atom_positions': np.zeros(
              (1, num_res, residue_constants.atom_type_num, 3), np.float32),
          'template_domain_names': np.array([''.encode()], dtype=object),
          'template_sequence': np.array([''.encode()], dtype=object),
          'template_sum_probs': np.array([0], dtype=np.float32)
      }
    return TemplateSearchResult(
        features=template_features, errors=errors, warnings=warnings)



### Hhsearch软件搜索pdb结构数据库得到的模版特征提取
import pickle
with open('test_pdb_hits.pkl', 'rb') as file:
#with open('/home/zheng/test/test_pdb_hits.pkl', 'rb') as file:
  # Load the data from the file
  pdb_template_hits = pickle.load(file)

pdb_template_hits = pdb_template_hits[0:5] # # 取部分演示数据
#print(type(pdb_template_hits))
#print(f"pdb_template_hits:{pdb_template_hits}")

## 根据pdb_template_hits结果,下载mmcif文件到制定目录
pdb_ids = []
for hit in pdb_template_hits:
   # name='5UXX_C BaquA.17208.a, BaquA.17842.a; SSGCID, Bartonella quintana, sigma factor; HET: SO4, MSE; 2.45A {Bartonella quintana}'
   pdb_id = hit.name.split()[0]
   pdb_id = pdb_id.split("_")[0]
   pdb_ids.append(pdb_id)

from Bio.PDB import PDBList
import os

# 创建PDBList对象
pdbl = PDBList()

# 设置下载目录
template_mmcif_dir = "/home/zheng/test/mmcif"

print("开始下载mmcif文件")
## 批量下载结构数据
for pdb_id in pdb_ids:
   pdbl.retrieve_pdb_file(pdb_code = pdb_id, 
                          pdir = template_mmcif_dir, 
                          file_format = 'mmCif')

print(f"mmCIF file downloaded to: {template_mmcif_dir}")
"""
"""

max_template_date = "2023-11-27"  # format YYYY-MM-DD 
#max_template_date = datetime.datetime.strptime(
#          max_template_date, '%Y-%m-%d')

MAX_TEMPLATE_HITS = 3
kalign_binary_path = "home/zheng/anaconda3/envs/deep_learning/bin/kalign"
#print(max_template_date)


# 实例化HhsearchHitFeaturizer类
#template_featurizer = HhsearchHitFeaturizer(mmcif_dir=template_mmcif_dir,
#                                            max_template_date=max_template_date,
#                                            max_hits=MAX_TEMPLATE_HITS,
#                                            kalign_binary_path=kalign_binary_path,
#                                            release_dates_path=None,
#                                            obsolete_pdbs_path=None)
# 实例化HmmsearchHitFeaturizer类
template_featurizer = HmmsearchHitFeaturizer(mmcif_dir=template_mmcif_dir,
                                             max_template_date=max_template_date,
                                             max_hits=MAX_TEMPLATE_HITS,
                                             kalign_binary_path=kalign_binary_path,
                                             release_dates_path=None,
                                             obsolete_pdbs_path=None)


print(template_featurizer)

## 输入序列
input_fasta_file = '/home/zheng/test/Q94K49.fasta'
## 从fasta文件提取 query_sequence(str格式)
input_sequence = ""
with open(input_fasta_file) as f:
  for line in f.readlines():
    if line.startswith(">"):
      continue
    input_sequence += line.strip()

templates_result = template_featurizer.get_templates(query_sequence=input_sequence,
                                                     hits=pdb_template_hits)
print(f"templates_result.errors: {templates_result.errors}")
print(f"templates_result.warnings:{templates_result.warnings}")

print(f"输入序列为:{input_sequence} 长度为:{len(input_sequence)}")

for k, v in templates_result.features.items():
  print(k)
  print(f"值的类型为:{type(v)}")
  print(f"值的维度为:{v.shape}")
  print(v)

#print(f"[2,120,:]:{templates_result.features['template_all_atom_positions'][2,120,     :]}")

你可能感兴趣的:(生物信息学)