链接:https://pan.baidu.com/s/1Zbqw5t2fWo9Z9Zep07Y74g
提取码:1234
import time
import os
from typing import List
from einops import rearrange
import torch
import numpy as np
import sys
import io
from glob import glob
from typing import Union, List
import requests
import warnings
from os.path import splitext, basename
from Bio.PDB import PDBParser, PDBIO
from Bio.SeqUtils import seq1
from Bio import SeqIO
from bisect import bisect_left, bisect_right
import torch
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from einops import rearrange
import os
from einops import rearrange, repeat
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix
from igfold.model.components import TriangleGraphTransformer, IPAEncoder, IPATransformer
from igfold.utils.coordinates import get_ideal_coords
from igfold.model.components.GraphTransformer import GraphTransformer
from invariant_point_attention.invariant_point_attention import IPABlock, exists
ATOM_DIM = 3
def get_ideal_coords(center=False):
N = torch.tensor([[0, 0, -1.458]], dtype=float)
A = torch.tensor([[0, 0, 0]], dtype=float)
B = torch.tensor([[0, 1.426, 0.531]], dtype=float)
C = place_fourth_atom(
B,
A,
N,
torch.tensor(2.460),
torch.tensor(0.615),
torch.tensor(-2.143),
)
coords = torch.cat([N, A, C, B]).float()
if center:
coords -= coords.mean(
dim=0,
keepdim=True,
)
return coords
@dataclass
class IgFoldOutput():
"""
Output type of for IgFold model.
"""
coords: torch.FloatTensor
prmsd: torch.FloatTensor
translations: torch.FloatTensor
rotations: torch.FloatTensor
coords_loss: Optional[torch.FloatTensor] = None
torsion_loss: Optional[torch.FloatTensor] = None
bondlen_loss: Optional[torch.FloatTensor] = None
prmsd_loss: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
bert_hidden: Optional[torch.FloatTensor] = None
bert_embs: Optional[torch.FloatTensor] = None
gt_embs: Optional[torch.FloatTensor] = None
structure_embs: Optional[torch.FloatTensor] = None
def bb_prmsd_l1(
pdev,
pred,
target,
align_mask=None,
mask=None,
):
aligned_target = do_kabsch(
mobile=target,
stationary=pred,
align_mask=align_mask,
)
bb_dev = (pred - aligned_target).norm(dim=-1)
loss = F.l1_loss(
pdev,
bb_dev,
reduction='none',
)
if exists(mask):
mask = repeat(mask, "b l -> b (l 4)")
loss = torch.sum(
loss * mask,
dim=-1,
) / torch.sum(
mask,
dim=-1,
)
else:
loss = loss.mean(-1)
loss = loss.mean(-1).unsqueeze(0)
return loss
def bond_length_l1(
pred,
target,
mask,
offsets=[1, 2],
):
losses = []
for c in range(pred.shape[0]):
m, p, t = mask[c], pred[c], target[c]
for o in offsets:
m_ = (torch.stack([m[:-o], m[o:]])).all(0)
pred_lens = torch.norm(p[:-o] - p[o:], dim=-1)
target_lens = torch.norm(t[:-o] - t[o:], dim=-1)
losses.append(
torch.abs(pred_lens[m_] - target_lens[m_], ).mean() / o)
return torch.stack(losses)
def do_kabsch(
mobile,
stationary,
align_mask=None,
):
mobile_, stationary_ = mobile.clone(), stationary.clone()
if exists(align_mask):
mobile_[~align_mask] = mobile_[align_mask].mean(dim=-2)
stationary_[~align_mask] = stationary_[align_mask].mean(dim=-2)
_, kabsch_xform = kabsch(
mobile_,
stationary_,
)
else:
_, kabsch_xform = kabsch(
mobile_,
stationary_,
)
return kabsch_xform(mobile)
def kabsch_mse(
pred,
target,
align_mask=None,
mask=None,
clamp=0.,
sqrt=False,
):
aligned_target = do_kabsch(
mobile=target,
stationary=pred.detach(),
align_mask=align_mask,
)
mse = F.mse_loss(
pred,
aligned_target,
reduction='none',
).mean(-1)
if clamp > 0:
mse = torch.clamp(mse, max=clamp**2)
if exists(mask):
mse = torch.sum(
mse * mask,
dim=-1,
) / torch.sum(
mask,
dim=-1,
)
else:
mse = mse.mean(-1)
if sqrt:
mse = mse.sqrt()
return mse
@dataclass
class IgFoldInput():
"""
Input type of for IgFold model.
"""
sequences: List[Union[torch.LongTensor, str]]
template_coords: Optional[torch.FloatTensor] = None
template_mask: Optional[torch.BoolTensor] = None
batch_mask: Optional[torch.BoolTensor] = None
align_mask: Optional[torch.BoolTensor] = None
coords_label: Optional[torch.FloatTensor] = None
return_embeddings: Optional[bool] = False
def kabsch(
mobile,
stationary,
return_translation_rotation=False,
):
X = rearrange(
mobile,
"... l d -> ... d l",
)
Y = rearrange(
stationary,
"... l d -> ... d l",
)
# center X and Y to the origin
XT, YT = X.mean(dim=-1, keepdim=True), Y.mean(dim=-1, keepdim=True)
X_ = X - XT
Y_ = Y - YT
# calculate convariance matrix
C = torch.einsum("... x l, ... y l -> ... x y", X_, Y_)
# Optimal rotation matrix via SVD
if int(torch.__version__.split(".")[1]) < 8:
# warning! int torch 1.<8 : W must be transposed
V, S, W = torch.svd(C)
W = rearrange(W, "... a b -> ... b a")
else:
V, S, W = torch.linalg.svd(C)
# determinant sign for direction correction
v_det = torch.det(V.to("cpu")).to(X.device)
w_det = torch.det(W.to("cpu")).to(X.device)
d = (v_det * w_det) < 0.0
if d.any():
S[d] = S[d] * (-1)
V[d, :] = V[d, :] * (-1)
# Create Rotation matrix U
U = torch.matmul(V, W) #.to(device)
U = rearrange(
U,
"... d x -> ... x d",
)
XT = rearrange(
XT,
"... d x -> ... x d",
)
YT = rearrange(
YT,
"... d x -> ... x d",
)
if return_translation_rotation:
return XT, U, YT
transform = lambda coords: torch.einsum(
"... l d, ... x d -> ... l x",
coords - XT,
U,
) + YT
mobile = transform(mobile)
return mobile, transform
class IgFold(pl.LightningModule):
def __init__(
self,
config,
config_overwrite=None,
):
super().__init__()
import transformers
self.save_hyperparameters()
config = self.hparams.config
if exists(config_overwrite):
config.update(config_overwrite)
self.tokenizer = config["tokenizer"]
self.vocab_size = len(self.tokenizer.vocab)
self.bert_model = transformers.BertModel(config["bert_config"])
bert_layers = self.bert_model.config.num_hidden_layers
self.bert_feat_dim = self.bert_model.config.hidden_size
self.bert_attn_dim = bert_layers * self.bert_model.config.num_attention_heads
self.node_dim = config["node_dim"]
self.depth = config["depth"]
self.gt_depth = config["gt_depth"]
self.gt_heads = config["gt_heads"]
self.temp_ipa_depth = config["temp_ipa_depth"]
self.temp_ipa_heads = config["temp_ipa_heads"]
self.str_ipa_depth = config["str_ipa_depth"]
self.str_ipa_heads = config["str_ipa_heads"]
self.dev_ipa_depth = config["dev_ipa_depth"]
self.dev_ipa_heads = config["dev_ipa_heads"]
self.str_node_transform = nn.Sequential(
nn.Linear(
self.bert_feat_dim,
self.node_dim,
),
nn.ReLU(),
nn.LayerNorm(self.node_dim),
)
self.str_edge_transform = nn.Sequential(
nn.Linear(
self.bert_attn_dim,
self.node_dim,
),
nn.ReLU(),
nn.LayerNorm(self.node_dim),
)
self.main_block = TriangleGraphTransformer(
dim=self.node_dim,
edge_dim=self.node_dim,
depth=self.depth,
tri_dim_hidden=2 * self.node_dim,
gt_depth=self.gt_depth,
gt_heads=self.gt_heads,
gt_dim_head=self.node_dim // 2,
)
self.template_ipa = IPAEncoder(
dim=self.node_dim,
depth=self.temp_ipa_depth,
heads=self.temp_ipa_heads,
require_pairwise_repr=True,
)
self.structure_ipa = IPATransformer(
dim=self.node_dim,
depth=self.str_ipa_depth,
heads=self.str_ipa_heads,
require_pairwise_repr=True,
)
self.dev_node_transform = nn.Sequential(
nn.Linear(self.bert_feat_dim, self.node_dim),
nn.ReLU(),
nn.LayerNorm(self.node_dim),
)
self.dev_edge_transform = nn.Sequential(
nn.Linear(
self.bert_attn_dim,
self.node_dim,
),
nn.ReLU(),
nn.LayerNorm(self.node_dim),
)
self.dev_ipa = IPAEncoder(
dim=self.node_dim,
depth=self.dev_ipa_depth,
heads=self.dev_ipa_heads,
require_pairwise_repr=True,
)
self.dev_linear = nn.Linear(
self.node_dim,
4,
)
def get_tokens(
self,
seq,
):
if isinstance(seq, str):
tokens = self.tokenizer.encode(
" ".join(list(seq)),
return_tensors="pt",
)
elif isinstance(seq, list) and isinstance(seq[0], str):
seqs = [" ".join(list(s)) for s in seq]
tokens = self.tokenizer.batch_encode_plus(
seqs,
return_tensors="pt",
)["input_ids"]
else:
tokens = seq
return tokens.to(self.device)
def get_bert_feats(self, tokens):
bert_output = self.bert_model(
tokens,
output_hidden_states=True,
output_attentions=True,
)
feats = bert_output.hidden_states[-1]
feats = feats[:, 1:-1]
attn = torch.cat(
bert_output.attentions,
dim=1,
)
attn = attn[:, :, 1:-1, 1:-1]
attn = rearrange(
attn,
"b d i j -> b i j d",
)
hidden = bert_output.hidden_states
return feats, attn, hidden
def get_coords_tran_rot(
self,
temp_coords,
batch_size,
seq_len,
center=True,
):
res_coords = rearrange(
temp_coords,
"b (l a) d -> b l a d",
l=seq_len,
)
res_ideal_coords = repeat(
get_ideal_coords(center=center),
"a d -> b l a d",
b=batch_size,
l=seq_len,
).to(self.device)
_, rotations, translations = kabsch(
res_ideal_coords,
res_coords,
return_translation_rotation=True,
)
translations = rearrange(
translations,
"b l () d -> b l d",
)
return translations, rotations
def clean_input(
self,
input: IgFoldInput,
):
tokens = [self.get_tokens(s) for s in input.sequences]
temp_coords = input.template_coords
temp_mask = input.template_mask
batch_mask = input.batch_mask
align_mask = input.align_mask
batch_size = tokens[0].shape[0]
seq_lens = [max(t.shape[1] - 2, 0) for t in tokens]
seq_len = sum(seq_lens)
if not exists(temp_coords):
temp_coords = torch.zeros(
batch_size,
4 * seq_len,
ATOM_DIM,
device=self.device,
).float()
if not exists(temp_mask):
temp_mask = torch.zeros(
batch_size,
4 * seq_len,
device=self.device,
).bool()
if not exists(batch_mask):
batch_mask = torch.ones(
batch_size,
4 * seq_len,
device=self.device,
).bool()
if not exists(align_mask):
align_mask = torch.ones(
batch_size,
4 * seq_len,
device=self.device,
).bool()
align_mask = align_mask & batch_mask # Should already be masked by batch_mask anyway
temp_coords[~temp_mask] = 0.
for i, (tc, m) in enumerate(zip(temp_coords, temp_mask)):
temp_coords[i][m] -= tc[m].mean(-2)
input.sequences = tokens
input.template_coords = temp_coords
input.template_mask = temp_mask
input.batch_mask = batch_mask
input.align_mask = align_mask
batch_size = tokens[0].shape[0]
seq_lens = [max(t.shape[1] - 2, 0) for t in tokens]
seq_len = sum(seq_lens)
return input, batch_size, seq_lens, seq_len
def forward(
self,
input: IgFoldInput,
):
input, batch_size, seq_lens, seq_len = self.clean_input(input)
tokens = input.sequences
temp_coords = input.template_coords
temp_mask = input.template_mask
coords_label = input.coords_label
batch_mask = input.batch_mask
align_mask = input.align_mask
return_embeddings = input.return_embeddings
res_batch_mask = rearrange(
batch_mask,
"b (l a) -> b l a",
a=4,
).all(-1)
res_temp_mask = rearrange(
temp_mask,
"b (l a) -> b l a",
a=4,
).all(-1)
### Model forward pass
bert_feats, bert_attns, bert_hidden = [], [], []
for t in tokens:
f, a, h = self.get_bert_feats(t)
bert_feats.append(f)
bert_attns.append(a)
bert_hidden.append(h)
bert_feats = torch.cat(bert_feats, dim=1)
bert_attn = torch.zeros(
(batch_size, seq_len, seq_len, self.bert_attn_dim),
device=self.device,
)
for i, (a, l) in enumerate(zip(bert_attns, seq_lens)):
cum_l = sum(seq_lens[:i])
bert_attn[:, cum_l:cum_l + l, cum_l:cum_l + l, :] = a
temp_translations, temp_rotations = self.get_coords_tran_rot(
temp_coords,
batch_size,
seq_len,
)
str_nodes = self.str_node_transform(bert_feats)
str_edges = self.str_edge_transform(bert_attn)
str_nodes, str_edges = self.main_block(
str_nodes,
str_edges,
mask=res_batch_mask,
)
gt_embs = str_nodes
str_nodes = self.template_ipa(
str_nodes,
translations=temp_translations,
rotations=temp_rotations,
pairwise_repr=str_edges,
mask=res_temp_mask,
)
structure_embs = str_nodes
ipa_coords, ipa_translations, ipa_quaternions = self.structure_ipa(
str_nodes,
translations=None,
quaternions=None,
pairwise_repr=str_edges,
mask=res_batch_mask,
)
ipa_rotations = quaternion_to_matrix(ipa_quaternions)
dev_nodes = self.dev_node_transform(bert_feats)
dev_edges = self.dev_edge_transform(bert_attn)
dev_out_feats = self.dev_ipa(
dev_nodes,
translations=ipa_translations.detach(),
rotations=ipa_rotations.detach(),
pairwise_repr=dev_edges,
mask=res_batch_mask,
)
dev_pred = F.relu(self.dev_linear(dev_out_feats))
dev_pred = rearrange(dev_pred, "b l a -> b (l a)", a=4)
bb_coords = rearrange(
ipa_coords[:, :, :3],
"b l a d -> b (l a) d",
)
flat_coords = rearrange(
ipa_coords[:, :, :4],
"b l a d -> b (l a) d",
)
### Calculate losses if given labels
loss = torch.zeros(
batch_size,
device=self.device,
)
if exists(coords_label):
rmsd_clamp = self.hparams.config["rmsd_clamp"]
coords_loss = kabsch_mse(
flat_coords,
coords_label,
align_mask=batch_mask,
mask=batch_mask,
clamp=rmsd_clamp,
)
bb_coords_label = rearrange(
rearrange(coords_label, "b (l a) d -> b l a d", a=4)[:, :, :3],
"b l a d -> b (l a) d")
bb_batch_mask = rearrange(
rearrange(batch_mask, "b (l a) -> b l a", a=4)[:, :, :3],
"b l a -> b (l a)")
bondlen_loss = bond_length_l1(
bb_coords,
bb_coords_label,
bb_batch_mask,
)
prmsd_loss = []
cum_seq_lens = np.cumsum([0] + seq_lens)
for sl_i, sl in enumerate(seq_lens):
align_mask_ = align_mask.clone()
align_mask_[:, :cum_seq_lens[sl_i]] = False
align_mask_[:, cum_seq_lens[sl_i + 1]:] = False
res_batch_mask_ = res_batch_mask.clone()
res_batch_mask_[:, :cum_seq_lens[sl_i]] = False
res_batch_mask_[:, cum_seq_lens[sl_i + 1]:] = False
if sl == 0 or align_mask_.sum() == 0 or res_batch_mask_.sum(
) == 0:
continue
prmsd_loss.append(
bb_prmsd_l1(
dev_pred,
flat_coords.detach(),
coords_label,
align_mask=align_mask_,
mask=res_batch_mask_,
))
prmsd_loss = sum(prmsd_loss)
coords_loss, bondlen_loss = list(
map(
lambda l: rearrange(l, "(c b) -> b c", b=batch_size).mean(
1),
[coords_loss, bondlen_loss],
))
loss += sum([coords_loss, bondlen_loss, prmsd_loss])
else:
prmsd_loss, coords_loss, bondlen_loss = None, None, None
if not exists(coords_label):
loss = None
bert_hidden = bert_hidden if return_embeddings else None
bert_embs = bert_feats if return_embeddings else None
gt_embs = gt_embs if return_embeddings else None
structure_embs = structure_embs if return_embeddings else None
output = IgFoldOutput(
coords=ipa_coords,
prmsd=dev_pred,
translations=ipa_translations,
rotations=ipa_rotations,
coords_loss=coords_loss,
bondlen_loss=bondlen_loss,
prmsd_loss=prmsd_loss,
loss=loss,
bert_hidden=bert_hidden,
bert_embs=bert_embs,
gt_embs=gt_embs,
structure_embs=structure_embs,
)
return output
def pdb2fasta(pdb_file, num_chains=None):
"""Converts a PDB file to a fasta formatted string using its ATOM data"""
pdb_id = basename(pdb_file).split('.')[0]
parser = PDBParser()
structure = parser.get_structure(
pdb_id,
pdb_file,
)
real_num_chains = len([0 for _ in structure.get_chains()])
if num_chains is not None and num_chains != real_num_chains:
print('WARNING: Skipping {}. Expected {} chains, got {}'.format(
pdb_file, num_chains, real_num_chains))
return ''
fasta = ''
for chain in structure.get_chains():
id_ = chain.id
seq = seq1(''.join([residue.resname for residue in chain]))
fasta += '>{}:{}\t{}\n'.format(pdb_id, id_, len(seq))
max_line_length = 80
for i in range(0, len(seq), max_line_length):
fasta += f'{seq[i:i + max_line_length]}\n'
return fasta
def get_atom_coord(residue, atom_type):
if atom_type in residue:
return residue[atom_type].get_coord()
else:
return [0, 0, 0]
def get_cb_or_ca_coord(residue):
if 'CB' in residue:
return residue['CB'].get_coord()
elif 'CA' in residue:
return residue['CA'].get_coord()
else:
return [0, 0, 0]
def place_fourth_atom(
a_coord: torch.Tensor,
b_coord: torch.Tensor,
c_coord: torch.Tensor,
length: torch.Tensor,
planar: torch.Tensor,
dihedral: torch.Tensor,
) -> torch.Tensor:
"""
Given 3 coords + a length + a planar angle + a dihedral angle, compute a fourth coord
"""
bc_vec = b_coord - c_coord
bc_vec = bc_vec / bc_vec.norm(dim=-1, keepdim=True)
n_vec = (b_coord - a_coord).expand(bc_vec.shape).cross(bc_vec)
n_vec = n_vec / n_vec.norm(dim=-1, keepdim=True)
m_vec = [bc_vec, n_vec.cross(bc_vec), n_vec]
d_vec = [
length * torch.cos(planar),
length * torch.sin(planar) * torch.cos(dihedral),
-length * torch.sin(planar) * torch.sin(dihedral)
]
d_coord = c_coord + sum([m * d for m, d in zip(m_vec, d_vec)])
return d_coord
def get_atom_coords_mask(coords):
mask = torch.ByteTensor([1 if sum(_) != 0 else 0 for _ in coords])
mask = mask & (1 - torch.any(torch.isnan(coords), dim=1).byte())
return mask
def place_missing_cb_o(atom_coords):
cb_coords = place_fourth_atom(
atom_coords['C'],
atom_coords['N'],
atom_coords['CA'],
torch.tensor(1.522),
torch.tensor(1.927),
torch.tensor(-2.143),
)
o_coords = place_fourth_atom(
torch.roll(atom_coords['N'], shifts=-1, dims=0),
atom_coords['CA'],
atom_coords['C'],
torch.tensor(1.231),
torch.tensor(2.108),
torch.tensor(-3.142),
)
bb_mask = get_atom_coords_mask(atom_coords['N']) & get_atom_coords_mask(
atom_coords['CA']) & get_atom_coords_mask(atom_coords['C'])
missing_cb = (get_atom_coords_mask(atom_coords['CB']) & bb_mask) == 0
atom_coords['CB'][missing_cb] = cb_coords[missing_cb]
bb_mask = get_atom_coords_mask(
torch.roll(
atom_coords['N'],
shifts=-1,
dims=0,
)) & get_atom_coords_mask(atom_coords['CA']) & get_atom_coords_mask(
atom_coords['C'])
missing_o = (get_atom_coords_mask(atom_coords['O']) & bb_mask) == 0
atom_coords['O'][missing_o] = o_coords[missing_o]
def get_atom_coords(pdb_file, fasta_file=None):
p = PDBParser()
file_name = splitext(basename(pdb_file))[0]
structure = p.get_structure(
file_name,
pdb_file,
)
if fasta_file:
residues = []
for chain in structure.get_chains():
pdb_seq = get_pdb_chain_seq(
pdb_file,
chain.id,
)
chain_dict = {"A": "H", "B": "L", "H": "H", "L": "L"}
fasta_seq = get_fasta_chain_seq(
fasta_file,
chain_dict[chain.id],
)
chain_residues = list(chain.get_residues())
continuous_ranges = get_continuous_ranges(chain_residues)
fasta_residues = [[]] * len(fasta_seq)
fasta_r = (0, 0)
for pdb_r in continuous_ranges:
fasta_r_start = fasta_seq[fasta_r[1]:].index(
pdb_seq[pdb_r[0]:pdb_r[1]]) + fasta_r[1]
fasta_r_end = (len(pdb_seq) if pdb_r[1] == None else
pdb_r[1]) - pdb_r[0] + fasta_r_start
fasta_r = (fasta_r_start, fasta_r_end)
fasta_residues[fasta_r[0]:fasta_r[1]] = chain_residues[
pdb_r[0]:pdb_r[1]]
residues += fasta_residues
else:
residues = list(structure.get_residues())
n_coords = torch.tensor([get_atom_coord(r, 'N') for r in residues])
ca_coords = torch.tensor([get_atom_coord(r, 'CA') for r in residues])
c_coords = torch.tensor([get_atom_coord(r, 'C') for r in residues])
cb_coords = torch.tensor([get_atom_coord(r, 'CB') for r in residues])
cb_ca_coords = torch.tensor([get_cb_or_ca_coord(r) for r in residues])
o_coords = torch.tensor([get_atom_coord(r, 'O') for r in residues])
atom_coords = {}
atom_coords['N'] = n_coords
atom_coords['CA'] = ca_coords
atom_coords['C'] = c_coords
atom_coords['CB'] = cb_coords
atom_coords['CBCA'] = cb_ca_coords
atom_coords['O'] = o_coords
place_missing_cb_o(atom_coords)
return atom_coords
def get_pdb_chain_seq(
pdb_file,
chain_id,
):
raw_fasta = pdb2fasta(pdb_file)
fasta = SeqIO.parse(
io.StringIO(raw_fasta),
'fasta',
)
chain_sequences = {
chain.id.split(':')[1]: str(chain.seq)
for chain in fasta
}
if chain_id not in chain_sequences.keys():
print(
"No such chain in PDB file. Chain must have a chain ID of \"[PDB ID]:{}\""
.format(chain_id))
return None
return chain_sequences[chain_id]
def get_fasta_chain_seq(
fasta_file,
chain_id,
):
for chain in SeqIO.parse(fasta_file, 'fasta'):
if ":{}".format(chain_id) in chain.id:
return str(chain.seq)
def process_template(
pdb_file,
fasta_file,
ignore_cdrs=None,
ignore_chain=None,
):
temp_coords, temp_mask = None, None
if exists(pdb_file):
temp_coords = get_atom_coords(
pdb_file,
fasta_file=fasta_file,
)
temp_coords = torch.stack(
[
temp_coords['N'], temp_coords['CA'], temp_coords['C'],
temp_coords['CB']
],
dim=1,
).view(-1, 3).unsqueeze(0)
temp_mask = torch.ones(temp_coords.shape[:2]).bool()
temp_mask[temp_coords.isnan().any(-1)] = False
temp_mask[temp_coords.sum(-1) == 0] = False
if exists(ignore_cdrs):
cdr_names = ["h1", "h2", "h3", "l1", "l2", "l3"]
if ignore_cdrs == False:
cdr_names = []
elif type(ignore_cdrs) == List:
cdr_names = ignore_cdrs
elif type(ignore_cdrs) == str:
cdr_names = [ignore_cdrs]
for cdr in cdr_names:
cdr_range = cdr_indices(pdb_file, cdr)
temp_mask[:, (cdr_range[0] - 1) * 4:(cdr_range[1] + 2) *
4] = False
if exists(ignore_chain) and ignore_chain in ["H", "L"]:
seq_dict = get_fasta_chain_dict(fasta_file)
hlen = len(seq_dict["H"])
if ignore_chain == "H":
temp_mask[:, :hlen * 4] = False
elif ignore_chain == "L":
temp_mask[:, hlen * 4:] = False
return temp_coords, temp_mask
def get_continuous_ranges(residues):
""" Returns ranges of residues which are continuously connected (peptide bond length 1.2-1.45 Å) """
dists = []
for res_i in range(len(residues) - 1):
dists.append(
np.linalg.norm(
np.array(get_atom_coord(residues[res_i], "C")) -
np.array(get_atom_coord(residues[res_i + 1], "N"))))
ranges = []
start_i = 0
for d_i, d in enumerate(dists):
if d > 1.45 or d < 1.2:
ranges.append((start_i, d_i + 1))
start_i = d_i + 1
if d_i == len(dists) - 1:
ranges.append((start_i, None))
return ranges
def get_fasta_chain_dict(fasta_file):
seq_dict = {}
for chain in SeqIO.parse(fasta_file, 'fasta'):
seq_dict[chain.id] = str(chain.seq)
return seq_dict
def exists(x):
return x is not None
_aa_dict = {
'A': '0',
'C': '1',
'D': '2',
'E': '3',
'F': '4',
'G': '5',
'H': '6',
'I': '7',
'K': '8',
'L': '9',
'M': '10',
'N': '11',
'P': '12',
'Q': '13',
'R': '14',
'S': '15',
'T': '16',
'V': '17',
'W': '18',
'Y': '19'
}
_aa_1_3_dict = {
'A': 'ALA',
'C': 'CYS',
'D': 'ASP',
'E': 'GLU',
'F': 'PHE',
'G': 'GLY',
'H': 'HIS',
'I': 'ILE',
'K': 'LYS',
'L': 'LEU',
'M': 'MET',
'N': 'ASN',
'P': 'PRO',
'Q': 'GLN',
'R': 'ARG',
'S': 'SER',
'T': 'THR',
'V': 'VAL',
'W': 'TRP',
'Y': 'TYR',
'-': 'GAP'
}
def save_PDB(
out_pdb: str,
coords: torch.Tensor,
seq: str,
chains: List[str] = None,
error: torch.Tensor = None,
delim: Union[int, List[int]] = None,
atoms=['N', 'CA', 'C', 'O', 'CB'],
) -> None:
"""
Write set of N, CA, C, O, CB coords to PDB file
"""
if not exists(chains):
chains = ["H", "L"]
if type(delim) == type(None):
delim = -1
elif type(delim) == int:
delim = [delim]
if not exists(error):
error = torch.zeros(len(seq))
with open(out_pdb, "w") as f:
k = 0
for r, residue in enumerate(coords):
AA = _aa_1_3_dict[seq[r]]
for a, atom in enumerate(residue):
if AA == "GLY" and atoms[a] == "CB": continue
x, y, z = atom
chain_id = chains[np.where(np.array(delim) - r > 0)[0][0]]
f.write(
"ATOM %5d %-2s %3s %s%4d %8.3f%8.3f%8.3f %4.2f %4.2f\n"
% (k + 1, atoms[a], AA, chain_id, r + 1, x, y, z, 1,
error[r]))
k += 1
f.close()
def write_pdb_bfactor(
in_pdb_file,
out_pdb_file,
bfactor,
b_chain=None,
):
parser = PDBParser()
with warnings.catch_warnings(record=True):
structure = parser.get_structure(
"_",
in_pdb_file,
)
i = 0
for chain in structure.get_chains():
if exists(b_chain) and chain._id != b_chain:
continue
for r in chain.get_residues():
[a.set_bfactor(bfactor[i]) for a in r.get_atoms()]
i += 1
io = PDBIO()
io.set_structure(structure)
io.save(out_pdb_file)
def cdr_indices(
chothia_pdb_file,
cdr,
offset_heavy=True,
):
"""Gets the index of a given CDR loop"""
cdr_chothia_range_dict = {
"h1": (26, 32),
"h2": (52, 56),
"h3": (95, 102),
"l1": (24, 34),
"l2": (50, 56),
"l3": (89, 97)
}
cdr = str.lower(cdr)
assert cdr in cdr_chothia_range_dict.keys()
chothia_range = cdr_chothia_range_dict[cdr]
chain_id = cdr[0].upper()
parser = PDBParser()
pdb_id = basename(chothia_pdb_file).split('.')[0]
structure = parser.get_structure(
pdb_id,
chothia_pdb_file,
)
cdr_chain_structure = None
for chain in structure.get_chains():
if chain.id == chain_id:
cdr_chain_structure = chain
break
if cdr_chain_structure is None:
print("PDB must have a chain with chain id \"[PBD ID]:{}\"".format(
chain_id))
sys.exit(-1)
residue_id_nums = [res.get_id()[1] for res in cdr_chain_structure]
# Binary search to find the start and end of the CDR loop
cdr_start = bisect_left(
residue_id_nums,
chothia_range[0],
)
cdr_end = bisect_right(
residue_id_nums,
chothia_range[1],
) - 1
if len(get_pdb_chain_seq(
chothia_pdb_file,
chain_id=chain_id,
)) != len(residue_id_nums):
print('ERROR in PDB file ' + chothia_pdb_file)
print('residue id len', len(residue_id_nums))
if chain_id == "L" and offset_heavy:
heavy_seq_len = get_pdb_chain_seq(
chothia_pdb_file,
chain_id="H",
)
cdr_start += len(heavy_seq_len)
cdr_end += len(heavy_seq_len)
return cdr_start, cdr_end
def process_prediction(
model_out,
pdb_file,
fasta_file,
skip_pdb=False,
do_refine=True,
use_openmm=False,
do_renum=False,
use_abnum=False,
):
prmsd = rearrange(
model_out.prmsd,
"b (l a) -> b l a",
a=4,
)
model_out.prmsd = prmsd
if skip_pdb:
return model_out
coords = model_out.coords.squeeze(0).detach()
res_rmsd = prmsd.square().mean(dim=-1).sqrt().squeeze(0)
seq_dict = get_fasta_chain_dict(fasta_file)
full_seq = "".join(list(seq_dict.values()))
delims = np.cumsum([len(s) for s in seq_dict.values()]).tolist()
save_PDB(
pdb_file,
coords,
full_seq,
atoms=['N', 'CA', 'C', 'CB', 'O'],
error=res_rmsd,
delim=delims,
)
if do_refine:
if use_openmm:
from igfold.refine.openmm_ref import refine
else:
try:
from igfold.refine.pyrosetta_ref import refine
except ImportError as e:
print(
"Warning: PyRosetta not available. Using OpenMM instead.")
print(e)
from igfold.refine.openmm_ref import refine
refine(pdb_file)
if do_renum:
if use_abnum:
from igfold.utils.pdb import renumber_pdb
else:
try:
from igfold.utils.anarci_ import renumber_pdb
except ImportError as e:
print(
"Warning: ANARCI not available. Provide --use_abnum to renumber with the AbNum server."
)
print(e)
renumber_pdb = lambda x, y: None
renumber_pdb(
pdb_file,
pdb_file,
)
write_pdb_bfactor(
pdb_file,
pdb_file,
bfactor=res_rmsd,
)
return model_out
def get_sequence_dict(
sequences,
pdb_file,
fasta_file=None,
ignore_cdrs=None,
ignore_chain=None,
template_pdb=None,
save_decoys=True,
):
if exists(sequences) and exists(fasta_file):
print("Both sequences and fasta file provided. Using fasta file.")
seq_dict = get_fasta_chain_dict(fasta_file)
elif not exists(sequences) and exists(fasta_file):
seq_dict = get_fasta_chain_dict(fasta_file)
elif exists(sequences):
seq_dict = sequences
else:
exit("Must provide sequences or fasta file.")
# return seq_dict
if not exists(fasta_file):
fasta_file = pdb_file.replace(".pdb", ".fasta")
with open(fasta_file, "w") as f:
for chain, seq in seq_dict.items():
f.write(">{}\n{}\n".format(
chain,
seq,
))
temp_coords, temp_mask = process_template(
template_pdb,
fasta_file,
ignore_cdrs=ignore_cdrs,
ignore_chain=ignore_chain,
)
model_in = IgFoldInput(
sequences=seq_dict.values(),
template_coords=temp_coords,
template_mask=temp_mask,
)
num_models = 4
try_gpu = True
project_path =r'D:\PDB蛋白质' ##填下载的权重的路径
ckpt_path = os.path.join(
project_path,
"*.ckpt",
)
model_ckpts = list(glob(ckpt_path))
model_ckpts = list(sorted(model_ckpts))[:num_models]
print(f"Loading {num_models} IgFold models...")
device = torch.device(
"cuda:0" if torch.cuda.is_available() and try_gpu else "cpu")
print(f"Using device: {device}")
models = []
for ckpt_file in model_ckpts:
print(f"Loading {ckpt_file}...")
models.append(IgFold.load_from_checkpoint(ckpt_file).eval().to(device))
# print(f"Loading {model_ckpts}...")
# models = torch.load(model_ckpts)
print(f"Successfully loaded {num_models} IgFold models.")
model_outs, scores = [], []
with torch.no_grad():
for i, model in enumerate(models):
model_out = model(model_in)
# x=np.array(model_out.coords)
print(model_out.coords.shape)
if save_decoys:
decoy_pdb_file = os.path.splitext(
pdb_file)[0] + f".decoy{i}.pdb"
process_prediction(
model_out,
decoy_pdb_file,
fasta_file,
do_refine=False,
use_openmm=False,
do_renum=False,
use_abnum=False,
)
scores.append(model_out.prmsd.quantile(0.9))
model_outs.append(model_out)
best_model_i = scores.index(min(scores))
print(best_model_i)
model_out = model_outs[best_model_i]
print(model_out.coords.shape)
process_prediction(
model_out,
pdb_file,
fasta_file,
skip_pdb=False,
do_refine=False,
use_openmm=False,
do_renum=False,
use_abnum=False,
)
return model_out
#pdb_file填保存路径
#sequences序列
get_sequence_dict(sequences = {
"H": "EVQLQQSGAEVVRSGASVKLSCTASGFNIKDYYIHWVKQRPEKGLEWIGWIDPEIGDTEYVPKFQGKATMTADTSSNTAYLQLSSLTSEDTAVYYCNAGHDYDRGRFPYWGQGTLVTVSAAKTTPPSVYPLAPGSAAQTNSMVTLGCLVKGYFPEPVTVTWNSGSLSSGVHTFPAVLQSDLYTLSSSVTVPSSTWPSETVTCNVAHPASSTKVDKKIVPRD",
"L": "DIVMTQSQKFMSTSVGDRVSITCKASQNVGTAVAWYQQKPGQSPKLMIYSASNRYTGVPDRFTGSGSGTDFTLTISNMQSEDLADYFCQQYSSYPLTFGAGTKLELKRADAAPTVSIFPPSSEQLTSGGASVVCFLNNFYPKDINVKWKIDGSERQNGVLNSATDQDSKDSTYSMSSTLTLTKDEYERHNSYTCEATHKTSTSPIVKSFNRNEC"
},pdb_file = r'D:\PDB蛋白质\test.pdb')