IGfold的window版本应用及原理(无rosetta微调) 快速预测抗体结构的IgFold深度学习方法,其准确率可以与AlphaFold2媲美。

 模型流程图

IGfold的window版本应用及原理(无rosetta微调) 快速预测抗体结构的IgFold深度学习方法,其准确率可以与AlphaFold2媲美。_第1张图片

官方权重下载

链接:https://pan.baidu.com/s/1Zbqw5t2fWo9Z9Zep07Y74g 
提取码:1234 
 

模型可应用代码(代码最下面填充序列和保存路径)

import time
import os
from typing import List
from einops import rearrange
import torch
import numpy as np
import sys
import io
from glob import glob
from typing import Union, List
import requests
import warnings
from os.path import splitext, basename
from Bio.PDB import PDBParser, PDBIO
from Bio.SeqUtils import seq1
from Bio import SeqIO
from bisect import bisect_left, bisect_right
import torch
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from einops import rearrange
import os
from einops import rearrange, repeat
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix
from igfold.model.components import TriangleGraphTransformer, IPAEncoder, IPATransformer
from igfold.utils.coordinates import get_ideal_coords
from igfold.model.components.GraphTransformer import GraphTransformer
from invariant_point_attention.invariant_point_attention import IPABlock, exists
ATOM_DIM = 3

def get_ideal_coords(center=False):
    N = torch.tensor([[0, 0, -1.458]], dtype=float)
    A = torch.tensor([[0, 0, 0]], dtype=float)
    B = torch.tensor([[0, 1.426, 0.531]], dtype=float)
    C = place_fourth_atom(
        B,
        A,
        N,
        torch.tensor(2.460),
        torch.tensor(0.615),
        torch.tensor(-2.143),
    )

    coords = torch.cat([N, A, C, B]).float()

    if center:
        coords -= coords.mean(
            dim=0,
            keepdim=True,
        )

    return coords



@dataclass
class IgFoldOutput():
    """
    Output type of for IgFold model.
    """

    coords: torch.FloatTensor
    prmsd: torch.FloatTensor
    translations: torch.FloatTensor
    rotations: torch.FloatTensor
    coords_loss: Optional[torch.FloatTensor] = None
    torsion_loss: Optional[torch.FloatTensor] = None
    bondlen_loss: Optional[torch.FloatTensor] = None
    prmsd_loss: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None
    bert_hidden: Optional[torch.FloatTensor] = None
    bert_embs: Optional[torch.FloatTensor] = None
    gt_embs: Optional[torch.FloatTensor] = None
    structure_embs: Optional[torch.FloatTensor] = None
def bb_prmsd_l1(
    pdev,
    pred,
    target,
    align_mask=None,
    mask=None,
):
    aligned_target = do_kabsch(
        mobile=target,
        stationary=pred,
        align_mask=align_mask,
    )
    bb_dev = (pred - aligned_target).norm(dim=-1)
    loss = F.l1_loss(
        pdev,
        bb_dev,
        reduction='none',
    )

    if exists(mask):
        mask = repeat(mask, "b l -> b (l 4)")
        loss = torch.sum(
            loss * mask,
            dim=-1,
        ) / torch.sum(
            mask,
            dim=-1,
        )
    else:
        loss = loss.mean(-1)

    loss = loss.mean(-1).unsqueeze(0)

    return loss

def bond_length_l1(
    pred,
    target,
    mask,
    offsets=[1, 2],
):
    losses = []
    for c in range(pred.shape[0]):
        m, p, t = mask[c], pred[c], target[c]
        for o in offsets:
            m_ = (torch.stack([m[:-o], m[o:]])).all(0)
            pred_lens = torch.norm(p[:-o] - p[o:], dim=-1)
            target_lens = torch.norm(t[:-o] - t[o:], dim=-1)

            losses.append(
                torch.abs(pred_lens[m_] - target_lens[m_], ).mean() / o)

    return torch.stack(losses)

def do_kabsch(
    mobile,
    stationary,
    align_mask=None,
):
    mobile_, stationary_ = mobile.clone(), stationary.clone()
    if exists(align_mask):
        mobile_[~align_mask] = mobile_[align_mask].mean(dim=-2)
        stationary_[~align_mask] = stationary_[align_mask].mean(dim=-2)
        _, kabsch_xform = kabsch(
            mobile_,
            stationary_,
        )
    else:
        _, kabsch_xform = kabsch(
            mobile_,
            stationary_,
        )

    return kabsch_xform(mobile)
def kabsch_mse(
    pred,
    target,
    align_mask=None,
    mask=None,
    clamp=0.,
    sqrt=False,
):
    aligned_target = do_kabsch(
        mobile=target,
        stationary=pred.detach(),
        align_mask=align_mask,
    )
    mse = F.mse_loss(
        pred,
        aligned_target,
        reduction='none',
    ).mean(-1)

    if clamp > 0:
        mse = torch.clamp(mse, max=clamp**2)

    if exists(mask):
        mse = torch.sum(
            mse * mask,
            dim=-1,
        ) / torch.sum(
            mask,
            dim=-1,
        )
    else:
        mse = mse.mean(-1)

    if sqrt:
        mse = mse.sqrt()

    return mse
