2021SC@SDUSC
DGraphDTA任务训练部分完整代码:
model = GNNNet()
model.to(device)
model_st = GNNNet.__name__
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#设置lossfunction
loss_fn = nn.MSELoss()
#设置优化器
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
for dataset in datasets:
train_data, valid_data = create_dataset_for_5folds(dataset, fold)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True,
collate_fn=collate)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False,
collate_fn=collate)
best_mse = 1000
best_test_mse = 1000
best_epoch = -1
model_file_name = 'models/model_' + model_st + '_' + dataset + '_' + str(fold) + '.model'
for epoch in range(NUM_EPOCHS):
train(model, device, train_loader, optimizer, epoch + 1)
print('predicting for valid data')
G, P = predicting(model, device, valid_loader)
val = get_mse(G, P)
print('valid result:', val, best_mse)
if val < best_mse:
best_mse = val
best_epoch = epoch + 1
torch.save(model.state_dict(), model_file_name)
print('rmse improved at epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
else:
print('No improvement since epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
初始化模型model,并将其送入device(此次为cuda)
model = GNNNet()
model.to(device)
model_st = GNNNet.__name__
设置损失函数为均方损失函数,设置优化器为Adam
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)#参数为模型的参数以及learning rate
该部分将数据集拆分封装为训练集和测试集,此部分细节本文不作深入研究
for dataset in datasets:
train_data, valid_data = create_dataset_for_5folds(dataset, fold)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True,
collate_fn=collate)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False,
collate_fn=collate)
此处对上面的循环做一个解释,这里的datasets为davis和kiba,所以是对两个数据集进行分割封装操作
#获取数据集
datasets = [['davis', 'kiba'][int(sys.argv[1])]]
数据加载部分:
def create_dataset_for_5folds(dataset, fold=0):
# load dataset
dataset_path = 'data/' + dataset + '/' #数据集路径
train_fold_origin = json.load(open(dataset_path + 'folds/train_fold_setting1.txt')) #通过json.load加载txt
train_fold_origin = [e for e in train_fold_origin] # for 5 folds 将其转换为列表
ligands = json.load(open(dataset_path + 'ligands_can.txt'), object_pairs_hook=OrderedDict) #加载分子配体SMILES序列
proteins = json.load(open(dataset_path + 'proteins.txt'), object_pairs_hook=OrderedDict) #加载蛋白质fasta序列
# load contact and aln 加载接触图和aln
msa_path = 'data/' + dataset + '/aln'
contac_path = 'data/' + dataset + '/pconsc4'
msa_list = []
contact_list = []
#根据protein的dict中的每一个key到对应的aln和contact中寻找对应蛋白质的aln和contact,将它们append到mas_list和contact_list中与fasta数据一一对应
for key in proteins:
msa_list.append(os.path.join(msa_path, key + '.aln'))
contact_list.append(os.path.join(contac_path, key + '.npy'))
# load train,valid and test entries
train_folds = []
valid_fold = train_fold_origin[fold] # one fold
for i in range(len(train_fold_origin)): # other folds
if i != fold:
train_folds += train_fold_origin[i]
affinity = pickle.load(open(dataset_path + 'Y', 'rb'), encoding='latin1')
drugs = []
prots = []
prot_keys = []
drug_smiles = []
# smiles
for d in ligands.keys():
lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]), isomericSmiles=True)
drugs.append(lg)
drug_smiles.append(ligands[d])
# seqs
for t in proteins.keys():
prots.append(proteins[t])
prot_keys.append(t)
if dataset == 'davis':
affinity = [-np.log10(y / 1e9) for y in affinity]
affinity = np.asarray(affinity)
opts = ['train', 'valid']
valid_train_count = 0
valid_valid_count = 0
for opt in opts:
if opt == 'train':
rows, cols = np.where(np.isnan(affinity) == False)
rows, cols = rows[train_folds], cols[train_folds]
train_fold_entries = []
for pair_ind in range(len(rows)):
if not valid_target(prot_keys[cols[pair_ind]], dataset): # ensure the contact and aln files exists
continue
ls = []
ls += [drugs[rows[pair_ind]]]
ls += [prots[cols[pair_ind]]]
ls += [prot_keys[cols[pair_ind]]]
ls += [affinity[rows[pair_ind], cols[pair_ind]]]
train_fold_entries.append(ls)
valid_train_count += 1
csv_file = 'data/' + dataset + '_' + 'fold_' + str(fold) + '_' + opt + '.csv'
data_to_csv(csv_file, train_fold_entries)
elif opt == 'valid':
rows, cols = np.where(np.isnan(affinity) == False)
rows, cols = rows[valid_fold], cols[valid_fold]
valid_fold_entries = []
for pair_ind in range(len(rows)):
if not valid_target(prot_keys[cols[pair_ind]], dataset):
continue
ls = []
ls += [drugs[rows[pair_ind]]]
ls += [prots[cols[pair_ind]]]
ls += [prot_keys[cols[pair_ind]]]
ls += [affinity[rows[pair_ind], cols[pair_ind]]]
valid_fold_entries.append(ls)
valid_valid_count += 1
csv_file = 'data/' + dataset + '_' + 'fold_' + str(fold) + '_' + opt + '.csv'
data_to_csv(csv_file, valid_fold_entries)
print('dataset:', dataset)
# print('len(set(drugs)),len(set(prots)):', len(set(drugs)), len(set(prots)))
# entries with protein contact and aln files are marked as effiective
print('fold:', fold)
print('train entries:', len(train_folds), 'effective train entries', valid_train_count)
print('valid entries:', len(valid_fold), 'effective valid entries', valid_valid_count)
compound_iso_smiles = drugs
target_key = prot_keys
# create smile graph
smile_graph = {}
for smile in compound_iso_smiles:
g = smile_to_graph(smile)
smile_graph[smile] = g
# print(smile_graph['CN1CCN(C(=O)c2cc3cc(Cl)ccc3[nH]2)CC1']) #for test
# create target graph
# print('target_key', len(target_key), len(set(target_key)))
target_graph = {}
for key in target_key:
if not valid_target(key, dataset): # ensure the contact and aln files exists
continue
g = target_to_graph(key, proteins[key], contac_path, msa_path)
target_graph[key] = g
# count the number of proteins with aln and contact files
print('effective drugs,effective prot:', len(smile_graph), len(target_graph))
if len(smile_graph) == 0 or len(target_graph) == 0:
raise Exception('no protein or drug, run the script for datasets preparation.')
# 'data/davis_fold_0_train.csv' or data/kiba_fold_0__train.csv'
train_csv = 'data/' + dataset + '_' + 'fold_' + str(fold) + '_' + 'train' + '.csv'
df_train_fold = pd.read_csv(train_csv)
train_drugs, train_prot_keys, train_Y = list(df_train_fold['compound_iso_smiles']), list(
df_train_fold['target_key']), list(df_train_fold['affinity'])
train_drugs, train_prot_keys, train_Y = np.asarray(train_drugs), np.asarray(train_prot_keys), np.asarray(train_Y)
train_dataset = DTADataset(root='data', dataset=dataset + '_' + 'train', xd=train_drugs, target_key=train_prot_keys,
y=train_Y, smile_graph=smile_graph, target_graph=target_graph)
df_valid_fold = pd.read_csv('data/' + dataset + '_' + 'fold_' + str(fold) + '_' + 'valid' + '.csv')
valid_drugs, valid_prots_keys, valid_Y = list(df_valid_fold['compound_iso_smiles']), list(
df_valid_fold['target_key']), list(df_valid_fold['affinity'])
valid_drugs, valid_prots_keys, valid_Y = np.asarray(valid_drugs), np.asarray(valid_prots_keys), np.asarray(
valid_Y)
valid_dataset = DTADataset(root='data', dataset=dataset + '_' + 'train', xd=valid_drugs,
target_key=valid_prots_keys, y=valid_Y, smile_graph=smile_graph,
target_graph=target_graph)
return train_dataset, valid_dataset
使用ctrl+F查找“]”,发现该文件形式为一个有5个数组元素的数组(5*n)的二维数组,而其中每个的数字代表一对蛋白质和药物结合的key,通过这个数字可以找到对应的药靶结合。
于是我们可以知道以下代码的含义,valid_fold是获取上面txt中第一个数组作为验证集,剩下的四个通过循环全部添加作为训练集。
train_folds = []
valid_fold = train_fold_origin[fold] # one fold
for i in range(len(train_fold_origin)): # other folds
if i != fold:
train_folds += train_fold_origin[i]