多次运用集成函数处理蛋白质特征数据

wrap_ensemble_fn函数实现对蛋白质特征的多次集成函数操作(如多次抽样sample_msa等)

主要函数 

tree.map_structure:
fn_output_signature = tree.map_structure( tf.TensorSpec.from_tensor, tensors_0)

tf.map_fn:用于在张量的每个元素上应用一个函数。
tensors = tf.map_fn(
    lambda x: wrap_ensemble_fn(tensors, x),
    tf.range(num_ensemble),
    parallel_iterations=1,
    fn_output_signature=fn_output_signature)
import copy
import tensorflow.compat.v1 as tf
import tree
import pickle
import numpy as np
import ml_collections
 
NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
NUM_EXTRA_SEQ = 'extra msa placeholder'
NUM_TEMPLATES = 'num templates placeholder'
 
CONFIG = ml_collections.ConfigDict({
    'data': {
        'common': {
            'masked_msa': {
                'profile_prob': 0.1,
                'same_prob': 0.1,
                'uniform_prob': 0.1
            },
            'max_extra_msa': 1024,
            'msa_cluster_features': True,
            'num_recycle': 3,
            'reduce_msa_clusters_by_max_templates': False,
            'resample_msa_in_recycling': True,
            'template_features': [
                'template_all_atom_positions', 'template_sum_probs',
                'template_aatype', 'template_all_atom_masks',
                'template_domain_names'
            ],
            'unsupervised_features': [
                'aatype', 'residue_index', 'sequence', 'msa', 'domain_name',
                'num_alignments', 'seq_length', 'between_segment_residues',
                'deletion_matrix'
            ],
            'use_templates': False,
        },
        'eval': {
            'feat': {
                'aatype': [NUM_RES],
                'all_atom_mask': [NUM_RES, None],
                'all_atom_positions': [NUM_RES, None, None],
                'alt_chi_angles': [NUM_RES, None],
                'atom14_alt_gt_exists': [NUM_RES, None],
                'atom14_alt_gt_positions': [NUM_RES, None, None],
                'atom14_atom_exists': [NUM_RES, None],
                'atom14_atom_is_ambiguous': [NUM_RES, None],
                'atom14_gt_exists': [NUM_RES, None],
                'atom14_gt_positions': [NUM_RES, None, None],
                'atom37_atom_exists': [NUM_RES, None],
                'backbone_affine_mask': [NUM_RES],
                'backbone_affine_tensor': [NUM_RES, None],
                'bert_mask': [NUM_MSA_SEQ, NUM_RES],
                'chi_angles': [NUM_RES, None],
                'chi_mask': [NUM_RES, None],
                'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_msa_row_mask': [NUM_EXTRA_SEQ],
                'is_distillation': [],
                'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
                'msa_mask': [NUM_MSA_SEQ, NUM_RES],
                'msa_row_mask': [NUM_MSA_SEQ],
                'pseudo_beta': [NUM_RES, None],
                'pseudo_beta_mask': [NUM_RES],
                'random_crop_to_size_seed': [None],
                'residue_index': [NUM_RES],
                'residx_atom14_to_atom37': [NUM_RES, None],
                'residx_atom37_to_atom14': [NUM_RES, None],
                'resolution': [],
                'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
                'rigidgroups_group_exists': [NUM_RES, None],
                'rigidgroups_group_is_ambiguous': [NUM_RES, None],
                'rigidgroups_gt_exists': [NUM_RES, None],
                'rigidgroups_gt_frames': [NUM_RES, None, None],
                'seq_length': [],
                'seq_mask': [NUM_RES],
                'target_feat': [NUM_RES, None],
                'template_aatype': [NUM_TEMPLATES, NUM_RES],
                'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
                'template_all_atom_positions': [
                    NUM_TEMPLATES, NUM_RES, None, None],
                'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
                'template_backbone_affine_tensor': [
                    NUM_TEMPLATES, NUM_RES, None],
                'template_mask': [NUM_TEMPLATES],
                'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
                'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
                'template_sum_probs': [NUM_TEMPLATES, None],
                'true_msa': [NUM_MSA_SEQ, NUM_RES]
            },
            'fixed_size': True,
            'subsample_templates': True,  # We want top templates.
            'masked_msa_replace_fraction': 0.15,
            'max_msa_clusters': 512,
            'max_templates': 4,
            'num_ensemble': 1,
            'crop_size': 100,
        },
    },
    'model': {
        'embeddings_and_evoformer': {
            'evoformer_num_block': 48,
            'evoformer': {
                'msa_row_attention_with_pair_bias': {
                    'dropout_rate': 0.15,
                    'gating': True,
                    'num_head': 8,
                    'orientation': 'per_row',
                    'shared_dropout': True
                },
                'msa_column_attention': {
                    'dropout_rate': 0.0,
                    'gating': True,
                    'num_head': 8,
                    'orientation': 'per_column',
                    'shared_dropout': True
                },
                'msa_transition': {
                    'dropout_rate': 0.0,
                    'num_intermediate_factor': 4,
                    'orientation': 'per_row',
                    'shared_dropout': True
                },
                'outer_product_mean': {
                    'first': False,
                    'chunk_size': 128,
                    'dropout_rate': 0.0,
                    'num_outer_channel': 32,
                    'orientation': 'per_row',
                    'shared_dropout': True
                },
                'triangle_attention_starting_node': {
                    'dropout_rate': 0.25,
                    'gating': True,
                    'num_head': 4,
                    'orientation': 'per_row',
                    'shared_dropout': True
                },
                'triangle_attention_ending_node': {
                    'dropout_rate': 0.25,
                    'gating': True,
                    'num_head': 4,
                    'orientation': 'per_column',
                    'shared_dropout': True
                },
                'triangle_multiplication_outgoing': {
                    'dropout_rate': 0.25,
                    'equation': 'ikc,jkc->ijc',
                    'num_intermediate_channel': 128,
                    'orientation': 'per_row',
                    'shared_dropout': True,
                    'fuse_projection_weights': False,
                },
                'triangle_multiplication_incoming': {
                    'dropout_rate': 0.25,
                    'equation': 'kjc,kic->ijc',
                    'num_intermediate_channel': 128,
                    'orientation': 'per_row',
                    'shared_dropout': True,
                    'fuse_projection_weights': False,
                },
                'pair_transition': {
                    'dropout_rate': 0.0,
                    'num_intermediate_factor': 4,
                    'orientation': 'per_row',
                    'shared_dropout': True
                }
            },
            'extra_msa_channel': 64,
            'extra_msa_stack_num_block': 4,
            'max_relative_feature': 32,
            'msa_channel': 256,
            'pair_channel': 128,
            'prev_pos': {
                'min_bin': 3.25,
                'max_bin': 20.75,
                'num_bins': 15
            },
            'recycle_features': True,
            'recycle_pos': True,
            'seq_channel': 384,
            'template': {
                'attention': {
                    'gating': False,
                    'key_dim': 64,
                    'num_head': 4,
                    'value_dim': 64
                },
                'dgram_features': {
                    'min_bin': 3.25,
                    'max_bin': 50.75,
                    'num_bins': 39
                },
                'embed_torsion_angles': False,
                'enabled': False,
                'template_pair_stack': {
                    'num_block': 2,
                    'triangle_attention_starting_node': {
                        'dropout_rate': 0.25,
                        'gating': True,
                        'key_dim': 64,
                        'num_head': 4,
                        'orientation': 'per_row',
                        'shared_dropout': True,
                        'value_dim': 64
                    },
                    'triangle_attention_ending_node': {
                        'dropout_rate': 0.25,
                        'gating': True,
                        'key_dim': 64,
                        'num_head': 4,
                        'orientation': 'per_column',
                        'shared_dropout': True,
                        'value_dim': 64
                    },
                    'triangle_multiplication_outgoing': {
                        'dropout_rate': 0.25,
                        'equation': 'ikc,jkc->ijc',
                        'num_intermediate_channel': 64,
                        'orientation': 'per_row',
                        'shared_dropout': True,
                        'fuse_projection_weights': False,
                    },
                    'triangle_multiplication_incoming': {
                        'dropout_rate': 0.25,
                        'equation': 'kjc,kic->ijc',
                        'num_intermediate_channel': 64,
                        'orientation': 'per_row',
                        'shared_dropout': True,
                        'fuse_projection_weights': False,
                    },
                    'pair_transition': {
                        'dropout_rate': 0.0,
                        'num_intermediate_factor': 2,
                        'orientation': 'per_row',
                        'shared_dropout': True
                    }
                },
                'max_templates': 4,
                'subbatch_size': 128,
                'use_template_unit_vector': False,
            }
        },
        'global_config': {
            'deterministic': False,
            'multimer_mode': False,
            'subbatch_size': 4,
            'use_remat': False,
            'zero_init': True,
            'eval_dropout': False,
        },
        'heads': {
            'distogram': {
                'first_break': 2.3125,
                'last_break': 21.6875,
                'num_bins': 64,
                'weight': 0.3
            },
            'predicted_aligned_error': {
                # `num_bins - 1` bins uniformly space the
                # [0, max_error_bin A] range.
                # The final bin covers [max_error_bin A, +infty]
                # 31A gives bins with 0.5A width.
                'max_error_bin': 31.,
                'num_bins': 64,
                'num_channels': 128,
                'filter_by_resolution': True,
                'min_resolution': 0.1,
                'max_resolution': 3.0,
                'weight': 0.0,
            },
            'experimentally_resolved': {
                'filter_by_resolution': True,
                'max_resolution': 3.0,
                'min_resolution': 0.1,
                'weight': 0.01
            },
            'structure_module': {
                'num_layer': 8,
                'fape': {
                    'clamp_distance': 10.0,
                    'clamp_type': 'relu',
                    'loss_unit_distance': 10.0
                },
                'angle_norm_weight': 0.01,
                'chi_weight': 0.5,
                'clash_overlap_tolerance': 1.5,
                'compute_in_graph_metrics': True,
                'dropout': 0.1,
                'num_channel': 384,
                'num_head': 12,
                'num_layer_in_transition': 3,
                'num_point_qk': 4,
                'num_point_v': 8,
                'num_scalar_qk': 16,
                'num_scalar_v': 16,
                'position_scale': 10.0,
                'sidechain': {
                    'atom_clamp_distance': 10.0,
                    'num_channel': 128,
                    'num_residual_block': 2,
                    'weight_frac': 0.5,
                    'length_scale': 10.,
                },
                'structural_violation_loss_weight': 1.0,
                'violation_tolerance_factor': 12.0,
                'weight': 1.0
            },
            'predicted_lddt': {
                'filter_by_resolution': True,
                'max_resolution': 3.0,
                'min_resolution': 0.1,
                'num_bins': 50,
                'num_channels': 128,
                'weight': 0.01
            },
            'masked_msa': {
                'num_output': 23,
                'weight': 2.0
            },
        },
        'num_recycle': 3,
        'resample_msa_in_recycling': True
    },
})
 
 
_MSA_FEATURE_NAMES = [
    'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask',
    'true_msa'
]
 
 
class SeedMaker(object):
  """Return unique seeds."""
 
  def __init__(self, initial_seed=0):
    self.next_seed = initial_seed
 
  def __call__(self):
    i = self.next_seed
    self.next_seed += 1
    return i
 
 