@dataclass
class IgFoldInput():
    """
    Input type of for IgFold model.
    """

    sequences: List[Union[torch.LongTensor, str]]
    template_coords: Optional[torch.FloatTensor] = None
    template_mask: Optional[torch.BoolTensor] = None
    batch_mask: Optional[torch.BoolTensor] = None
    align_mask: Optional[torch.BoolTensor] = None
    coords_label: Optional[torch.FloatTensor] = None
    return_embeddings: Optional[bool] = False

def kabsch(
    mobile,
    stationary,
    return_translation_rotation=False,
):
    X = rearrange(
        mobile,
        "... l d -> ... d l",
    )
    Y = rearrange(
        stationary,
        "... l d -> ... d l",
    )

    #  center X and Y to the origin
    XT, YT = X.mean(dim=-1, keepdim=True), Y.mean(dim=-1, keepdim=True)
    X_ = X - XT
    Y_ = Y - YT

    # calculate convariance matrix
    C = torch.einsum("... x l, ... y l -> ... x y", X_, Y_)

    # Optimal rotation matrix via SVD
    if int(torch.__version__.split(".")[1]) < 8:
        # warning! int torch 1.<8 : W must be transposed
        V, S, W = torch.svd(C)
        W = rearrange(W, "... a b -> ... b a")
    else:
        V, S, W = torch.linalg.svd(C)

    # determinant sign for direction correction
    v_det = torch.det(V.to("cpu")).to(X.device)
    w_det = torch.det(W.to("cpu")).to(X.device)
    d = (v_det * w_det) < 0.0
    if d.any():
        S[d] = S[d] * (-1)
        V[d, :] = V[d, :] * (-1)

    # Create Rotation matrix U
    U = torch.matmul(V, W)  #.to(device)

    U = rearrange(
        U,
        "... d x -> ... x d",
    )
    XT = rearrange(
        XT,
        "... d x -> ... x d",
    )
    YT = rearrange(
        YT,
        "... d x -> ... x d",
    )

    if return_translation_rotation:
        return XT, U, YT

    transform = lambda coords: torch.einsum(
        "... l d, ... x d -> ... l x",
        coords - XT,
        U,
    ) + YT
    mobile = transform(mobile)

    return mobile, transform

