graph2vec: Learning Distributed Representations of Graphs 代码解读

论文地址:graph2vec: Learning Distributed Representations of Graphs

github代码:https://github.com/benedekrozemberczki/graph2vec

这篇文章对论文做了简单解读:https://blog.csdn.net/qq_39388410/article/details/103895874

下面是代码部分,代码的关键部分我做了注释,子图抽取用的是Weisfeiler-Lehman算法

"""Graph2Vec module."""

import json
import glob
import hashlib
import pandas as pd
import networkx as nx
from tqdm import tqdm
from joblib import Parallel, delayed
#from param_parser import parameter_parser
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

### 默认参数 可调
input_path = "./dataset/"
output_path = "./features/ret.csv"
dimensions = 128
workers =4
epochs =10
min_count = 5
wl_iterations = 2
learning_rate = 0.025
down_sampling = 0.0001


class WeisfeilerLehmanMachine: ###WeisfeilerLehman 子图抽取
    """
    Weisfeiler Lehman feature extractor class.
    """
    def __init__(self, graph, features, iterations):
        """
        Initialization method which also executes feature extraction.
        :param graph: The Nx graph object.
        :param features: Feature hash table.
        :param iterations: Number of WL iterations.
        """
        self.iterations = iterations
        self.graph = graph
        self.features = features
        self.nodes = self.graph.nodes()
        self.extracted_features = [str(v) for k, v in features.items()]
        self.do_recursions()

    def do_a_recursion(self):
        """
        The method does a single WL recursion.
        :return new_features: The hash table with extracted WL features.
        """
        new_features = {}
        for node in self.nodes:
            nebs = self.graph.neighbors(node) ##取邻居节点
            degs = [self.features[neb] for neb in nebs] ## 取邻居节点的label/attribute,这里其实是节点的度
            features = [str(self.features[node])]+sorted([str(deg) for deg in degs]) ## 添加上自己的度 传播模型
            features = "_".join(features) ## 拼接起来 3_2_3_3
            hash_object = hashlib.md5(features.encode())
            hashing = hash_object.hexdigest()  ## Md5 映射节点的label/attribute  aec8787e7f0dd7c71c319a65ba4d670f
            new_features[node] = hashing ## 更新节点
        self.extracted_features = self.extracted_features + list(new_features.values())
        return new_features

    def do_recursions(self):
        """
        The method does a series of WL recursions.
        """
        for _ in range(self.iterations):
            self.features = self.do_a_recursion()

def dataset_reader(path): ## 数据读取 返回networkx类型的graph,所有节点的度的字典 图的编号name
    """
    Function to read the graph and features from a json file.
    :param path: The path to the graph json.
    :return graph: The graph object.
    :return features: Features hash table.
    :return name: Name of the graph.
    """
    name = path.strip(".json").split("\\")[-1] ### window下"\\",linux下“/”
    data = json.load(open(path))
    graph = nx.from_edgelist(data["edges"])

    if "features" in data.keys():
        features = data["features"]
    else:
        features = nx.degree(graph)

    features = {int(k): v for k, v in features.items()}
    return graph, features, name

def feature_extractor(path, rounds):   ### 数据转换成DOC
    """
    Function to extract WL features from a graph.
    :param path: The path to the graph json.
    :param rounds: Number of WL iterations.
    :return doc: Document collection object.
    """
    graph, features, name = dataset_reader(path)
    machine = WeisfeilerLehmanMachine(graph, features, rounds)
    doc = TaggedDocument(words=machine.extracted_features, tags=["g_" + name])
    return doc

def save_embedding(output_path, model, files, dimensions):
    """
    Function to save the embedding.
    :param output_path: Path to the embedding csv.
    :param model: The embedding model object.
    :param files: The list of files.
    :param dimensions: The embedding dimension parameter.
    """
    out = []
    for f in files:
        identifier = f.split("\\")[-1].strip(".json")  ### window下"\\",linux下“/”
        out.append([int(identifier)] + list(model.docvecs["g_"+identifier]))
    column_names = ["type"]+["x_"+str(dim) for dim in range(dimensions)]
    out = pd.DataFrame(out, columns=column_names)
    out = out.sort_values(["type"])
    out.to_csv(output_path, index=None)