def shape_list(x):
  """Return list of dimensions of a tensor, statically where possible.
  Like `x.shape.as_list()` but with tensors instead of `None`s.
  Args:
    x: A tensor.
  Returns:
    A list with length equal to the rank of the tensor. The n-th element of the
    list is an integer when that dimension is statically known otherwise it is
    the n-th element of `tf.shape(x)`.
  """
  x = tf.convert_to_tensor(x)
 
  # If unknown rank, return dynamic shape
  if x.get_shape().dims is None:
    return tf.shape(x)
 
  static = x.get_shape().as_list()
  shape = tf.shape(x)
 
  ret = []
  for i in range(len(static)):
    dim = static[i]
    if dim is None:
      dim = shape[i]
    ret.append(dim)
  return ret
 
 
def shaped_categorical(probs, epsilon=1e-10):
  ds = shape_list(probs)
  num_classes = ds[-1]
  counts = tf.random.categorical(
      tf.reshape(tf.log(probs + epsilon), [-1, num_classes]),
      1,
      dtype=tf.int32)
  return tf.reshape(counts, ds[:-1])
 
 
def data_transforms_curry1(f):
  """Supply all arguments but the first."""
 
  def fc(*args, **kwargs):
    return lambda x: f(x, *args, **kwargs)
 
  return fc
 
 
 