class IgFold(pl.LightningModule):
    def __init__(
        self,
        config,
        config_overwrite=None,
    ):
        super().__init__()

        import transformers

        self.save_hyperparameters()
        config = self.hparams.config
        if exists(config_overwrite):
            config.update(config_overwrite)

        self.tokenizer = config["tokenizer"]
        self.vocab_size = len(self.tokenizer.vocab)
        self.bert_model = transformers.BertModel(config["bert_config"])
        bert_layers = self.bert_model.config.num_hidden_layers
        self.bert_feat_dim = self.bert_model.config.hidden_size
        self.bert_attn_dim = bert_layers * self.bert_model.config.num_attention_heads

        self.node_dim = config["node_dim"]

        self.depth = config["depth"]
        self.gt_depth = config["gt_depth"]
        self.gt_heads = config["gt_heads"]

        self.temp_ipa_depth = config["temp_ipa_depth"]
        self.temp_ipa_heads = config["temp_ipa_heads"]

        self.str_ipa_depth = config["str_ipa_depth"]
        self.str_ipa_heads = config["str_ipa_heads"]

        self.dev_ipa_depth = config["dev_ipa_depth"]
        self.dev_ipa_heads = config["dev_ipa_heads"]

        self.str_node_transform = nn.Sequential(
            nn.Linear(
                self.bert_feat_dim,
                self.node_dim,
            ),
            nn.ReLU(),
            nn.LayerNorm(self.node_dim),
        )
        self.str_edge_transform = nn.Sequential(
            nn.Linear(
                self.bert_attn_dim,
                self.node_dim,
            ),
            nn.ReLU(),
            nn.LayerNorm(self.node_dim),
        )

        self.main_block = TriangleGraphTransformer(
            dim=self.node_dim,
            edge_dim=self.node_dim,
            depth=self.depth,
            tri_dim_hidden=2 * self.node_dim,
            gt_depth=self.gt_depth,
            gt_heads=self.gt_heads,
            gt_dim_head=self.node_dim // 2,
        )
        self.template_ipa = IPAEncoder(
            dim=self.node_dim,
            depth=self.temp_ipa_depth,
            heads=self.temp_ipa_heads,
            require_pairwise_repr=True,
        )

        self.structure_ipa = IPATransformer(
            dim=self.node_dim,
            depth=self.str_ipa_depth,
            heads=self.str_ipa_heads,
            require_pairwise_repr=True,
        )

        self.dev_node_transform = nn.Sequential(
            nn.Linear(self.bert_feat_dim, self.node_dim),
            nn.ReLU(),
            nn.LayerNorm(self.node_dim),
        )
        self.dev_edge_transform = nn.Sequential(
            nn.Linear(
                self.bert_attn_dim,
                self.node_dim,
            ),
            nn.ReLU(),
            nn.LayerNorm(self.node_dim),
        )
        self.dev_ipa = IPAEncoder(
            dim=self.node_dim,
            depth=self.dev_ipa_depth,
            heads=self.dev_ipa_heads,
            require_pairwise_repr=True,
        )
        self.dev_linear = nn.Linear(
            self.node_dim,
            4,
        )

    def get_tokens(
        self,
        seq,
    ):
        if isinstance(seq, str):
            tokens = self.tokenizer.encode(
                " ".join(list(seq)),
                return_tensors="pt",
            )
        elif isinstance(seq, list) and isinstance(seq[0], str):
            seqs = [" ".join(list(s)) for s in seq]
            tokens = self.tokenizer.batch_encode_plus(
                seqs,
                return_tensors="pt",
            )["input_ids"]
        else:
            tokens = seq

        return tokens.to(self.device)

    def get_bert_feats(self, tokens):
        bert_output = self.bert_model(
            tokens,
            output_hidden_states=True,
            output_attentions=True,
        )

        feats = bert_output.hidden_states[-1]
        feats = feats[:, 1:-1]

        attn = torch.cat(
            bert_output.attentions,
            dim=1,
        )
        attn = attn[:, :, 1:-1, 1:-1]
        attn = rearrange(
            attn,
            "b d i j -> b i j d",
        )

        hidden = bert_output.hidden_states

        return feats, attn, hidden

    def get_coords_tran_rot(
        self,
        temp_coords,
        batch_size,
        seq_len,
        center=True,
    ):
        res_coords = rearrange(
            temp_coords,
            "b (l a) d -> b l a d",
            l=seq_len,
        )
        res_ideal_coords = repeat(
            get_ideal_coords(center=center),
            "a d -> b l a d",
            b=batch_size,
            l=seq_len,
        ).to(self.device)
        _, rotations, translations = kabsch(
            res_ideal_coords,
            res_coords,
            return_translation_rotation=True,
        )
        translations = rearrange(
            translations,
            "b l () d -> b l d",
        )

        return translations, rotations

    def clean_input(
        self,
        input: IgFoldInput,
    ):
        tokens = [self.get_tokens(s) for s in input.sequences]

        temp_coords = input.template_coords
        temp_mask = input.template_mask
        batch_mask = input.batch_mask
        align_mask = input.align_mask

        batch_size = tokens[0].shape[0]
        seq_lens = [max(t.shape[1] - 2, 0) for t in tokens]
        seq_len = sum(seq_lens)

        if not exists(temp_coords):
            temp_coords = torch.zeros(
                batch_size,
                4 * seq_len,
                ATOM_DIM,
                device=self.device,
            ).float()
        if not exists(temp_mask):
            temp_mask = torch.zeros(
                batch_size,
                4 * seq_len,
                device=self.device,
            ).bool()
        if not exists(batch_mask):
            batch_mask = torch.ones(
                batch_size,
                4 * seq_len,
                device=self.device,
            ).bool()
        if not exists(align_mask):
            align_mask = torch.ones(
                batch_size,
                4 * seq_len,
                device=self.device,
            ).bool()

        align_mask = align_mask & batch_mask  # Should already be masked by batch_mask anyway
        temp_coords[~temp_mask] = 0.
        for i, (tc, m) in enumerate(zip(temp_coords, temp_mask)):
            temp_coords[i][m] -= tc[m].mean(-2)

        input.sequences = tokens
        input.template_coords = temp_coords
        input.template_mask = temp_mask
        input.batch_mask = batch_mask
        input.align_mask = align_mask

        batch_size = tokens[0].shape[0]
        seq_lens = [max(t.shape[1] - 2, 0) for t in tokens]
        seq_len = sum(seq_lens)

        return input, batch_size, seq_lens, seq_len

    def forward(
        self,
        input: IgFoldInput,
    ):
        input, batch_size, seq_lens, seq_len = self.clean_input(input)
        tokens = input.sequences
        temp_coords = input.template_coords
        temp_mask = input.template_mask
        coords_label = input.coords_label
        batch_mask = input.batch_mask
        align_mask = input.align_mask
        return_embeddings = input.return_embeddings

        res_batch_mask = rearrange(
            batch_mask,
            "b (l a) -> b l a",
            a=4,
        ).all(-1)
        res_temp_mask = rearrange(
            temp_mask,
            "b (l a) -> b l a",
            a=4,
        ).all(-1)

        ### Model forward pass

        bert_feats, bert_attns, bert_hidden = [], [], []
        for t in tokens:
            f, a, h = self.get_bert_feats(t)
            bert_feats.append(f)
            bert_attns.append(a)
            bert_hidden.append(h)

        bert_feats = torch.cat(bert_feats, dim=1)
        bert_attn = torch.zeros(
            (batch_size, seq_len, seq_len, self.bert_attn_dim),
            device=self.device,
        )
        for i, (a, l) in enumerate(zip(bert_attns, seq_lens)):
            cum_l = sum(seq_lens[:i])
            bert_attn[:, cum_l:cum_l + l, cum_l:cum_l + l, :] = a

        temp_translations, temp_rotations = self.get_coords_tran_rot(
            temp_coords,
            batch_size,
            seq_len,
        )

        str_nodes = self.str_node_transform(bert_feats)
        str_edges = self.str_edge_transform(bert_attn)
        str_nodes, str_edges = self.main_block(
            str_nodes,
            str_edges,
            mask=res_batch_mask,
        )
        gt_embs = str_nodes
        str_nodes = self.template_ipa(
            str_nodes,
            translations=temp_translations,
            rotations=temp_rotations,
            pairwise_repr=str_edges,
            mask=res_temp_mask,
        )
        structure_embs = str_nodes

        ipa_coords, ipa_translations, ipa_quaternions = self.structure_ipa(
            str_nodes,
            translations=None,
            quaternions=None,
            pairwise_repr=str_edges,
            mask=res_batch_mask,
        )
        ipa_rotations = quaternion_to_matrix(ipa_quaternions)

        dev_nodes = self.dev_node_transform(bert_feats)
        dev_edges = self.dev_edge_transform(bert_attn)
        dev_out_feats = self.dev_ipa(
            dev_nodes,
            translations=ipa_translations.detach(),
            rotations=ipa_rotations.detach(),
            pairwise_repr=dev_edges,
            mask=res_batch_mask,
        )
        dev_pred = F.relu(self.dev_linear(dev_out_feats))
        dev_pred = rearrange(dev_pred, "b l a -> b (l a)", a=4)

        bb_coords = rearrange(
            ipa_coords[:, :, :3],
            "b l a d -> b (l a) d",
        )
        flat_coords = rearrange(
            ipa_coords[:, :, :4],
            "b l a d -> b (l a) d",
        )

        ### Calculate losses if given labels
        loss = torch.zeros(
            batch_size,
            device=self.device,
        )
        if exists(coords_label):
            rmsd_clamp = self.hparams.config["rmsd_clamp"]
            coords_loss = kabsch_mse(
                flat_coords,
                coords_label,
                align_mask=batch_mask,
                mask=batch_mask,
                clamp=rmsd_clamp,
            )

            bb_coords_label = rearrange(
                rearrange(coords_label, "b (l a) d -> b l a d", a=4)[:, :, :3],
                "b l a d -> b (l a) d")
            bb_batch_mask = rearrange(
                rearrange(batch_mask, "b (l a) -> b l a", a=4)[:, :, :3],
                "b l a -> b (l a)")
            bondlen_loss = bond_length_l1(
                bb_coords,
                bb_coords_label,
                bb_batch_mask,
            )

            prmsd_loss = []
            cum_seq_lens = np.cumsum([0] + seq_lens)
            for sl_i, sl in enumerate(seq_lens):
                align_mask_ = align_mask.clone()
                align_mask_[:, :cum_seq_lens[sl_i]] = False
                align_mask_[:, cum_seq_lens[sl_i + 1]:] = False
                res_batch_mask_ = res_batch_mask.clone()
                res_batch_mask_[:, :cum_seq_lens[sl_i]] = False
                res_batch_mask_[:, cum_seq_lens[sl_i + 1]:] = False

                if sl == 0 or align_mask_.sum() == 0 or res_batch_mask_.sum(
                ) == 0:
                    continue

                prmsd_loss.append(
                    bb_prmsd_l1(
                        dev_pred,
                        flat_coords.detach(),
                        coords_label,
                        align_mask=align_mask_,
                        mask=res_batch_mask_,
                    ))
            prmsd_loss = sum(prmsd_loss)

            coords_loss, bondlen_loss = list(
                map(
                    lambda l: rearrange(l, "(c b) -> b c", b=batch_size).mean(
                        1),
                    [coords_loss, bondlen_loss],
                ))

            loss += sum([coords_loss, bondlen_loss, prmsd_loss])
        else:
            prmsd_loss, coords_loss, bondlen_loss = None, None, None

        if not exists(coords_label):
            loss = None

        bert_hidden = bert_hidden if return_embeddings else None
        bert_embs = bert_feats if return_embeddings else None
        gt_embs = gt_embs if return_embeddings else None
        structure_embs = structure_embs if return_embeddings else None
        output = IgFoldOutput(
            coords=ipa_coords,
            prmsd=dev_pred,
            translations=ipa_translations,
            rotations=ipa_rotations,
            coords_loss=coords_loss,
            bondlen_loss=bondlen_loss,
            prmsd_loss=prmsd_loss,
            loss=loss,
            bert_hidden=bert_hidden,
            bert_embs=bert_embs,
            gt_embs=gt_embs,
            structure_embs=structure_embs,
        )

        return output


