2021SC@SDUSC软件工程应用与实践11----DGraphDTA代码分析二

2021SC@SDUSC

一,前言

  上一次主要分析了用接触图获取蛋白质特征的过程,这一次我们从整体分析一下训练DTA模型。

二,源码

  在github上:https://github.com/595693085/DGraphDTA

三,源码分析

  引入生成数据集的包和之前搭好的GNN网络

from torch_geometric.data import DataLoader

from gnn import GNNNet
from utils import *
from emetrics import *

from data_process import create_dataset_for_5folds

  选用训练的数据集,cuda,以及5叠中哪一叠作为验证集,其他作为数据集(这样分五次训练,取平均可以防止取验证集不当带来的影响)

# datasets = [['davis', 'kiba'][int(sys.argv[1])]]
datasets = ['davis']
print('datasets:',datasets)
#cuda_name = ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'][int(sys.argv[2])]
cuda_name='cuda:0'
print('cuda_name:', cuda_name)
#fold = [0, 1, 2, 3, 4][int(sys.argv[3])]
fold=0;

  指定学习率,批大小,学习轮次

TRAIN_BATCH_SIZE = 512
TEST_BATCH_SIZE = 512
LR = 0.001
NUM_EPOCHS = 20

  将给定的蛋白质序列,分子smiles序列文件生成data,同时处理成dataloader(便于批处理)这里是处重点,等下会单独详细分析这块儿

for dataset in datasets:
    train_data, valid_data = create_dataset_for_5folds(dataset, fold)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True,
                                               collate_fn=collate)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False,
                                               collate_fn=collate)

  模型训练,并打印出发生优化的轮次,以及最低的损失

    for epoch in range(NUM_EPOCHS):
        #TODO 得到的trainDS 和validDS 一个用来训练一个用来预测
        train(model, device, train_loader, optimizer, epoch + 1)
        print('predicting for valid data')
        G, P = predicting(model, device, valid_loader)
        val = get_mse(G, P)
        print('valid result:', val, best_mse)
        if val < best_mse:
            best_mse = val
            best_epoch = epoch + 1
            torch.save(model.state_dict(), model_file_name)
            print('rmse improved at epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
        else:
            print('No improvement since epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)

  我们接下来详细看下create_dataset_for_5folds方法,这里读取了一个train_fold_setting1.txt文件

def create_dataset_for_5folds(dataset, fold=0):
    # load dataset
    dataset_path = 'data/' + dataset + '/'
    #TODO 问题三 不清楚train_fold_setting1.txt的作用,这个里面的序列与proteins.txt中的key并不匹配
    train_fold_origin = json.load(open(dataset_path + 'folds/train_fold_setting1.txt'))
    train_fold_origin = [e for e in train_fold_origin]  # for 5 folds

文件内容如下

他其实是根据Y(亲和力结果的文件)中能发生接触的那些DT对的下标 ,后面可用于过滤

2021SC@SDUSC软件工程应用与实践11----DGraphDTA代码分析二_第1张图片

这里载入下蛋白质序列和分子序列

ligands = json.load(open(dataset_path + 'ligands_can.txt'), object_pairs_hook=OrderedDict)
    proteins = json.load(open(dataset_path + 'proteins.txt'), object_pairs_hook=OrderedDict)

根据Y过滤出有亲和力且在选中的训练集中的下标对

            rows, cols = np.where(np.isnan(affinity) == False)
            # rows,cols长为20036
            print(len(train_folds))
            rows, cols = rows[train_folds], cols[train_folds] #TODO 这里应该是取得非nan且在train_folds中的部分?

定义一个ls变量,存储发生接触的药物,分子,药物key,亲和力加起来,并存到.csv文件中

            for pair_ind in range(len(rows)):
                if not valid_target(prot_keys[cols[pair_ind]], dataset):  # ensure the contact and aln files exists
                    continue
                ls = []
                a = rows[pair_ind]
                b = len(drugs)
                ls += [drugs[rows[pair_ind]]]
                ls += [prots[cols[pair_ind]]]
                ls += [prot_keys[cols[pair_ind]]]
                ls += [affinity[rows[pair_ind], cols[pair_ind]]]
                train_fold_entries.append(ls)
                valid_train_count += 1
            csv_file = 'data/' + dataset + '_' + 'fold_' + str(fold) + '_' + opt + '.csv'
            data_to_csv(csv_file, train_fold_entries)

调用之前博客分析的方法,生成分子图和蛋白质图,达到蛋白质节点特征

    # create smile graph
    smile_graph = {}
    for smile in compound_iso_smiles:
        g = smile_to_graph(smile)
        smile_graph[smile] = g
    # print(smile_graph['CN1CCN(C(=O)c2cc3cc(Cl)ccc3[nH]2)CC1']) #for test

    # create target graph
    # print('target_key', len(target_key), len(set(target_key)))
    target_graph = {}
    for key in target_key:
        if not valid_target(key, dataset):  # ensure the contact and aln files exists
            continue
            #TODO 调用target_to_graph方法得到序列长,特征,图边集 用一个g就可以接受?
        g = target_to_graph(key, proteins[key], contac_path, msa_path)
        target_graph[key] = g

初始化一下图网络

result_str = ''
USE_CUDA = torch.cuda.is_available()
device = torch.device(cuda_name if USE_CUDA else 'cpu')
model = GNNNet()
model.to(device)
model_st = GNNNet.__name__
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

开始训练:目标是得到一个可以提取蛋白质和药物特征的GNN

    for epoch in range(NUM_EPOCHS):
        #TODO 得到的trainDS 和validDS 一个用来训练一个用来预测
        train(model, device, train_loader, optimizer, epoch + 1)
        print('predicting for valid data')
        G, P = predicting(model, device, valid_loader)
        val = get_mse(G, P)
        print('valid result:', val, best_mse)
        if val < best_mse:
            best_mse = val
            best_epoch = epoch + 1
            torch.save(model.state_dict(), model_file_name)
            print('rmse improved at epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
        else:
            print('No improvement since epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)

四,总结

  以上便是本篇博客的全部内容,如有问题,欢迎在评论区指正!

你可能感兴趣的:(软件工程应用于实践,python,机器学习,深度学习,人工智能,神经网络)