@data_transforms_curry1
def sample_msa(protein, max_seq, keep_extra):
  """Sample MSA randomly, remaining sequences are stored as `extra_*`.
  Args:
    protein: batch to sample msa from.
    max_seq: number of sequences to sample.
    keep_extra: When True sequences not sampled are put into fields starting
      with 'extra_*'.
  Returns:
    Protein with sampled msa.
  """
  num_seq = tf.shape(protein['msa'])[0]
  # 索引0的序列为查询序列
  shuffled = tf.random_shuffle(tf.range(1, num_seq))
  index_order = tf.concat([[0], shuffled], axis=0)
  num_sel = tf.minimum(max_seq, num_seq)
  # tf.split函数将张量沿指定轴进行切分,
  # 第一张量大小为num_sel,第二张量大小为num_seq - num_sel
  sel_seq, not_sel_seq = tf.split(index_order, [num_sel, num_seq - num_sel])
 
  for k in _MSA_FEATURE_NAMES:
    if k in protein:
      if keep_extra:
        # tf.gather 按索引从输入张量中收集元素的函数
          protein['extra_' + k] = tf.gather(protein[k], not_sel_seq)
      protein[k] = tf.gather(protein[k], sel_seq)
 
  return protein
 
 
@data_transforms_curry1
def make_masked_msa(protein, config, replace_fraction):
  """Create data for BERT on raw MSA."""
  # Add a random amino acid uniformly
  random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32)
  # 构建随机随机出现某一氨基酸的概率,和MSA中氨基酸的保守性有关
  categorical_probs = (
      config.uniform_prob * random_aa +
      config.profile_prob * protein['hhblits_profile'] +
      config.same_prob * tf.one_hot(protein['msa'], 22))
 
  #print(tf.reduce_sum(categorical_probs, axis=-1))  # 都为0.3
 
  # Put all remaining probability on [MASK] which is a new column
 
  pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
  pad_shapes[-1][1] = 1
  # mask_prob : 0.7, 其他prob加在一起0.3
  mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
  assert mask_prob >= 0.
  # categorical_probs张量后填充mask_prob值,代表MSA每一个位置的概率(20种氨基酸+gap+X+mask)
  categorical_probs = tf.pad(
      categorical_probs, pad_shapes, constant_values=mask_prob)
 
  #print(tf.reduce_sum(categorical_probs, axis=-1))  # 都为0.3
 
  sh = shape_list(protein['msa'])
  # 0-1均匀分布中随机抽样,形状为sh,通过和replace_fraction(0.15)比较,产生随机mask位置
  mask_position = tf.random.uniform(sh) < replace_fraction
  
  ##抽样,注意随机性产生的方式,抽到mask概率最大,而抽到其他氨基酸概率的大小和其在MSA中的保守性有关
  bert_msa = shaped_categorical(categorical_probs)
  ## 大概0.15的概率用随机氨基酸代替,随机氨基酸中有0.7的概率是mask,还有0.3的概率抽到其他氨基酸,
  ## 氨基酸在此位置越保守,抽到的可能性越大
  ## bert_msa中大概有0.7*0.15的mask,还有混杂着错误和正确的氨基酸
  bert_msa = tf.where(mask_position, bert_msa, protein['msa'])
 
  # Mix real and masked MSA
  protein['bert_mask'] = tf.cast(mask_position, tf.float32)
  protein['true_msa'] = protein['msa']
  protein['msa'] = bert_msa
 
  return protein
 
 