def pdb2fasta(pdb_file, num_chains=None):
    """Converts a PDB file to a fasta formatted string using its ATOM data"""
    pdb_id = basename(pdb_file).split('.')[0]
    parser = PDBParser()
    structure = parser.get_structure(
        pdb_id,
        pdb_file,
    )

    real_num_chains = len([0 for _ in structure.get_chains()])
    if num_chains is not None and num_chains != real_num_chains:
        print('WARNING: Skipping {}. Expected {} chains, got {}'.format(
            pdb_file, num_chains, real_num_chains))
        return ''

    fasta = ''
    for chain in structure.get_chains():
        id_ = chain.id
        seq = seq1(''.join([residue.resname for residue in chain]))
        fasta += '>{}:{}\t{}\n'.format(pdb_id, id_, len(seq))
        max_line_length = 80
        for i in range(0, len(seq), max_line_length):
            fasta += f'{seq[i:i + max_line_length]}\n'
    return fasta

def get_atom_coord(residue, atom_type):
    if atom_type in residue:
        return residue[atom_type].get_coord()
    else:
        return [0, 0, 0]

def get_cb_or_ca_coord(residue):
    if 'CB' in residue:
        return residue['CB'].get_coord()
    elif 'CA' in residue:
        return residue['CA'].get_coord()
    else:
        return [0, 0, 0]