def main():
    """
    Main function to read the graph list, extract features.
    Learn the embedding and save it.
    :param args: Object with the arguments.
    """
    graphs = glob.glob(input_path + "*.json")
    print("\nFeature extraction started.\n")
    ### document_collections是一个list,每个ele都是一个TaggedDocument([],tags=[])
    ###每个ele是长这样,有extracted_features 和 tags组成
    '''
    TaggedDocument(words=['1', '1', '1', '1', '1', '2', '2', '2', '2', '2', '3', '3', '3', '3', '3', '3', '3', '3', '3', '3', '3',
    'aec8787e7f0dd7c71c319a65ba4d670f', 'aec8787e7f0dd7c71c319a65ba4d670f', '711b064fd221cccd34a14f29b696a265',
    '68bea2bc149d4ee3548d35a2e6216542', '711b064fd221cccd34a14f29b696a265', '5c960b2e285107aff4a29ddbc53aa94e', 
    'aec8787e7f0dd7c71c319a65ba4d670f', '711b064fd221cccd34a14f29b696a265', '68bea2bc149d4ee3548d35a2e6216542',
    '2220178a6d449ba0a24980f4dfab8cd8', '3d34cbdc7cb225501d23e26f3d2ed945', 'c1544793797909a2c72ad52bdb0487d2', 
    '25f4b183e297058b59282613e231b5af', '5c960b2e285107aff4a29ddbc53aa94e', '648af13d92bf7d1019e0af47c8e71074',
    '648af13d92bf7d1019e0af47c8e71074', 'f0f421418433ba3cb592238eb7e51441', 'ab35e84a215f0f711ed629c2abb9efa0',
    'ab35e84a215f0f711ed629c2abb9efa0', 'ab35e84a215f0f711ed629c2abb9efa0', 'ab35e84a215f0f711ed629c2abb9efa0',
    '56218143711ad0d943613e81debfcaf6', 'ba32ef1397dd62a5661bff7d0c1bbd75', '8fe9ae5843120923a0dc4469a3089eef', 
    '16f04b8e020bce1ae04100dbd9f1aa88', 'b7a5cac1ba5e57afa69f50128cc2d95b', '8d889315ef6d7e55846fdb69f15fd1b7',
    'd9b9b6f3ee62f5ef7059770fd8903538', 'b7a5cac1ba5e57afa69f50128cc2d95b', '16f04b8e020bce1ae04100dbd9f1aa88',
    '184b020bdf857802edb40b0ebb5bd069', 'eb453d564c2ad11ed23ea2af7a5a4a7a', 'c754cc7c963e0b33df17b62e3c6a2620',
    '8e5f9e44897e7115e6f5c363046cde09', '63067dd99d3136f686fa814590355cf6', 'dcd616306a4cee9df958df008aae4b69',
    '7e4162c3ff0e44efc7826087c413e1b8', '537179ae9dc33f665a2eca172eeba096', 'b4b3107ce9aae441f8fd4a61ed1fecfb',
    'b4b3107ce9aae441f8fd4a61ed1fecfb', 'b4b3107ce9aae441f8fd4a61ed1fecfb', 'b4b3107ce9aae441f8fd4a61ed1fecfb'], tags=['g_0']
    '''
    document_collections = Parallel(n_jobs=workers)(delayed(feature_extractor)(g,wl_iterations) for g in tqdm(graphs))
    
    print("\nOptimization started.\n")

    model = Doc2Vec(document_collections,
                    vector_size=dimensions,
                    window=0,
                    min_count=min_count,
                    dm=0,
                    sample=down_sampling,
                    workers=workers,
                    epochs=epochs,
                    alpha=learning_rate)

    save_embedding(output_path, model, graphs,dimensions)

if __name__ == '__main__':
    main()

结果输出对每个graph的一个128维的向量表示,后面会用DGL库提供的不同类别的图做一个简单图聚类的Demo

你可能感兴趣的:(图算法,python学习,机器学习)