论文地址:graph2vec: Learning Distributed Representations of Graphs
github代码:https://github.com/benedekrozemberczki/graph2vec
这篇文章对论文做了简单解读:https://blog.csdn.net/qq_39388410/article/details/103895874
下面是代码部分,代码的关键部分我做了注释,子图抽取用的是Weisfeiler-Lehman算法
"""Graph2Vec module."""
import json
import glob
import hashlib
import pandas as pd
import networkx as nx
from tqdm import tqdm
from joblib import Parallel, delayed
#from param_parser import parameter_parser
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
### 默认参数 可调
input_path = "./dataset/"
output_path = "./features/ret.csv"
dimensions = 128
workers =4
epochs =10
min_count = 5
wl_iterations = 2
learning_rate = 0.025
down_sampling = 0.0001
class WeisfeilerLehmanMachine: ###WeisfeilerLehman 子图抽取
"""
Weisfeiler Lehman feature extractor class.
"""
def __init__(self, graph, features, iterations):
"""
Initialization method which also executes feature extraction.
:param graph: The Nx graph object.
:param features: Feature hash table.
:param iterations: Number of WL iterations.
"""
self.iterations = iterations
self.graph = graph
self.features = features
self.nodes = self.graph.nodes()
self.extracted_features = [str(v) for k, v in features.items()]
self.do_recursions()
def do_a_recursion(self):
"""
The method does a single WL recursion.
:return new_features: The hash table with extracted WL features.
"""
new_features = {}
for node in self.nodes:
nebs = self.graph.neighbors(node) ##取邻居节点
degs = [self.features[neb] for neb in nebs] ## 取邻居节点的label/attribute,这里其实是节点的度
features = [str(self.features[node])]+sorted([str(deg) for deg in degs]) ## 添加上自己的度 传播模型
features = "_".join(features) ## 拼接起来 3_2_3_3
hash_object = hashlib.md5(features.encode())
hashing = hash_object.hexdigest() ## Md5 映射节点的label/attribute aec8787e7f0dd7c71c319a65ba4d670f
new_features[node] = hashing ## 更新节点
self.extracted_features = self.extracted_features + list(new_features.values())
return new_features
def do_recursions(self):
"""
The method does a series of WL recursions.
"""
for _ in range(self.iterations):
self.features = self.do_a_recursion()
def dataset_reader(path): ## 数据读取 返回networkx类型的graph,所有节点的度的字典 图的编号name
"""
Function to read the graph and features from a json file.
:param path: The path to the graph json.
:return graph: The graph object.
:return features: Features hash table.
:return name: Name of the graph.
"""
name = path.strip(".json").split("\\")[-1] ### window下"\\",linux下“/”
data = json.load(open(path))
graph = nx.from_edgelist(data["edges"])
if "features" in data.keys():
features = data["features"]
else:
features = nx.degree(graph)
features = {int(k): v for k, v in features.items()}
return graph, features, name
def feature_extractor(path, rounds): ### 数据转换成DOC
"""
Function to extract WL features from a graph.
:param path: The path to the graph json.
:param rounds: Number of WL iterations.
:return doc: Document collection object.
"""
graph, features, name = dataset_reader(path)
machine = WeisfeilerLehmanMachine(graph, features, rounds)
doc = TaggedDocument(words=machine.extracted_features, tags=["g_" + name])
return doc
def save_embedding(output_path, model, files, dimensions):
"""
Function to save the embedding.
:param output_path: Path to the embedding csv.
:param model: The embedding model object.
:param files: The list of files.
:param dimensions: The embedding dimension parameter.
"""
out = []
for f in files:
identifier = f.split("\\")[-1].strip(".json") ### window下"\\",linux下“/”
out.append([int(identifier)] + list(model.docvecs["g_"+identifier]))
column_names = ["type"]+["x_"+str(dim) for dim in range(dimensions)]
out = pd.DataFrame(out, columns=column_names)
out = out.sort_values(["type"])
out.to_csv(output_path, index=None)
def main():
"""
Main function to read the graph list, extract features.
Learn the embedding and save it.
:param args: Object with the arguments.
"""
graphs = glob.glob(input_path + "*.json")
print("\nFeature extraction started.\n")
### document_collections是一个list,每个ele都是一个TaggedDocument([],tags=[])
###每个ele是长这样,有extracted_features 和 tags组成
'''
TaggedDocument(words=['1', '1', '1', '1', '1', '2', '2', '2', '2', '2', '3', '3', '3', '3', '3', '3', '3', '3', '3', '3', '3',
'aec8787e7f0dd7c71c319a65ba4d670f', 'aec8787e7f0dd7c71c319a65ba4d670f', '711b064fd221cccd34a14f29b696a265',
'68bea2bc149d4ee3548d35a2e6216542', '711b064fd221cccd34a14f29b696a265', '5c960b2e285107aff4a29ddbc53aa94e',
'aec8787e7f0dd7c71c319a65ba4d670f', '711b064fd221cccd34a14f29b696a265', '68bea2bc149d4ee3548d35a2e6216542',
'2220178a6d449ba0a24980f4dfab8cd8', '3d34cbdc7cb225501d23e26f3d2ed945', 'c1544793797909a2c72ad52bdb0487d2',
'25f4b183e297058b59282613e231b5af', '5c960b2e285107aff4a29ddbc53aa94e', '648af13d92bf7d1019e0af47c8e71074',
'648af13d92bf7d1019e0af47c8e71074', 'f0f421418433ba3cb592238eb7e51441', 'ab35e84a215f0f711ed629c2abb9efa0',
'ab35e84a215f0f711ed629c2abb9efa0', 'ab35e84a215f0f711ed629c2abb9efa0', 'ab35e84a215f0f711ed629c2abb9efa0',
'56218143711ad0d943613e81debfcaf6', 'ba32ef1397dd62a5661bff7d0c1bbd75', '8fe9ae5843120923a0dc4469a3089eef',
'16f04b8e020bce1ae04100dbd9f1aa88', 'b7a5cac1ba5e57afa69f50128cc2d95b', '8d889315ef6d7e55846fdb69f15fd1b7',
'd9b9b6f3ee62f5ef7059770fd8903538', 'b7a5cac1ba5e57afa69f50128cc2d95b', '16f04b8e020bce1ae04100dbd9f1aa88',
'184b020bdf857802edb40b0ebb5bd069', 'eb453d564c2ad11ed23ea2af7a5a4a7a', 'c754cc7c963e0b33df17b62e3c6a2620',
'8e5f9e44897e7115e6f5c363046cde09', '63067dd99d3136f686fa814590355cf6', 'dcd616306a4cee9df958df008aae4b69',
'7e4162c3ff0e44efc7826087c413e1b8', '537179ae9dc33f665a2eca172eeba096', 'b4b3107ce9aae441f8fd4a61ed1fecfb',
'b4b3107ce9aae441f8fd4a61ed1fecfb', 'b4b3107ce9aae441f8fd4a61ed1fecfb', 'b4b3107ce9aae441f8fd4a61ed1fecfb'], tags=['g_0']
'''
document_collections = Parallel(n_jobs=workers)(delayed(feature_extractor)(g,wl_iterations) for g in tqdm(graphs))
print("\nOptimization started.\n")
model = Doc2Vec(document_collections,
vector_size=dimensions,
window=0,
min_count=min_count,
dm=0,
sample=down_sampling,
workers=workers,
epochs=epochs,
alpha=learning_rate)
save_embedding(output_path, model, graphs,dimensions)
if __name__ == '__main__':
main()
结果输出对每个graph的一个128维的向量表示,后面会用DGL库提供的不同类别的图做一个简单图聚类的Demo