def place_fourth_atom(
    a_coord: torch.Tensor,
    b_coord: torch.Tensor,
    c_coord: torch.Tensor,
    length: torch.Tensor,
    planar: torch.Tensor,
    dihedral: torch.Tensor,
) -> torch.Tensor:
    """
    Given 3 coords + a length + a planar angle + a dihedral angle, compute a fourth coord
    """
    bc_vec = b_coord - c_coord
    bc_vec = bc_vec / bc_vec.norm(dim=-1, keepdim=True)

    n_vec = (b_coord - a_coord).expand(bc_vec.shape).cross(bc_vec)
    n_vec = n_vec / n_vec.norm(dim=-1, keepdim=True)

    m_vec = [bc_vec, n_vec.cross(bc_vec), n_vec]
    d_vec = [
        length * torch.cos(planar),
        length * torch.sin(planar) * torch.cos(dihedral),
        -length * torch.sin(planar) * torch.sin(dihedral)
    ]

    d_coord = c_coord + sum([m * d for m, d in zip(m_vec, d_vec)])

    return d_coord

def get_atom_coords_mask(coords):
    mask = torch.ByteTensor([1 if sum(_) != 0 else 0 for _ in coords])
    mask = mask & (1 - torch.any(torch.isnan(coords), dim=1).byte())
    return mask

def place_missing_cb_o(atom_coords):
    cb_coords = place_fourth_atom(
        atom_coords['C'],
        atom_coords['N'],
        atom_coords['CA'],
        torch.tensor(1.522),
        torch.tensor(1.927),
        torch.tensor(-2.143),
    )
    o_coords = place_fourth_atom(
        torch.roll(atom_coords['N'], shifts=-1, dims=0),
        atom_coords['CA'],
        atom_coords['C'],
        torch.tensor(1.231),
        torch.tensor(2.108),
        torch.tensor(-3.142),
    )

    bb_mask = get_atom_coords_mask(atom_coords['N']) & get_atom_coords_mask(
        atom_coords['CA']) & get_atom_coords_mask(atom_coords['C'])
    missing_cb = (get_atom_coords_mask(atom_coords['CB']) & bb_mask) == 0
    atom_coords['CB'][missing_cb] = cb_coords[missing_cb]

    bb_mask = get_atom_coords_mask(
        torch.roll(
            atom_coords['N'],
            shifts=-1,
            dims=0,
        )) & get_atom_coords_mask(atom_coords['CA']) & get_atom_coords_mask(
            atom_coords['C'])
    missing_o = (get_atom_coords_mask(atom_coords['O']) & bb_mask) == 0
    atom_coords['O'][missing_o] = o_coords[missing_o]


def get_atom_coords(pdb_file, fasta_file=None):
    p = PDBParser()
    file_name = splitext(basename(pdb_file))[0]
    structure = p.get_structure(
        file_name,
        pdb_file,
    )

    if fasta_file:
        residues = []
        for chain in structure.get_chains():
            pdb_seq = get_pdb_chain_seq(
                pdb_file,
                chain.id,
            )

            chain_dict = {"A": "H", "B": "L", "H": "H", "L": "L"}
            fasta_seq = get_fasta_chain_seq(
                fasta_file,
                chain_dict[chain.id],
            )

            chain_residues = list(chain.get_residues())
            continuous_ranges = get_continuous_ranges(chain_residues)

            fasta_residues = [[]] * len(fasta_seq)
            fasta_r = (0, 0)
            for pdb_r in continuous_ranges:
                fasta_r_start = fasta_seq[fasta_r[1]:].index(
                    pdb_seq[pdb_r[0]:pdb_r[1]]) + fasta_r[1]
                fasta_r_end = (len(pdb_seq) if pdb_r[1] == None else
                               pdb_r[1]) - pdb_r[0] + fasta_r_start
                fasta_r = (fasta_r_start, fasta_r_end)
                fasta_residues[fasta_r[0]:fasta_r[1]] = chain_residues[
                    pdb_r[0]:pdb_r[1]]

            residues += fasta_residues
    else:
        residues = list(structure.get_residues())

    n_coords = torch.tensor([get_atom_coord(r, 'N') for r in residues])
    ca_coords = torch.tensor([get_atom_coord(r, 'CA') for r in residues])
    c_coords = torch.tensor([get_atom_coord(r, 'C') for r in residues])
    cb_coords = torch.tensor([get_atom_coord(r, 'CB') for r in residues])
    cb_ca_coords = torch.tensor([get_cb_or_ca_coord(r) for r in residues])
    o_coords = torch.tensor([get_atom_coord(r, 'O') for r in residues])

    atom_coords = {}
    atom_coords['N'] = n_coords
    atom_coords['CA'] = ca_coords
    atom_coords['C'] = c_coords
    atom_coords['CB'] = cb_coords
    atom_coords['CBCA'] = cb_ca_coords
    atom_coords['O'] = o_coords

    place_missing_cb_o(atom_coords)

    return atom_coords

