疾病-基因与图神经网络和图自动编码器的相互作用:
学习图自编码器
在PYG库中的卷积层中有许多不同的变体,但每层的核心是三个步骤:消息传递、聚合和更新。
在pytorch_geometric中,可以使用一行代码构建GCN层:
from torch_geometric.nn import GCNConv
conv = GCNConv(in_channels, out_channels)
in_channels和out_channels分别表示节点的输入表示维度和输出表示维度的大小。一般来说 in_channels=X.shape[1](X是节点特征矩阵)。
GCN虽然是最简单的GNN,但在实践中效果很好,GCN的变体通常排在图形数据集基准的首位。可以查看开放图形基准测试的数据集OGB的排行榜。在排行榜中单纯的GCN可能使用的数据集比较少。
首先,我了解一点机器学习领域中的自编码器(AutoEncode)。自编码器包含两个主要的部分:Encode(编码)和 Decode(解码)。AE的作用大体就是把一个高维向量X编码成低维的隐变量h,然后h通过解码器解码到初始维度,最好的情况就是解码器能够近似或者完美恢复原来的输入。这就要求编码器尽可能地学习最有信息量的特征。
那么在图中原理差不多也是一样的,在GAE中,我们有一个编码器,其工作是将输入图映射到较低维空间,以及一个解码器,用于从低维嵌入重建输入图。也就是说,我们将解码器输出解释为重建的邻接矩阵 A ^ \hat A A^。目标是优化模型,使重建损失( A ^ \hat A A^和原始图形输入A之间的差异)最小化。
我们将定义一个具有两个图形卷积层、一个 ReLU 和一个 dropout 的 GCN,以帮助建模性能。当然,在编码器部分我们可以多样化,插入其他GNN卷积层。
GCNEncode:
导入库
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GAE
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import train_test_split_edges, negative_sampling, degree
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.transforms as T
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import random
import string
from sklearn import metrics
from torch_geometric.data import Data, download_url, extract_gz
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class GAEncoder(nn.Module):
def __init__(self, in_channels, hidden_size, out_channels, dropout):
super(GAEncoder, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_size, cached=True)
self.conv2 = GCNConv(hidden_size, out_channels, cached=True)
self.dropout = nn.Dropout(dropout)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.dropout(x)
out = self.conv2(x, edge_index)
return out
model = GAE(GAEncoder(20, 200, 20, 0.5)).to(device)
查看模型结构
GAE(
(encoder): GAEncoder(
(conv1): GCNConv(20, 200)
(conv2): GCNConv(200, 20)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
)
(decoder): InnerProductDecoder() ## 默认解码器点积运算符
)
def train(train_data, model, optimizer):
model.train()
optimizer.zero_grad()
z = model.encode(train_data.x,train_data.edge_index)
loss = model.recon_loss(z, train_data.pos_edge_label_index.to(device))
loss.backward(retain_graph=True)
optimizer.step()
return float(loss)
@torch.no_grad()
def gae_test(test_data,model):
model.eval()
z = model.encode(test_data.x, test_data.edge_index)
loss = model.test(z, test_data.pos_edge_label_index, test_data.neg_edge_label_index)
return loss
以上就是模型GAE的网络架构。
url = 'http://snap.stanford.edu/biodata/datasets/10012/files/DG-AssocMiner_miner-disease-gene.tsv.gz'
extract_gz(download_url(url, '.'), '.')
data_path = "./DG-AssocMiner_miner-disease-gene.tsv"
df = pd.read_csv(data_path, sep="\t")
df.head()
输出
# Disease ID Disease Name Gene ID
0 C0036095 Salivary Gland Neoplasms 1462
1 C0036095 Salivary Gland Neoplasms 1612
2 C0036095 Salivary Gland Neoplasms 182
3 C0036095 Salivary Gland Neoplasms 2011
4 C0036095 Salivary Gland Neoplasms 2019
df.shape
(21357, 3)
导入数据
def load_data(data_path,class_node=519):
df = pd.read_csv(data_path, sep='\t')
dise_id = df['# Disease ID']
Gene_id = df['Gene ID']
dis_mapping = {index_id: int(i) + 0 for i, index_id in enumerate(dise_id.unique())}
gen_mapping = {index_id: int(i) + class_node for i, index_id in enumerate(Gene_id.unique())}
src_nodes = [dis_mapping[index] for index in df['# Disease ID']]
dst_nodes = [gen_mapping[index] for index in df['Gene ID']]
edge_index = torch.tensor([src_nodes, dst_nodes])
rev_edge_index = torch.tensor([dst_nodes, src_nodes])
data = Data()
data.num_nodes = len(dis_mapping) + len(gen_mapping)
data.edge_index = torch.concat([edge_index, rev_edge_index],dim=1)
data.x = torch.ones((data.num_nodes, 20))
return data, gen_mapping, dis_mapping
通过上述,我们得到了无向图的data图数据
data_object, gene_mapping, dis_mapping = load_data(data_path)
print(data_object)
print("Number of genes:", len(gene_mapping))
print("Number of diseases:", len(dz_mapping))
输出
Data(num_nodes=7813, edge_index=[2, 42714], x=[7813, 20])
Number of genes: 7294
Number of diseases: 519
在pytorch_geometric中使用 RandomLinkSplit 方法创建训练集、验证集和测试集。
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.05, num_test=0.15, is_undirected=True,
split_labels=True, add_negative_train_samples=True),
])
train_datasets, val_datasets, test_datasets = transform(data_object)
print("Train Data:", train_datasets)
print("Validation Data:", val_datasets)
print("Test Data:", test_datasets)
查看训练测试集
Train Data: Data(num_nodes=7813, edge_index=[2, 34174], x=[7813, 20], pos_edge_label=[17087], pos_edge_label_index=[2, 17087], neg_edge_label=[17087], neg_edge_label_index=[2, 17087])
Validation Data: Data(num_nodes=7813, edge_index=[2, 34174], x=[7813, 20], pos_edge_label=[1067], pos_edge_label_index=[2, 1067], neg_edge_label=[1067], neg_edge_label_index=[2, 1067])
Test Data: Data(num_nodes=7813, edge_index=[2, 36308], x=[7813, 20], pos_edge_label=[3203], pos_edge_label_index=[2, 3203], neg_edge_label=[3203], neg_edge_label_index=[2, 3203])
接下来就是训练
optimizer = optim.Adam(model.parameters(),lr=0.1)
losses = []
test_auc = []
test_ap = []
train_aucs = []
train_aps = []
for epoch in range(1, 50):
loss = train(train_datasets, model, optimizer)
losses.append(loss)
auc, ap = gae_test(test_datasets, model)
test_auc.append(auc)
test_ap.append(ap)
train_auc, train_ap = gae_test(train_datasets, model)
train_aucs.append(train_auc)
train_aps.append(train_ap)
print('Epoch: {:03d}, test AUC: {:.4f}, test AP: {:.4f}, train AUC: {:.4f}, train AP: {:.4f}, loss:{:.4f}'.format(epoch, auc, ap, train_auc, train_ap, loss))
Epoch: 001, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2068
Epoch: 002, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2009
Epoch: 003, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1856
Epoch: 004, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0031
Epoch: 005, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0008
Epoch: 006, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0743
Epoch: 007, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0811
Epoch: 008, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1047
Epoch: 009, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0700
Epoch: 010, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1119
Epoch: 011, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1230
Epoch: 012, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0303
Epoch: 013, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1233
Epoch: 014, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0642
Epoch: 015, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0782
Epoch: 016, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0172
Epoch: 017, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0432
Epoch: 018, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0900
Epoch: 019, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0705
Epoch: 020, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:2.9925
Epoch: 021, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2228
Epoch: 022, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0448
Epoch: 023, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0485
Epoch: 024, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:2.9706
Epoch: 025, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1790
Epoch: 026, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:2.9618
Epoch: 027, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0103
Epoch: 028, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1602
Epoch: 029, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1063
Epoch: 030, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0232
Epoch: 031, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1293
Epoch: 032, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0388
Epoch: 033, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1522
Epoch: 034, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1759
Epoch: 035, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1749
Epoch: 036, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2155
Epoch: 037, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1248
Epoch: 038, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1518
Epoch: 039, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0606
Epoch: 040, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0665
Epoch: 041, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0464
Epoch: 042, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0648
Epoch: 043, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0416
Epoch: 044, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2085
Epoch: 045, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0527
Epoch: 046, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:2.9981
Epoch: 047, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1345
Epoch: 048, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0027
Epoch: 049, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1076
我不知道为啥训练和测试没有变化,但是loss能看到变化。
通过机器学习中的自编码器AE,迁移到图的自编码器GAE。还有一种VAE变分自编码器,VGAE是GAE的变体。先学一波。
以上来自https://snap.stanford.edu/class/cs224w-2020/