设计替换概率张量categorical_probs,含有为mask位点以及随机/正确的氨基酸(+gap+X),
从protein['msa'] 中随机选择替换的位点,生成protein['bert_mask']特征,
最后通过tf.where替换生成新的protein['msa'],而原来的protein['msa']替换为 protein['true_msa']。
主要张量操作函数:
import pickle
import tensorflow.compat.v1 as tf
import ml_collections
NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
NUM_EXTRA_SEQ = 'extra msa placeholder'
NUM_TEMPLATES = 'num templates placeholder'
CONFIG = ml_collections.ConfigDict({
'data': {
'common': {
'masked_msa': {
'profile_prob': 0.1,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'max_extra_msa': 1024,
#'msa_cluster_features': True,
'msa_cluster_features': False,
'num_recycle': 3,
'reduce_msa_clusters_by_max_templates': False,
'resample_msa_in_recycling': True,
'template_features': [
'template_all_atom_positions', 'template_sum_probs',
'template_aatype', 'template_all_atom_masks',
'template_domain_names'
],
'unsupervised_features': [
'aatype', 'residue_index', 'sequence', 'msa', 'domain_name',
'num_alignments', 'seq_length', 'between_segment_residues',
'deletion_matrix'
],
'use_templates': False,
},
'eval': {
'feat': {
'aatype': [NUM_RES],
'all_atom_mask': [NUM_RES, None],
'all_atom_positions': [NUM_RES, None, None],
'alt_chi_angles': [NUM_RES, None],
'atom14_alt_gt_exists': [NUM_RES, None],
'atom14_alt_gt_positions': [NUM_RES, None, None],
'atom14_atom_exists': [NUM_RES, None],
'atom14_atom_is_ambiguous': [NUM_RES, None],
'atom14_gt_exists': [NUM_RES, None],
'atom14_gt_positions': [NUM_RES, None, None],
'atom37_atom_exists': [NUM_RES, None],
'backbone_affine_mask': [NUM_RES],
'backbone_affine_tensor': [NUM_RES, None],
'bert_mask': [NUM_MSA_SEQ, NUM_RES],
'chi_angles': [NUM_RES, None],
'chi_mask': [NUM_RES, None],
'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_row_mask': [NUM_EXTRA_SEQ],
'is_distillation': [],
'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
'msa_mask': [NUM_MSA_SEQ, NUM_RES],
'msa_row_mask': [NUM_MSA_SEQ],
'pseudo_beta': [NUM_RES, None],
'pseudo_beta_mask': [NUM_RES],
'random_crop_to_size_seed': [None],
'residue_index': [NUM_RES],
'residx_atom14_to_atom37': [NUM_RES, None],
'residx_atom37_to_atom14': [NUM_RES, None],
'resolution': [],
'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
'rigidgroups_group_exists': [NUM_RES, None],
'rigidgroups_group_is_ambiguous': [NUM_RES, None],
'rigidgroups_gt_exists': [NUM_RES, None],
'rigidgroups_gt_frames': [NUM_RES, None, None],
'seq_length': [],
'seq_mask': [NUM_RES],
'target_feat': [NUM_RES, None],
'template_aatype': [NUM_TEMPLATES, NUM_RES],
'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
'template_all_atom_positions': [
NUM_TEMPLATES, NUM_RES, None, None],
'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
'template_backbone_affine_tensor': [
NUM_TEMPLATES, NUM_RES, None],
'template_mask': [NUM_TEMPLATES],
'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
'template_sum_probs': [NUM_TEMPLATES, None],
'true_msa': [NUM_MSA_SEQ, NUM_RES]
},
'fixed_size': True,
'subsample_templates': False, # We want top templates.
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
},
},
'model': {
'embeddings_and_evoformer': {
'evoformer_num_block': 48,
'evoformer': {
'msa_row_attention_with_pair_bias': {
'dropout_rate': 0.15,
'gating': True,
'num_head': 8,
'orientation': 'per_row',
'shared_dropout': True
},
'msa_column_attention': {
'dropout_rate': 0.0,
'gating': True,
'num_head': 8,
'orientation': 'per_column',
'shared_dropout': True
},
'msa_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'outer_product_mean': {
'first': False,
'chunk_size': 128,
'dropout_rate': 0.0,
'num_outer_channel': 32,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'max_relative_feature': 32,
'msa_channel': 256,
'pair_channel': 128,
'prev_pos': {
'min_bin': 3.25,
'max_bin': 20.75,
'num_bins': 15
},
'recycle_features': True,
'recycle_pos': True,
'seq_channel': 384,
'template': {
'attention': {
'gating': False,
'key_dim': 64,
'num_head': 4,
'value_dim': 64
},
'dgram_features': {
'min_bin': 3.25,
'max_bin': 50.75,
'num_bins': 39
},
'embed_torsion_angles': False,
'enabled': False,
'template_pair_stack': {
'num_block': 2,
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'key_dim': 64,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True,
'value_dim': 64
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'key_dim': 64,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True,
'value_dim': 64
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 2,
'orientation': 'per_row',
'shared_dropout': True
}
},
'max_templates': 4,
'subbatch_size': 128,
'use_template_unit_vector': False,
}
},
'global_config': {
'deterministic': False,
'multimer_mode': False,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True,
'eval_dropout': False,
},
'heads': {
'distogram': {
'first_break': 2.3125,
'last_break': 21.6875,
'num_bins': 64,
'weight': 0.3
},
'predicted_aligned_error': {
# `num_bins - 1` bins uniformly space the
# [0, max_error_bin A] range.
# The final bin covers [max_error_bin A, +infty]
# 31A gives bins with 0.5A width.
'max_error_bin': 31.,
'num_bins': 64,
'num_channels': 128,
'filter_by_resolution': True,
'min_resolution': 0.1,
'max_resolution': 3.0,
'weight': 0.0,
},
'experimentally_resolved': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'weight': 0.01
},
'structure_module': {
'num_layer': 8,
'fape': {
'clamp_distance': 10.0,
'clamp_type': 'relu',
'loss_unit_distance': 10.0
},
'angle_norm_weight': 0.01,
'chi_weight': 0.5,
'clash_overlap_tolerance': 1.5,
'compute_in_graph_metrics': True,
'dropout': 0.1,
'num_channel': 384,
'num_head': 12,
'num_layer_in_transition': 3,
'num_point_qk': 4,
'num_point_v': 8,
'num_scalar_qk': 16,
'num_scalar_v': 16,
'position_scale': 10.0,
'sidechain': {
'atom_clamp_distance': 10.0,
'num_channel': 128,
'num_residual_block': 2,
'weight_frac': 0.5,
'length_scale': 10.,
},
'structural_violation_loss_weight': 1.0,
'violation_tolerance_factor': 12.0,
'weight': 1.0
},
'predicted_lddt': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 50,
'num_channels': 128,
'weight': 0.01
},
'masked_msa': {
'num_output': 23,
'weight': 2.0
},
},
'num_recycle': 3,
'resample_msa_in_recycling': True
},
})
def shaped_categorical(probs, epsilon=1e-10):
ds = shape_list(probs) # [num_seq, num_resi, num_resi_type]
num_classes = ds[-1]
# 随机采样,加上epsilon,防止log后为0
# msa中,每一个位置(restype+x+gap+mask_)
counts = tf.random.categorical(
tf.reshape(tf.log(probs + epsilon), [-1, num_classes]),
1,
dtype=tf.int32)
return tf.reshape(counts, ds[:-1])
def shape_list(x):
"""Return list of dimensions of a tensor, statically where possible.
Like `x.shape.as_list()` but with tensors instead of `None`s.
Args:
x: A tensor.
Returns:
A list with length equal to the rank of the tensor. The n-th element of the
list is an integer when that dimension is statically known otherwise it is
the n-th element of `tf.shape(x)`.
"""
x = tf.convert_to_tensor(x)
# If unknown rank, return dynamic shape
if x.get_shape().dims is None:
return tf.shape(x)
static = x.get_shape().as_list()
shape = tf.shape(x)
ret = []
for i in range(len(static)):
dim = static[i]
if dim is None:
dim = shape[i]
ret.append(dim)
return ret
def data_transforms_curry1(f):
"""Supply all arguments but the first."""
def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs)
return fc
@data_transforms_curry1
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly
random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32)
# 构建随机随机出现某一氨基酸的概率,和MSA中氨基酸的保守性有关
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * protein['hhblits_profile'] +
config.same_prob * tf.one_hot(protein['msa'], 22))
#print(tf.reduce_sum(categorical_probs, axis=-1)) # 都为0.3
# Put all remaining probability on [MASK] which is a new column
pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
pad_shapes[-1][1] = 1
# mask_prob : 0.7, 其他prob加在一起0.3
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0.
# categorical_probs张量后填充mask_prob值,代表MSA每一个位置的概率(20种氨基酸+gap+X+mask)
categorical_probs = tf.pad(
categorical_probs, pad_shapes, constant_values=mask_prob)
#print(tf.reduce_sum(categorical_probs, axis=-1)) # 都为0.3
sh = shape_list(protein['msa'])
# 0-1均匀分布中随机抽样,形状为sh,通过和replace_fraction(0.15)比较,产生随机mask位置
mask_position = tf.random.uniform(sh) < replace_fraction
##抽样,注意随机性产生的方式,抽到mask概率最大,而抽到其他氨基酸概率的大小和其在MSA中的保守性有关
bert_msa = shaped_categorical(categorical_probs)
## 大概0.15的概率用随机氨基酸代替,随机氨基酸中有0.7的概率是mask,还有0.3的概率抽到其他氨基酸,
## 氨基酸在此位置越保守,抽到的可能性越大
## bert_msa中大概有0.7*0.15的mask,还有混杂着错误和正确的氨基酸
bert_msa = tf.where(mask_position, bert_msa, protein['msa'])
# Mix real and masked MSA
protein['bert_mask'] = tf.cast(mask_position, tf.float32)
protein['true_msa'] = protein['msa']
protein['msa'] = bert_msa
return protein
common_cfg = CONFIG.data.common
eval_cfg = CONFIG.data.eval
with open("Human_HBB_tensor_dict_nonensembled.pkl",'rb') as f:
protein = pickle.load(f)
print(protein.keys())
protein = make_masked_msa(common_cfg.masked_msa,
eval_cfg.masked_msa_replace_fraction)(protein)
print(protein.keys())
print(protein['bert_mask'].shape)
print(protein['msa'].shape)
print(protein['true_msa'].shape)