def get_pdb_chain_seq(
    pdb_file,
    chain_id,
):
    raw_fasta = pdb2fasta(pdb_file)
    fasta = SeqIO.parse(
        io.StringIO(raw_fasta),
        'fasta',
    )
    chain_sequences = {
        chain.id.split(':')[1]: str(chain.seq)
        for chain in fasta
    }
    if chain_id not in chain_sequences.keys():
        print(
            "No such chain in PDB file. Chain must have a chain ID of \"[PDB ID]:{}\""
            .format(chain_id))
        return None
    return chain_sequences[chain_id]


def get_fasta_chain_seq(
    fasta_file,
    chain_id,
):
    for chain in SeqIO.parse(fasta_file, 'fasta'):
        if ":{}".format(chain_id) in chain.id:
            return str(chain.seq)

def process_template(
    pdb_file,
    fasta_file,
    ignore_cdrs=None,
    ignore_chain=None,
):
    temp_coords, temp_mask = None, None
    if exists(pdb_file):
        temp_coords = get_atom_coords(
            pdb_file,
            fasta_file=fasta_file,
        )
        temp_coords = torch.stack(
            [
                temp_coords['N'], temp_coords['CA'], temp_coords['C'],
                temp_coords['CB']
            ],
            dim=1,
        ).view(-1, 3).unsqueeze(0)

        temp_mask = torch.ones(temp_coords.shape[:2]).bool()
        temp_mask[temp_coords.isnan().any(-1)] = False
        temp_mask[temp_coords.sum(-1) == 0] = False

        if exists(ignore_cdrs):
            cdr_names = ["h1", "h2", "h3", "l1", "l2", "l3"]
            if ignore_cdrs == False:
                cdr_names = []
            elif type(ignore_cdrs) == List:
                cdr_names = ignore_cdrs
            elif type(ignore_cdrs) == str:
                cdr_names = [ignore_cdrs]

            for cdr in cdr_names:
                cdr_range = cdr_indices(pdb_file, cdr)
                temp_mask[:, (cdr_range[0] - 1) * 4:(cdr_range[1] + 2) *
                          4] = False
        if exists(ignore_chain) and ignore_chain in ["H", "L"]:
            seq_dict = get_fasta_chain_dict(fasta_file)
            hlen = len(seq_dict["H"])
            if ignore_chain == "H":
                temp_mask[:, :hlen * 4] = False
            elif ignore_chain == "L":
                temp_mask[:, hlen * 4:] = False

    return temp_coords, temp_mask

def get_continuous_ranges(residues):
    """ Returns ranges of residues which are continuously connected (peptide bond length 1.2-1.45 Å) """
    dists = []
    for res_i in range(len(residues) - 1):
        dists.append(
            np.linalg.norm(
                np.array(get_atom_coord(residues[res_i], "C")) -
                np.array(get_atom_coord(residues[res_i + 1], "N"))))

    ranges = []
    start_i = 0
    for d_i, d in enumerate(dists):
        if d > 1.45 or d < 1.2:
            ranges.append((start_i, d_i + 1))
            start_i = d_i + 1
        if d_i == len(dists) - 1:
            ranges.append((start_i, None))

    return ranges

def get_fasta_chain_dict(fasta_file):
    seq_dict = {}
    for chain in SeqIO.parse(fasta_file, 'fasta'):
        seq_dict[chain.id] = str(chain.seq)

    return seq_dict


def exists(x):
    return x is not None
_aa_dict = {
    'A': '0',
    'C': '1',
    'D': '2',
    'E': '3',
    'F': '4',
    'G': '5',
    'H': '6',
    'I': '7',
    'K': '8',
    'L': '9',
    'M': '10',
    'N': '11',
    'P': '12',
    'Q': '13',
    'R': '14',
    'S': '15',
    'T': '16',
    'V': '17',
    'W': '18',
    'Y': '19'
}

