2021SC@SDUSC
上一次主要分析了用接触图获取蛋白质特征的过程,这一次我们从整体分析一下训练DTA模型。
在github上:https://github.com/595693085/DGraphDTA
引入生成数据集的包和之前搭好的GNN网络
from torch_geometric.data import DataLoader
from gnn import GNNNet
from utils import *
from emetrics import *
from data_process import create_dataset_for_5folds
选用训练的数据集,cuda,以及5叠中哪一叠作为验证集,其他作为数据集(这样分五次训练,取平均可以防止取验证集不当带来的影响)
# datasets = [['davis', 'kiba'][int(sys.argv[1])]]
datasets = ['davis']
print('datasets:',datasets)
#cuda_name = ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'][int(sys.argv[2])]
cuda_name='cuda:0'
print('cuda_name:', cuda_name)
#fold = [0, 1, 2, 3, 4][int(sys.argv[3])]
fold=0;
指定学习率,批大小,学习轮次
TRAIN_BATCH_SIZE = 512
TEST_BATCH_SIZE = 512
LR = 0.001
NUM_EPOCHS = 20
将给定的蛋白质序列,分子smiles序列文件生成data,同时处理成dataloader(便于批处理)这里是处重点,等下会单独详细分析这块儿
for dataset in datasets:
train_data, valid_data = create_dataset_for_5folds(dataset, fold)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True,
collate_fn=collate)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False,
collate_fn=collate)
模型训练,并打印出发生优化的轮次,以及最低的损失
for epoch in range(NUM_EPOCHS):
#TODO 得到的trainDS 和validDS 一个用来训练一个用来预测
train(model, device, train_loader, optimizer, epoch + 1)
print('predicting for valid data')
G, P = predicting(model, device, valid_loader)
val = get_mse(G, P)
print('valid result:', val, best_mse)
if val < best_mse:
best_mse = val
best_epoch = epoch + 1
torch.save(model.state_dict(), model_file_name)
print('rmse improved at epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
else:
print('No improvement since epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
我们接下来详细看下create_dataset_for_5folds方法,这里读取了一个train_fold_setting1.txt文件
def create_dataset_for_5folds(dataset, fold=0):
# load dataset
dataset_path = 'data/' + dataset + '/'
#TODO 问题三 不清楚train_fold_setting1.txt的作用,这个里面的序列与proteins.txt中的key并不匹配
train_fold_origin = json.load(open(dataset_path + 'folds/train_fold_setting1.txt'))
train_fold_origin = [e for e in train_fold_origin] # for 5 folds
文件内容如下
他其实是根据Y(亲和力结果的文件)中能发生接触的那些DT对的下标 ,后面可用于过滤
这里载入下蛋白质序列和分子序列
ligands = json.load(open(dataset_path + 'ligands_can.txt'), object_pairs_hook=OrderedDict)
proteins = json.load(open(dataset_path + 'proteins.txt'), object_pairs_hook=OrderedDict)
根据Y过滤出有亲和力且在选中的训练集中的下标对
rows, cols = np.where(np.isnan(affinity) == False)
# rows,cols长为20036
print(len(train_folds))
rows, cols = rows[train_folds], cols[train_folds] #TODO 这里应该是取得非nan且在train_folds中的部分?
定义一个ls变量,存储发生接触的药物,分子,药物key,亲和力加起来,并存到.csv文件中
for pair_ind in range(len(rows)):
if not valid_target(prot_keys[cols[pair_ind]], dataset): # ensure the contact and aln files exists
continue
ls = []
a = rows[pair_ind]
b = len(drugs)
ls += [drugs[rows[pair_ind]]]
ls += [prots[cols[pair_ind]]]
ls += [prot_keys[cols[pair_ind]]]
ls += [affinity[rows[pair_ind], cols[pair_ind]]]
train_fold_entries.append(ls)
valid_train_count += 1
csv_file = 'data/' + dataset + '_' + 'fold_' + str(fold) + '_' + opt + '.csv'
data_to_csv(csv_file, train_fold_entries)
调用之前博客分析的方法,生成分子图和蛋白质图,达到蛋白质节点特征
# create smile graph
smile_graph = {}
for smile in compound_iso_smiles:
g = smile_to_graph(smile)
smile_graph[smile] = g
# print(smile_graph['CN1CCN(C(=O)c2cc3cc(Cl)ccc3[nH]2)CC1']) #for test
# create target graph
# print('target_key', len(target_key), len(set(target_key)))
target_graph = {}
for key in target_key:
if not valid_target(key, dataset): # ensure the contact and aln files exists
continue
#TODO 调用target_to_graph方法得到序列长,特征,图边集 用一个g就可以接受?
g = target_to_graph(key, proteins[key], contac_path, msa_path)
target_graph[key] = g
初始化一下图网络
result_str = ''
USE_CUDA = torch.cuda.is_available()
device = torch.device(cuda_name if USE_CUDA else 'cpu')
model = GNNNet()
model.to(device)
model_st = GNNNet.__name__
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
开始训练:目标是得到一个可以提取蛋白质和药物特征的GNN
for epoch in range(NUM_EPOCHS):
#TODO 得到的trainDS 和validDS 一个用来训练一个用来预测
train(model, device, train_loader, optimizer, epoch + 1)
print('predicting for valid data')
G, P = predicting(model, device, valid_loader)
val = get_mse(G, P)
print('valid result:', val, best_mse)
if val < best_mse:
best_mse = val
best_epoch = epoch + 1
torch.save(model.state_dict(), model_file_name)
print('rmse improved at epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
else:
print('No improvement since epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
以上便是本篇博客的全部内容,如有问题,欢迎在评论区指正!