@data_transforms_curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
  """Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
 
  # Determine how much weight we assign to each agreement.  In theory, we could
  # use a full blosum matrix here, but right now let's just down-weight gap
  # agreement because it could be spurious.
  # Never put weight on agreeing on BERT mask
  # 除了gap权重为0,其他(restype+X+mask)权重为1
  weights = tf.concat([
      tf.ones(21),
      gap_agreement_weight * tf.ones(1),
      np.zeros(1)], 0)
 
  # Make agreement score as weighted Hamming distance
  # 增加一个维度
  sample_one_hot = (protein['msa_mask'][:, :, None] *
                    tf.one_hot(protein['msa'], 23))
  extra_one_hot = (protein['extra_msa_mask'][:, :, None] *
                   tf.one_hot(protein['extra_msa'], 23))
 
  num_seq, num_res, _ = shape_list(sample_one_hot)
  extra_num_seq, _, _ = shape_list(extra_one_hot)
 
  # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
  # in an optimized fashion to avoid possible memory or computation blowup.
  # 判断extra msa序列与MSA sample序列的相似度,相同的氨基酸越多,越相似。
  # 没有考虑氨基酸的性质,可以改进
  # 注意氨基酸的权重(weights)
  agreement = tf.matmul(
      tf.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
      tf.reshape(sample_one_hot * weights, [num_seq, num_res * 23]),
      transpose_b=True)
 
  # Assign each sequence in the extra sequences to the closest MSA sample
  # 对extra msa中每一条序列,取相似度最高的MSA sample序列
  protein['extra_cluster_assignment'] = tf.argmax(
      agreement, axis=1, output_type=tf.int32)
 
  return protein
@data_transforms_curry1
def summarize_clusters(protein):
  """Produce profile and deletion_matrix_mean within each cluster."""
  num_seq = shape_list(protein['msa'])[0]
  def csum(x):
    return tf.math.unsorted_segment_sum(
        x, protein['extra_cluster_assignment'], num_seq)
 
  mask = protein['extra_msa_mask']
  mask_counts = 1e-6 + protein['msa_mask'] + csum(mask)  # Include center
  
  # 结果张量[num_seq, num_resi],第一行表示和msa中的0号序列是最近邻序列的extr_msa之和,以此类推
  msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23))
  msa_sum += tf.one_hot(protein['msa'], 23)  # Original sequence
  protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]
 
  del msa_sum
 
  # 每条msa序列的最近邻序列的extr_msa,在不同位置deletion数统计
  # del_sum [num_seq, num_resi],第一行表示和msa中的0号序列是最近邻序列的extr_msa,不同位置deletion数,以此类推
  del_sum = csum(mask * protein['extra_deletion_matrix'])
  del_sum += protein['deletion_matrix']  # Original sequence
  protein['cluster_deletion_mean'] = del_sum / mask_counts
  del del_sum
 
  return protein
@data_transforms_curry1
def crop_extra_msa(protein, max_extra_msa):
  """MSA features are cropped so only `max_extra_msa` sequences are kept."""
  num_seq = tf.shape(protein['extra_msa'])[0]
  num_sel = tf.minimum(max_extra_msa, num_seq)
  select_indices = tf.random_shuffle(tf.range(0, num_seq))[:num_sel]
  for k in _MSA_FEATURE_NAMES:
    if 'extra_' + k in protein:
      protein['extra_' + k] = tf.gather(protein['extra_' + k], select_indices)
  return protein
@data_transforms_curry1
def make_msa_feat(protein):
  """Create and concatenate MSA features."""
  # Whether there is a domain break. Always zero for chains, but keeping
  # for compatibility with domain datasets.
  has_break = tf.clip_by_value(
      tf.cast(protein['between_segment_residues'], tf.float32),
      0, 1)
  aatype_1hot = tf.one_hot(protein['aatype'], 21, axis=-1)
  target_feat = [
      tf.expand_dims(has_break, axis=-1),
      aatype_1hot,  # Everyone gets the original sequence.
  ]
  msa_1hot = tf.one_hot(protein['msa'], 23, axis=-1)
  has_deletion = tf.clip_by_value(protein['deletion_matrix'], 0., 1.)
  deletion_value = tf.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi)
  msa_feat = [
      msa_1hot,
      tf.expand_dims(has_deletion, axis=-1),
      tf.expand_dims(deletion_value, axis=-1),
  ]
  if 'cluster_profile' in protein:
    deletion_mean_value = (
        tf.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
    msa_feat.extend([
        protein['cluster_profile'],
        tf.expand_dims(deletion_mean_value, axis=-1),
    ])
  if 'extra_deletion_matrix' in protein:
    protein['extra_has_deletion'] = tf.clip_by_value(
        protein['extra_deletion_matrix'], 0., 1.)
    protein['extra_deletion_value'] = tf.atan(
        protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)
  protein['msa_feat'] = tf.concat(msa_feat, axis=-1)
  protein['target_feat'] = tf.concat(target_feat, axis=-1)
  return protein
@data_transforms_curry1
def select_feat(protein, feature_list):
  return {k: v for k, v in protein.items() if k in feature_list}
@data_transforms_curry1
def random_crop_to_size(protein, crop_size, max_templates, shape_schema,
                        subsample_templates=False):
  """Crop randomly to `crop_size`, or keep as is if shorter than that."""
  seq_length = protein['seq_length']
  if 'template_mask' in protein:
    num_templates = tf.cast(
        shape_list(protein['template_mask'])[0], tf.int32)
  else:
    num_templates = tf.constant(0, dtype=tf.int32)
  num_res_crop_size = tf.math.minimum(seq_length, crop_size)
 
  # Ensures that the cropping of residues and templates happens in the same way
  # across ensembling iterations.
  # Do not use for randomness that should vary in ensembling.
  seed_maker = SeedMaker(initial_seed=protein['random_crop_to_size_seed'])
 
  if subsample_templates:
    templates_crop_start = tf.random.stateless_uniform(
        shape=(), minval=0, maxval=num_templates + 1, dtype=tf.int32,
        seed=seed_maker())
  else:
    templates_crop_start = 0
 
  num_templates_crop_size = tf.math.minimum(
      num_templates - templates_crop_start, max_templates)
 
  num_res_crop_start = tf.random.stateless_uniform(
      shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1,
      dtype=tf.int32, seed=seed_maker())
 
  ## 产生随机打乱的索引,用于所有需要裁剪的模版特征
 
  # tf.argsort 函数用于返回张量中元素的排序索引
  # tf.random.stateless_uniform:生成指定形状的服从均匀分布的随机张量
  # 生成num_templates个指定形状的服从均匀分布的随机张量,形状为shape=(num_templates,)。
  # 注:num_templates为标量,作为shape时,变成list[num_templates]
  templates_select_indices = tf.argsort(tf.random.stateless_uniform(
      [num_templates], seed=seed_maker()))
 
  for k, v in protein.items():
    if k not in shape_schema or (
        'template' not in k and NUM_RES not in shape_schema[k]):
      continue
 
    # randomly permute the templates before cropping them.
    if k.startswith('template') and subsample_templates:
      v = tf.gather(v, templates_select_indices)
 
    crop_sizes = []
    crop_starts = []
    
    # zip函数把维度说明和维度值绑定
    # shape_schema[k]维度说明(placeholder)列表 ,shape_list(v)维度值
    for i, (dim_size, dim) in enumerate(zip(shape_schema[k],shape_list(v))):
      is_num_res = (dim_size == NUM_RES)
      if i == 0 and k.startswith('template'):
        crop_size = num_templates_crop_size
        crop_start = templates_crop_start
      else:
        crop_start = num_res_crop_start if is_num_res else 0
        crop_size = (num_res_crop_size if is_num_res else
                     (-1 if dim is None else dim))
      crop_sizes.append(crop_size)
      crop_starts.append(crop_start)
    protein[k] = tf.slice(v, crop_starts, crop_sizes)
 
  protein['seq_length'] = num_res_crop_size
  return protein
@data_transforms_curry1
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size,
                    num_res, num_templates=0):
  """Guess at the MSA and sequence dimensions to make fixed size."""
 
  pad_size_map = {
      NUM_RES: num_res,
      NUM_MSA_SEQ: msa_cluster_size,
      NUM_EXTRA_SEQ: extra_msa_size,
      NUM_TEMPLATES: num_templates,
  }
 
  for k, v in protein.items():
    # Don't transfer this to the accelerator.
    if k == 'extra_cluster_assignment':
      continue
    shape = v.shape.as_list()
    # 特征维度placeholder
    schema = shape_schema[k]
    assert len(shape) == len(schema), (
        f'Rank mismatch between shape and shape schema for {k}: '
        f'{shape} vs {schema}')
    
    # 特征张量不同维度的填充尺寸(pad_size)。需要填充的维度尺寸由pad_size_map决定。
    # 字典get方法,键不存在时返回的None,这时列表取 s1 for (s1, s2) in zip(shape, schema)
    pad_size = [
        pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
    ]
    # 在张量的后面填充,需要填充0的数目为填充尺寸减去现有的尺寸(p - tf.shape(v)[i])
    padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)]
    if padding:
      protein[k] = tf.pad(
          v, padding, name=f'pad_to_fixed_{k}')
      protein[k].set_shape(pad_size)
  return protein
 
 
def ensembled_map_fns(data_config):
  """Input pipeline functions that can be ensembled and averaged."""
  common_cfg = data_config.common
  eval_cfg = data_config.eval
 
  map_fns = []
 
  if common_cfg.reduce_msa_clusters_by_max_templates:
    pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
  else:
    pad_msa_clusters = eval_cfg.max_msa_clusters
 
  max_msa_clusters = pad_msa_clusters
  max_extra_msa = common_cfg.max_extra_msa
 
  map_fns.append(sample_msa(max_msa_clusters,keep_extra=True))
 
  if 'masked_msa' in common_cfg:
    # Masked MSA should come *before* MSA clustering so that
    # the clustering and full MSA profile do not leak information about
    # the masked locations and secret corrupted locations.
    map_fns.append(make_masked_msa(common_cfg.masked_msa,
                                   eval_cfg.masked_msa_replace_fraction))
 
  if common_cfg.msa_cluster_features:
    map_fns.append(nearest_neighbor_clusters())
    
    map_fns.append(summarize_clusters())
    
  # Crop after creating the cluster profiles.
  if max_extra_msa:
    map_fns.append(crop_extra_msa(max_extra_msa))
  else:
    map_fns.append(delete_extra_msa)
 
  map_fns.append(make_msa_feat())
 
  crop_feats = dict(eval_cfg.feat)
 
  if eval_cfg.fixed_size:
    map_fns.append(select_feat(list(crop_feats)))
    map_fns.append(random_crop_to_size(
        eval_cfg.crop_size,
        eval_cfg.max_templates,
        crop_feats,
        eval_cfg.subsample_templates))
    map_fns.append(make_fixed_size(
        crop_feats,
        pad_msa_clusters,
        common_cfg.max_extra_msa,
        eval_cfg.crop_size,
        eval_cfg.max_templates))
  else:
    map_fns.append(crop_templates(eval_cfg.max_templates))
 
  return map_fns
 
 
@data_transforms_curry1
def compose(x, fs):
  for f in fs:
    x = f(x)
  return x
 

### 得到配置数据
data_config = CONFIG.data 
eval_cfg = data_config.eval
common_cfg = data_config.common
 
crop_feats = dict(eval_cfg.feat)
#pad_msa_clusters = eval_cfg.max_msa_clusters
 
shape_schema = crop_feats
num_ensemble = eval_cfg.num_ensemble


def wrap_ensemble_fn(data, i):
  """Function to be mapped over the ensemble dimension."""
  d = data.copy()
  fns = ensembled_map_fns(data_config)
  fn = compose(fns)
  d['ensemble_index'] = i
  return fn(d)


### 读入数据,蛋白质特征已经过nonensembled函数处理 
with open("Human_HBB_tensor_dict_nonensembled.pkl",'rb') as f:
   Human_HBB_tensor = pickle.load(f)
 
protein = copy.deepcopy(Human_HBB_tensor)
 
#加上protein['deletion_matrix']特征,不然会报错
protein['deletion_matrix'] = tf.cast(protein['deletion_matrix_int'], dtype=tf.float32) 
 
protein_0 = wrap_ensemble_fn(protein, tf.constant(0))

if data_config.common.resample_msa_in_recycling:
  # Separate batch per ensembling & recycling step.
  num_ensemble *= data_config.common.num_recycle + 1
 

if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1:
  fn_output_signature = tree.map_structure(
        tf.TensorSpec.from_tensor, protein_0)
  #tf.map_fn 在处理两个结构不具有相同嵌套结构的情况时,
  #可以使用 fn_output_signature 参数来指定输出函数的签名,
  #从而显式指定输出的结构。
  protein = tf.map_fn(
        lambda x: wrap_ensemble_fn(protein, x),
        tf.range(num_ensemble),
        parallel_iterations=1,
        fn_output_signature=fn_output_signature)
else:
  # 增加一个维度
  protein = tree.map_structure(lambda x: x[None],
                                 protein_0)
 
print(f"ensembled函数处理前:")
print(f"特征数:{len(Human_HBB_tensor)}")
print(f"特征:{Human_HBB_tensor.keys()}")
print(Human_HBB_tensor['aatype'].shape)
#print(Human_HBB_tensor['aatype'])
      
print(f"ensembled函数处理后:")
print(f"特征数:{len(protein)}")
print(f"特征:{protein.keys()}")
print(protein['extra_msa'].shape)
print(protein['aatype'].shape)
print(protein['msa_feat'].shape)

print("protein_0['msa_feat'].shape")
print(protein_0['msa_feat'].shape)

你可能感兴趣的:(生物信息学,tensorflow,python)