_aa_1_3_dict = {
    'A': 'ALA',
    'C': 'CYS',
    'D': 'ASP',
    'E': 'GLU',
    'F': 'PHE',
    'G': 'GLY',
    'H': 'HIS',
    'I': 'ILE',
    'K': 'LYS',
    'L': 'LEU',
    'M': 'MET',
    'N': 'ASN',
    'P': 'PRO',
    'Q': 'GLN',
    'R': 'ARG',
    'S': 'SER',
    'T': 'THR',
    'V': 'VAL',
    'W': 'TRP',
    'Y': 'TYR',
    '-': 'GAP'
}
def save_PDB(
    out_pdb: str,
    coords: torch.Tensor,
    seq: str,
    chains: List[str] = None,
    error: torch.Tensor = None,
    delim: Union[int, List[int]] = None,
    atoms=['N', 'CA', 'C', 'O', 'CB'],
) -> None:
    """
    Write set of N, CA, C, O, CB coords to PDB file
    """

    if not exists(chains):
        chains = ["H", "L"]

    if type(delim) == type(None):
        delim = -1
    elif type(delim) == int:
        delim = [delim]

    if not exists(error):
        error = torch.zeros(len(seq))

    with open(out_pdb, "w") as f:
        k = 0
        for r, residue in enumerate(coords):
            AA = _aa_1_3_dict[seq[r]]
            for a, atom in enumerate(residue):
                if AA == "GLY" and atoms[a] == "CB": continue
                x, y, z = atom
                chain_id = chains[np.where(np.array(delim) - r > 0)[0][0]]
                f.write(
                    "ATOM  %5d  %-2s  %3s %s%4d    %8.3f%8.3f%8.3f  %4.2f  %4.2f\n"
                    % (k + 1, atoms[a], AA, chain_id, r + 1, x, y, z, 1,
                       error[r]))
                k += 1
        f.close()

def write_pdb_bfactor(
    in_pdb_file,
    out_pdb_file,
    bfactor,
    b_chain=None,
):
    parser = PDBParser()
    with warnings.catch_warnings(record=True):
        structure = parser.get_structure(
            "_",
            in_pdb_file,
        )

    i = 0
    for chain in structure.get_chains():
        if exists(b_chain) and chain._id != b_chain:
            continue

        for r in chain.get_residues():
            [a.set_bfactor(bfactor[i]) for a in r.get_atoms()]
            i += 1

    io = PDBIO()
    io.set_structure(structure)
    io.save(out_pdb_file)


def cdr_indices(
    chothia_pdb_file,
    cdr,
    offset_heavy=True,
):
    """Gets the index of a given CDR loop"""
    cdr_chothia_range_dict = {
        "h1": (26, 32),
        "h2": (52, 56),
        "h3": (95, 102),
        "l1": (24, 34),
        "l2": (50, 56),
        "l3": (89, 97)
    }

    cdr = str.lower(cdr)
    assert cdr in cdr_chothia_range_dict.keys()

    chothia_range = cdr_chothia_range_dict[cdr]
    chain_id = cdr[0].upper()

    parser = PDBParser()
    pdb_id = basename(chothia_pdb_file).split('.')[0]
    structure = parser.get_structure(
        pdb_id,
        chothia_pdb_file,
    )
    cdr_chain_structure = None
    for chain in structure.get_chains():
        if chain.id == chain_id:
            cdr_chain_structure = chain
            break
    if cdr_chain_structure is None:
        print("PDB must have a chain with chain id \"[PBD ID]:{}\"".format(
            chain_id))
        sys.exit(-1)

    residue_id_nums = [res.get_id()[1] for res in cdr_chain_structure]

    # Binary search to find the start and end of the CDR loop
    cdr_start = bisect_left(
        residue_id_nums,
        chothia_range[0],
    )
    cdr_end = bisect_right(
        residue_id_nums,
        chothia_range[1],
    ) - 1

    if len(get_pdb_chain_seq(
            chothia_pdb_file,
            chain_id=chain_id,
    )) != len(residue_id_nums):
        print('ERROR in PDB file ' + chothia_pdb_file)
        print('residue id len', len(residue_id_nums))

    if chain_id == "L" and offset_heavy:
        heavy_seq_len = get_pdb_chain_seq(
            chothia_pdb_file,
            chain_id="H",
        )
        cdr_start += len(heavy_seq_len)
        cdr_end += len(heavy_seq_len)

    return cdr_start, cdr_end

def process_prediction(
    model_out,
    pdb_file,
    fasta_file,
    skip_pdb=False,
    do_refine=True,
    use_openmm=False,
    do_renum=False,
    use_abnum=False,
):
    prmsd = rearrange(
        model_out.prmsd,
        "b (l a) -> b l a",
        a=4,
    )
    model_out.prmsd = prmsd

    if skip_pdb:
        return model_out

    coords = model_out.coords.squeeze(0).detach()
    res_rmsd = prmsd.square().mean(dim=-1).sqrt().squeeze(0)

    seq_dict = get_fasta_chain_dict(fasta_file)
    full_seq = "".join(list(seq_dict.values()))
    delims = np.cumsum([len(s) for s in seq_dict.values()]).tolist()
    save_PDB(
        pdb_file,
        coords,
        full_seq,
        atoms=['N', 'CA', 'C', 'CB', 'O'],
        error=res_rmsd,
        delim=delims,
    )

    if do_refine:
        if use_openmm:
            from igfold.refine.openmm_ref import refine
        else:
            try:
                from igfold.refine.pyrosetta_ref import refine
            except ImportError as e:
                print(
                    "Warning: PyRosetta not available. Using OpenMM instead.")
                print(e)
                from igfold.refine.openmm_ref import refine

        refine(pdb_file)

    if do_renum:
        if use_abnum:
            from igfold.utils.pdb import renumber_pdb
        else:
            try:
                from igfold.utils.anarci_ import renumber_pdb
            except ImportError as e:
                print(
                    "Warning: ANARCI not available. Provide --use_abnum to renumber with the AbNum server."
                )
                print(e)
                renumber_pdb = lambda x, y: None

        renumber_pdb(
            pdb_file,
            pdb_file,
        )

    write_pdb_bfactor(
        pdb_file,
        pdb_file,
        bfactor=res_rmsd,
    )

    return model_out

