入口
processed_feature_dict = model_runner.process_features(feature_dict, random_seed=model_random_seed)
这一部分的作用是根据config取出需要的字段
def make_data_config(
config: ml_collections.ConfigDict,
num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
"""Makes a data config for the input pipeline."""
cfg = copy.deepcopy(config.data)
feature_names = cfg.common.unsupervised_features # 需要用到的特征名称['aatype', 'residue_index', 'sequence', 'msa', 'domain_name', 'num_alignments', 'seq_length', 'between_segment_residues', 'deletion_matrix']
if cfg.common.use_templates:# 是否使用模板
feature_names += cfg.common.template_features
with cfg.unlocked():
cfg.eval.crop_size = num_res
return cfg, feature_names
data:
common:
masked_msa:
profile_prob: 0.1
same_prob: 0.1
uniform_prob: 0.1
max_extra_msa: 5120
msa_cluster_features: true
num_recycle: 3
reduce_msa_clusters_by_max_templates: true
resample_msa_in_recycling: true
template_features:
- template_all_atom_positions
- template_sum_probs
- template_aatype
- template_all_atom_masks
- template_domain_names
unsupervised_features:
- aatype
- residue_index
- sequence
- msa
- domain_name
- num_alignments
- seq_length
- between_segment_residues
- deletion_matrix
use_templates: true
eval:
feat:
aatype:
- num residues placeholder
all_atom_mask:
- num residues placeholder
- null
all_atom_positions:
- num residues placeholder
- null
- null
alt_chi_angles:
- num residues placeholder
- null
atom14_alt_gt_exists:
- num residues placeholder
- null
atom14_alt_gt_positions:
- num residues placeholder
- null
- null
atom14_atom_exists:
- num residues placeholder
- null
atom14_atom_is_ambiguous:
- num residues placeholder
- null
atom14_gt_exists:
- num residues placeholder
- null
atom14_gt_positions:
- num residues placeholder
- null
- null
atom37_atom_exists:
- num residues placeholder
- null
backbone_affine_mask:
- num residues placeholder
backbone_affine_tensor:
- num residues placeholder
- null
bert_mask:
- msa placeholder
- num residues placeholder
chi_angles:
- num residues placeholder
- null
chi_mask:
- num residues placeholder
- null
extra_deletion_value:
- extra msa placeholder
- num residues placeholder
extra_has_deletion:
- extra msa placeholder
- num residues placeholder
extra_msa:
- extra msa placeholder
- num residues placeholder
extra_msa_mask:
- extra msa placeholder
- num residues placeholder
extra_msa_row_mask:
- extra msa placeholder
is_distillation: []
msa_feat:
- msa placeholder
- num residues placeholder
- null
msa_mask:
- msa placeholder
- num residues placeholder
msa_row_mask:
- msa placeholder
pseudo_beta:
- num residues placeholder
- null
pseudo_beta_mask:
- num residues placeholder
random_crop_to_size_seed:
- null
residue_index:
- num residues placeholder
residx_atom14_to_atom37:
- num residues placeholder
- null
residx_atom37_to_atom14:
- num residues placeholder
- null
resolution: []
rigidgroups_alt_gt_frames:
- num residues placeholder
- null
- null
rigidgroups_group_exists:
- num residues placeholder
- null
rigidgroups_group_is_ambiguous:
- num residues placeholder
- null
rigidgroups_gt_exists:
- num residues placeholder
- null
rigidgroups_gt_frames:
- num residues placeholder
- null
- null
seq_length: []
seq_mask:
- num residues placeholder
target_feat:
- num residues placeholder
- null
template_aatype:
- num templates placeholder
- num residues placeholder
template_all_atom_masks:
- num templates placeholder
- num residues placeholder
- null
template_all_atom_positions:
- num templates placeholder
- num residues placeholder
- null
- null
template_backbone_affine_mask:
- num templates placeholder
- num residues placeholder
template_backbone_affine_tensor:
- num templates placeholder
- num residues placeholder
- null
template_mask:
- num templates placeholder
template_pseudo_beta:
- num templates placeholder
- num residues placeholder
- null
template_pseudo_beta_mask:
- num templates placeholder
- num residues placeholder
template_sum_probs:
- num templates placeholder
- null
true_msa:
- msa placeholder
- num residues placeholder
fixed_size: true
masked_msa_replace_fraction: 0.15
max_msa_clusters: 512
max_templates: 4
num_ensemble: 1
subsample_templates: false
model:
embeddings_and_evoformer:
evoformer:
msa_column_attention:
dropout_rate: 0.0
gating: true
num_head: 8
orientation: per_column
shared_dropout: true
msa_row_attention_with_pair_bias:
dropout_rate: 0.15
gating: true
num_head: 8
orientation: per_row
shared_dropout: true
msa_transition:
dropout_rate: 0.0
num_intermediate_factor: 4
orientation: per_row
shared_dropout: true
outer_product_mean:
chunk_size: 128
dropout_rate: 0.0
first: false
num_outer_channel: 32
orientation: per_row
shared_dropout: true
pair_transition:
dropout_rate: 0.0
num_intermediate_factor: 4
orientation: per_row
shared_dropout: true
triangle_attention_ending_node:
dropout_rate: 0.25
gating: true
num_head: 4
orientation: per_column
shared_dropout: true
triangle_attention_starting_node:
dropout_rate: 0.25
gating: true
num_head: 4
orientation: per_row
shared_dropout: true
triangle_multiplication_incoming:
dropout_rate: 0.25
equation: kjc,kic->ijc
num_intermediate_channel: 128
orientation: per_row
shared_dropout: true
triangle_multiplication_outgoing:
dropout_rate: 0.25
equation: ikc,jkc->ijc
num_intermediate_channel: 128
orientation: per_row
shared_dropout: true
evoformer_num_block: 48
extra_msa_channel: 64
extra_msa_stack_num_block: 4
max_relative_feature: 32
msa_channel: 256
pair_channel: 128
prev_pos:
max_bin: 20.75
min_bin: 3.25
num_bins: 15
recycle_features: true
recycle_pos: true
seq_channel: 384
template:
attention:
gating: false
key_dim: 64
num_head: 4
value_dim: 64
dgram_features:
max_bin: 50.75
min_bin: 3.25
num_bins: 39
embed_torsion_angles: true
enabled: true
max_templates: 4
subbatch_size: 128
template_pair_stack:
num_block: 2
pair_transition:
dropout_rate: 0.0
num_intermediate_factor: 2
orientation: per_row
shared_dropout: true
triangle_attention_ending_node:
dropout_rate: 0.25
gating: true
key_dim: 64
num_head: 4
orientation: per_column
shared_dropout: true
value_dim: 64
triangle_attention_starting_node:
dropout_rate: 0.25
gating: true
key_dim: 64
num_head: 4
orientation: per_row
shared_dropout: true
value_dim: 64
triangle_multiplication_incoming:
dropout_rate: 0.25
equation: kjc,kic->ijc
num_intermediate_channel: 64
orientation: per_row
shared_dropout: true
triangle_multiplication_outgoing:
dropout_rate: 0.25
equation: ikc,jkc->ijc
num_intermediate_channel: 64
orientation: per_row
shared_dropout: true
use_template_unit_vector: false
global_config:
deterministic: false
multimer_mode: false
subbatch_size: 4
use_remat: false
zero_init: true
heads:
distogram:
first_break: 2.3125
last_break: 21.6875
num_bins: 64
weight: 0.3
experimentally_resolved:
filter_by_resolution: true
max_resolution: 3.0
min_resolution: 0.1
weight: 0.01
masked_msa:
num_output: 23
weight: 2.0
predicted_aligned_error:
filter_by_resolution: true
max_error_bin: 31.0
max_resolution: 3.0
min_resolution: 0.1
num_bins: 64
num_channels: 128
weight: 0.0
predicted_lddt:
filter_by_resolution: true
max_resolution: 3.0
min_resolution: 0.1
num_bins: 50
num_channels: 128
weight: 0.01
structure_module:
angle_norm_weight: 0.01
chi_weight: 0.5
clash_overlap_tolerance: 1.5
compute_in_graph_metrics: true
dropout: 0.1
fape:
clamp_distance: 10.0
clamp_type: relu
loss_unit_distance: 10.0
num_channel: 384
num_head: 12
num_layer: 8
num_layer_in_transition: 3
num_point_qk: 4
num_point_v: 8
num_scalar_qk: 16
num_scalar_v: 16
position_scale: 10.0
sidechain:
atom_clamp_distance: 10.0
length_scale: 10.0
num_channel: 128
num_residual_block: 2
weight_frac: 0.5
structural_violation_loss_weight: 1.0
violation_tolerance_factor: 12.0
weight: 1.0
num_recycle: 3
resample_msa_in_recycling: true
tensor转换
features_metadata = _make_features_metadata(features) # 特征的数据类型
## {'between_segment_residues': (tf.int64, [...]), 'template_aatype': (tf.float32, [...]), 'deletion_matrix': (tf.float32, [...]), 'template_all_atom_positions': (tf.float32, [...]), 'residue_index': (tf.int64, [...]), 'sequence': (tf.string, [...]), 'num_alignments': (tf.int64, [...]), 'domain_name': (tf.string, [...]), 'template_all_atom_masks': (tf.float32, [...]), 'msa': (tf.int64, [...]), 'seq_length': (tf.int64, [...]), 'template_domain_names': (tf.string, [...]), 'template_sum_probs': (tf.float32, [...]), 'aatype': (tf.float32, [...])}
tensor_dict = {k: tf.constant(v) for k, v in np_example.items()
if k in features_metadata}## 取出需要的特征
# Ensures shapes are as expected. Needed for setting size of empty features
# e.g. when no template hits were found. 确保形状符合预期。需要设置空功能的大小,例如当未找到模板命中时。
tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)
return tensor_dict