def get_sequence_dict(
    sequences,
    pdb_file,
    fasta_file=None,
    ignore_cdrs=None,
    ignore_chain=None,
    template_pdb=None,
    save_decoys=True,
):
    if exists(sequences) and exists(fasta_file):
        print("Both sequences and fasta file provided. Using fasta file.")
        seq_dict = get_fasta_chain_dict(fasta_file)
    elif not exists(sequences) and exists(fasta_file):
        seq_dict = get_fasta_chain_dict(fasta_file)
    elif exists(sequences):
        seq_dict = sequences
    else:
        exit("Must provide sequences or fasta file.")

    # return seq_dict

    if not exists(fasta_file):
        fasta_file = pdb_file.replace(".pdb", ".fasta")
        with open(fasta_file, "w") as f:
            for chain, seq in seq_dict.items():
                f.write(">{}\n{}\n".format(
                    chain,
                    seq,
                ))

    temp_coords, temp_mask = process_template(
        template_pdb,
        fasta_file,
        ignore_cdrs=ignore_cdrs,
        ignore_chain=ignore_chain,
    )
    model_in = IgFoldInput(
        sequences=seq_dict.values(),
        template_coords=temp_coords,
        template_mask=temp_mask,
    )

    num_models = 4
    try_gpu = True

    project_path =r'D:\PDB蛋白质'  ##填下载的权重的路径
    ckpt_path = os.path.join(
        project_path,
        "*.ckpt",
    )

    model_ckpts = list(glob(ckpt_path))


    model_ckpts = list(sorted(model_ckpts))[:num_models]

    print(f"Loading {num_models} IgFold models...")

    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and try_gpu else "cpu")
    print(f"Using device: {device}")

    models = []
    for ckpt_file in model_ckpts:
        print(f"Loading {ckpt_file}...")
        models.append(IgFold.load_from_checkpoint(ckpt_file).eval().to(device))
    # print(f"Loading {model_ckpts}...")
    # models = torch.load(model_ckpts)

    print(f"Successfully loaded {num_models} IgFold models.")

    model_outs, scores = [], []
    with torch.no_grad():
        for i, model in enumerate(models):
            model_out = model(model_in)
            # x=np.array(model_out.coords)
            print(model_out.coords.shape)
            if save_decoys:
                decoy_pdb_file = os.path.splitext(
                    pdb_file)[0] + f".decoy{i}.pdb"
                process_prediction(
                    model_out,
                    decoy_pdb_file,
                    fasta_file,
                    do_refine=False,
                    use_openmm=False,
                    do_renum=False,
                    use_abnum=False,
                )

            scores.append(model_out.prmsd.quantile(0.9))
            model_outs.append(model_out)

    best_model_i = scores.index(min(scores))
    print(best_model_i)
    model_out = model_outs[best_model_i]
    print(model_out.coords.shape)
    process_prediction(
        model_out,
        pdb_file,
        fasta_file,
        skip_pdb=False,
        do_refine=False,
        use_openmm=False,
        do_renum=False,
        use_abnum=False,
    )

    return model_out

#pdb_file填保存路径
#sequences序列

get_sequence_dict(sequences = {
    "H": "EVQLQQSGAEVVRSGASVKLSCTASGFNIKDYYIHWVKQRPEKGLEWIGWIDPEIGDTEYVPKFQGKATMTADTSSNTAYLQLSSLTSEDTAVYYCNAGHDYDRGRFPYWGQGTLVTVSAAKTTPPSVYPLAPGSAAQTNSMVTLGCLVKGYFPEPVTVTWNSGSLSSGVHTFPAVLQSDLYTLSSSVTVPSSTWPSETVTCNVAHPASSTKVDKKIVPRD",
    "L": "DIVMTQSQKFMSTSVGDRVSITCKASQNVGTAVAWYQQKPGQSPKLMIYSASNRYTGVPDRFTGSGSGTDFTLTISNMQSEDLADYFCQQYSSYPLTFGAGTKLELKRADAAPTVSIFPPSSEQLTSGGASVVCFLNNFYPKDINVKWKIDGSERQNGVLNSATDQDSKDSTYSMSSTLTLTKDEYERHNSYTCEATHKTSTSPIVKSFNRNEC"
},pdb_file = r'D:\PDB蛋白质\test.pdb')

你可能感兴趣的:(AI制药,深度学习,pytorch,python)