python define graph_GraphSAGE 代码解析(一) - unsupervised_train.py

原创文章~转载请注明出处哦。其他部分内容参见以下链接~

GraphSAGE代码详解

example_data:

1. toy-ppi-G.json 图的信息

{

directed: false

graph : {

{name: disjoint_union(,) }

nodes: [

{

test: false

id: 0

features: [ ... ]

val: false

lable: [ ... ]

}

{...}

...

]

links: [

{

test_removed: false

train_removed: false

target:800 #指向的节点id(默认从小节点指向大节点)

source: 0 #从0节点按顺序展示

}

{...}

...

]

}

}

View Code

2. toy-ppi-class_map.json

3. toy-ppi-feats.npy 预训练好得到的features

4. toy-ppi-id_map.json 节点编号与序号的一一对应;数据格式为:{"0": 0, "1": 1,..., "14754": 14754}

5. toy-ppi-walks.txt

从一点出发随机游走到邻居节点的情况,对于每个点取198次(即可能有重复情况)

例如:0    708 表示从0点走到708点。

1. __init__.py

1 from __future__ importprint_function2 #即使在python2.X,使用print就得像python3.X那样加括号使用。

3

4 from __future__ importdivision5 #导入python未来支持的语言特征division(精确除法),

6 #当我们没有在程序中导入该特征时,"/"操作符执行的是截断除法(Truncating Division);

7 #当我们导入精确除法之后,"/"执行的是精确除法, "//"执行截断除除法

2. unsupervised_train.py

1 if __name__ == '__main__':2 tf.app.run()3 #https://blog.csdn.net/fxjzzyo/article/details/80466321

4 #tf.app.run()的作用:通过处理flag解析,然后执行main函数

5 #如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,如test(),则你应该这样写入口tf.app.run(test())

6 #如果你的代码中的入口函数叫main(),则你就可以把入口写成tf.app.run()

1 def main(argv=None):2   print("Loading training data..")3   train_data = load_data(FLAGS.train_prefix, load_walks=True)4   #load_data函数在graphsage.utils中定义

5

6   print("Done loading training data..")7 train(train_data)8   #train函数在该文件中定义def train(train_data, test_data=None)

3. utils.py - func: load_data

(1) 读入id_map, class_map

1 ifisinstance(G.nodes()[0], int):2 def conversion(n): returnint(n)3 else:4 def conversion(n): return n

a. isinstance()函数来判断一个对象是否是一个已知的类型,类似 type()。

isinstance(object, classinfo)

参数

object -- 实例对象。

classinfo -- 可以是直接或间接类名、基本类型或者由它们组成的元组。

返回值

如果对象的类型与参数二的类型(classinfo)相同则返回 True,否则返回 False。

>>>a = 2

>>>isinstance (a,int)

True>>>isinstance (a,str)

False>>> isinstance (a,(str,int,list)) #是元组中的一个返回 True

True

type() 与 isinstance() 区别:

type() 不会认为子类是一种父类类型,不考虑继承关系。

isinstance() 会认为子类是一种父类类型,考虑继承关系。

如果要判断两个类型是否相同推荐使用 isinstance()。

1 classA:2 pass

3

4 classB(A):5 pass

6

7 isinstance(A(), A) #returns True

8 type(A()) == A #returns True

9 isinstance(B(), A) #returns True

10 type(B()) == A #returns False

View Code

b. G.nodes()

例子:

>>> G = nx.path_graph(3)>>>list(G.nodes)

[0,1, 2]>>>list(G)

[0,1, 2]

View Code

获取nodedata:

>>> G.add_node(1, time='5pm')>>> G.nodes[0]['foo'] = 'bar'

>>> list(G.nodes(data=True))

[(0, {'foo': 'bar'}), (1, {'time': '5pm'}), (2, {})]>>>list(G.nodes.data())

[(0, {'foo': 'bar'}), (1, {'time': '5pm'}), (2, {})]>>> list(G.nodes(data='foo'))

[(0,'bar'), (1, None), (2, None)]>>> list(G.nodes(data='time'))

[(0, None), (1, '5pm'), (2, None)]>>> list(G.nodes(data='time', default='Not Available'))

[(0,'Not Available'), (1, '5pm'), (2, 'Not Available')]

View Code

If some of your nodes have an attribute and the rest are assumed to have a default attribute value you can create a dictionary from node/attribute pairs using the default keyword argument to guarantee the value is never None:

>>> G =nx.Graph()>>>G.add_node(0)>>> G.add_node(1, weight=2)>>> G.add_node(2, weight=3)>>> dict(G.nodes(data='weight', default=1))

{0:1, 1: 2, 2: 3}

View Code

----------------------------

在utils.py中,判断G.nodes()[0] 是否为int型(即不带nodedata)。

若为int型,则将n转为int型;否则直接返回n.

b. conversion() 函数

1 id_map = json.load(open(prefix + "-id_map.json"))2 id_map = {conversion(k): int(v) for k, v in id_map.items()}

前面定义的conversion()函数在id_map这里用到了,把外存中的文件内容读到内存中,用dict类型的id_map存储。

id_map.json文件中数据格式为:{"0": 0, "1": 1,..., "14754": 14754},也即id_map的迭代中k为str类型,v为int型。数据文件中G.nodes()[0] 显然是带nodedata的,也就算一般采用 def conversion(n): return n,返回的n为类型的(就是前面形参k的类型);

但是为什么当G.nodes()[0] 不带nodedata时,要返回int(n)?

c. class_map:  {"0": [.0,1,..], "1": [.0,1,..]...} ?含义?

list(class_map.values()): [ [...], [...], ... ,[...] ]

list(class_map.values())[0]: 表示取第一个[...] =>含义?

ifisinstance(list(class_map.values())[0], list):

def lab_conversion(n): returnn

else:

def lab_conversion(n): return int(n)

(2) Remove node

1 #Remove all nodes that do not have val/test annotations

2 #(necessary because of networkx weirdness with the Reddit data)

3 broken_count =04 for node inG.nodes():5 if not 'val' in G.node[node] or not 'test' inG.node[node]:6 G.remove_node(node)7 broken_count += 1

这里删除的节点是不具有'val','test'属性的节点,而不是'val','test' 属性值为None的节点。

区分开 if not 'val' in G.node[node] 和 if not G.node[n]['val']的不同意义。

broken_count  记录删去的没有val 或者 test的属性的节点的数目。

e. G.edges()

1 for edge inG.edges():2 if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or

3 G.node[edge[0]]['test'] or G.node[edge[1]]['test']):4 G[edge[0]][edge[1]]['train_removed'] =True5 else:6 G[edge[0]][edge[1]]['train_removed'] = False

G.edges() 得到edge_list, [( , ), ( , ), ... ( , )].list中每一个元素是所表示边的两个节点信息。若设置data = True,则会显示边的权重等属性信息。

>>> G = nx.Graph() #or DiGraph, MultiGraph, MultiDiGraph, etc

>>> G.add_path([0,1,2])>>> G.add_edge(2,3,weight=5)>>>G.edges()

[(0,1), (1, 2), (2, 3)]>>> G.edges(data=True) #default edge data is {} (empty dictionary)

[(0, 1, {}), (1, 2, {}), (2, 3, {'weight': 5})]>>> list(G.edges_iter(data='weight', default=1))

[(0,1, 1), (1, 2, 1), (2, 3, 5)]>>> G.edges([0,3])

[(0,1), (3, 2)]>>>G.edges(0)

[(0,1)]

View Code

代码中edge对edges迭代,每次去list中的一个元组,而edge[0], edge[1]则分别表示两个顶点。

若两个顶点中至少有一个的val/test不为空,则将该边的'train_removed'设为True,否则为False.

该操作为保证'train_removed'不为空。

(3) 获取训练数据features并标准化

1 if normalize and not feats isNone:2 from sklearn.preprocessing importStandardScaler3 train_ids = np.array([id_map[n] for n inG.nodes(4 ) if not G.node[n]['val'] and not G.node[n]['test']])5 train_feats =feats[train_ids]6 scaler =StandardScaler()7 scaler.fit(train_feats)8 feats = scaler.transform(feats)

这里if not feats is None 等价于 if feats is not None.

将val,test均为None的node选为训练数据,通过id_map获取其在feature表中的索引值,添加到train_ids数组中。根据索引train_ids,train_fests获取这些nodes的features.

StandardScaler的用法:

Methods:

fit(X[, y]) : Compute the mean and std to be used for later scaling.

transform(X[, y, copy]) : Perform standardization by centering and scaling

fit_transform(X[, y]) : Fit to data, then transform it.

例子:

>>> from sklearn.preprocessing importStandardScaler>>> data = [[0, 0], [0, 0], [1, 1], [1, 1]]>>> scaler =StandardScaler()>>> print(scaler.fit(data))

StandardScaler(copy=True, with_mean=True, with_std=True)>>> print(scaler.mean_)

[0.5 0.5]>>> print(scaler.transform(data))

[[-1. -1.]

[-1. -1.]

[1. 1.]

[1. 1.]]>>> print(scaler.transform([[2, 2]]))

[[3. 3.]]#计算得#均值[0.5, 0.5],#方差:1/4 * [(0 - 0.5)^2 * 2 + (1 - 0.5)^2 * 2] = 1/4 = 0.25#标准差:0.5#对于[2,2] transform 标准化之后: (2 - 0.5) / 0.5 = 3

View Code

(4) Load walks

在unsupervised_train.py的main函数中:

1 train_data = load_data(FLAGS.train_prefix, load_walks=True)

load_walks = True,需要执行utils.py中的load_walks操作。

1 if load_walks: #false by default

2 with open(prefix + "-walks.txt") as fp:3 for line infp:4 walks.append(map(conversion, line.split()))

map(function, iterable, ...)

map() 会根据提供的函数对指定序列做映射。

第一个参数 function 以参数序列中的每一个元素调用 function 函数,返回包含每次 function 函数返回值的新列表。

例子:

>>>def square(x) : #计算平方数

... return x ** 2...>>> map(square, [1,2,3,4,5]) #计算列表各个元素的平方

[1, 4, 9, 16, 25]>>> map(lambda x: x ** 2, [1, 2, 3, 4, 5]) #使用 lambda 匿名函数

[1, 4, 9, 16, 25]#提供了两个列表,对相同位置的列表数据进行相加

>>> map(lambda x, y: x + y, [1, 3, 5, 7, 9], [2, 4, 6, 8, 10])

[3, 7, 11, 15, 19]

View Code

walks初始化为[], 之后append的是游走的节点对的对象。

例子:walks.txt:

0 708

0 3163

0 276

1 def conversion(n): returnn2 walks =[]3 with open("walks.txt") as fp:4 for line infp:5 print(line.split())6 walks.append(map(conversion, line.split()))7 print(walks)8 print(len(walks))

View Code

输出:

['0', '708']

['0', '3163']

['0', '276']

[, , ]3

(5) 函数返回值

1 return G, feats, id_map, walks, class_map

------------------------------------------------------------------------------------

4. unsupervised_train.py - func: train(train_data)

1 def train(train_data, test_data=None):

这里的train_data是上文所述的load_data函数的返回值。

变量含义:

G = train_data[0] #图

features = train_data[1] #训练数据的features

id_map = train_data[2] #"n" : n

context_pairs = train_data[3] if FLAGS.random_context else None #random walk的点对

1 if not features isNone:2 #pad with dummy zero vector

3 features = np.vstack([features, np.zeros((features.shape[1],))])

这里vstack为features添加列一行0向量,用于WX + b中与b相加。

1 placeholders =construct_placeholders()2 #def construct_placeholders()定义的placeholders包含:

3 #batch1, batch2, neg_samples, dropout, batch_size

minibatch是EdgeMinibatchIterator的一个实例,转至minibatch.py看class EdgeMinibatchIterator(object)的定义。

5. minibatch.py - class EdgeMinibatchIterator

6. unsupervised_train.py - func train

继续回来看unsupervised_trian.py 中的train函数

变量:

1 adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)2 adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

adj_info记录邻居信息,是一个矩阵,矩阵每一行对应每一个节点的邻居节点编号数组。

(1)选择模型

接下来根据输入参数判断选择6种模型(graphsage_mean,gcn,graphsage_seq,graphsage_maxpool,graphsage_meanpool,n2v)中的哪一种。

以graphsage开头的几种是graphsage的几种变体,由于aggregator不同而不同。可以通过设定SampleAndAggregate()中的aggregator_type进行选择。默认为mean.

其中gcn与graphsage的参数不同在于:

gcn的aggregator中进行列concat的操作,因此其维数是graphsage的二倍。

a. graphsage_maxpool

1 sampler =UniformNeighborSampler(adj_info)

首先看UniformNeighborSampler,该类用于sample节点的邻居,在neigh_samplers.py中。

neigh_samplers.py

1 classUniformNeighborSampler(Layer):2 """

3 Uniformly samples neighbors.4 Assumes that adj lists are padded with random re-sampling5 """

6 def __init__(self, adj_info, **kwargs):7 super(UniformNeighborSampler, self).__init__(**kwargs)8 self.adj_info =adj_info9

10 def_call(self, inputs):11 ids, num_samples =inputs12 adj_lists =tf.nn.embedding_lookup(self.adj_info, ids)13 adj_lists =tf.transpose(tf.random_shuffle(tf.transpose(adj_lists)))14 adj_lists = tf.slice(adj_lists, [0,0], [-1, num_samples])15 return adj_lists

1.  tf.nn.embedding_lookup 用于根据ids在adj_info中找到各个对应位的向量。

2. adj_lists = tf.transpose(tf.random_shuffle(tf.transpose(adj_lists)))

adj_lists = tf.slice(adj_lists, [0,0], [-1, num_samples]) 的过程见下:

id0 id1 id2...   --transpose--> id0 [...]  --shuffle--> id1 [...]  --transpose--> id1 id2 id0 --slice--> id1 id2

[]    []    []                                id1 [...]                     id2 [...]                         []     []    []                  []    []

id2 [...]                     id0 [...]

均匀:shuffle打乱0维的顺序,即打乱行顺序,以此使下面采样可以“均匀”。为了使用shuffle函数,需要在shuffle前后transpose一下。

采样:slice之后,相当于随机挑选了num_samples个样本,并保留了这些样本的全部属性特征。

3. 最后的adj_lists即为均匀采样后的表示邻居信息的矩阵。

---------------------------------------------------

回到unsupervised_train.py 的train()函数.

1 sampler = UniformNeighborSampler(adj_info)

sampler获取均匀采样后的邻居节点信息。

---------------------------------------------------

1 layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),2 SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

其中SAGEInfo在models.py中。

models.py

1 #SAGEInfo is a namedtuple that specifies the parameters

2 #of the recursive GraphSAGE layers

3 SAGEInfo = namedtuple("SAGEInfo",4 ['layer_name', #name of the layer (to get feature embedding etc.)

5 'neigh_sampler', #callable neigh_sampler constructor

6 'num_samples',7 'output_dim' #the output (i.e., hidden) dimension

8 ])

namedtuple命名元组,可以给tuple命名,用法见下:

1 importcollections2

3 MyTupleClass = collections.namedtuple('MyTupleClass',['name', 'age', 'job'])4 obj = MyTupleClass("Tomsom",12,'Cooker')5 print(obj.name)6 print(obj.age)7 print(obj.job)8

9 #Output:

10 #Tomsom

11 #12

12 #Cooker

13 #############################

14

15 Person=collections.namedtuple('Person','name age gender')16 #以空格分开,表示这个namedtuple有三个元素

17

18 print( 'Type of Person:',type(Person))19 Bob=Person(name='Bob',age=30,gender='male')20 print( 'Representation:',Bob)21 Jane=Person(name='Jane',age=29,gender='female')22 print( 'Field by Name:',Jane.name)23 for people in[Bob,Jane]:24 print ("%s is %d years old %s" %people)25

26 #Output:

27 #Type of Person:

28 #Representation: Person(name='Bob', age=30, gender='male')

29 #Field by Name: Jane

30 #Bob is 30 years old male

31 #Jane is 29 years old female

32 #############################

33

34 #在使用namedtyuple的时候要注意其中的名称不能使用Python的关键字,如class def等

35 #不能有重复的元素名称,比如:不能有两个’age age’。如果出现这些情况,程序会报错。

36 #但是,在实际使用的时候可能无法避免这种情况,

37 #比如:可能我们的元素名称是从数据库里读出来的记录,这样很难保证一定不会出现Python关键字。

38 #这种情况下的解决办法是将namedtuple的重命名模式打开,

39 #这样如果遇到Python关键字或者有重复元素名 时,自动进行重命名。

40

41 with_class=collections.namedtuple('Person','name age class gender',rename=True)42 printwith_class._fields43 two_ages=collections.namedtuple('Person','name age gender age',rename=True)44 printtwo_ages._fields45

46 #Output:

47 #('name', 'age', '_2', 'gender')

48 #('name', 'age', 'gender', '_3')

49

50 #使用rename=True的方式打开重命名选项。

51 #可以看到第一个集合中的class被重命名为 ‘_2' ;

52 #第二个集合中重复的age被重命名为 ‘_3'

53 #namedtuple在重命名的时候使用了下划线 _ 加元素所在索引数的方式进行重命名

54 ##############################

55

56 #附两段官方文档代码实例:

57 #1) namedtuple基本用法

58 >>> #Basic example

59 >>> Point = namedtuple('Point', ['x', 'y'])60 >>> p = Point(11, y=22) #instantiate with positional or keyword arguments

61 >>> p[0] + p[1] #indexable like the plain tuple (11, 22)

62 33

63 >>> x, y = p #unpack like a regular tuple

64 >>>x, y65 (11, 22)66 >>> p.x + p.y #fields also accessible by name

67 33

68 >>> p #readable __repr__ with a name=value style

69 Point(x=11, y=22)70

71 #2) namedtuple结合csv和sqlite用法

72 EmployeeRecord = namedtuple('EmployeeRecord', 'name, age, title, department, paygrade')73 importcsv74 for emp in map(EmployeeRecord._make, csv.reader(open("employees.csv", "rb"))):75 print(emp.name, emp.title)76

77 importsqlite378 conn = sqlite3.connect('/companydata')79 cursor =conn.cursor()80 cursor.execute('SELECT name, age, title, department, paygrade FROM employees')81 for emp inmap(EmployeeRecord._make, cursor.fetchall()):82 print(emp.name, emp.title)

View Code

对于FLAGS.dim_1与FLAGS.dim_2,定义为:

1 flags.DEFINE_integer(2 'dim_1', 128, 'Size of output dim (final is 2x this, if using concat)')3 flags.DEFINE_integer(4 'dim_2', 128, 'Size of output dim (final is 2x this, if using concat)')

若GCN,因为有concat操作,故使用2x.

对于FLAGS.samples_1与FLAGS.samples_2,定义为:

1 flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1')

2 flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2')

对应论文中的K = 1 ,第一层S1 = 25; K = 2 ,第二层S2 = 10。

----------------------------------------------------------

1 model =SampleAndAggregate(placeholders,2 features,3 adj_info,4 minibatch.deg,5 layer_infos=layer_infos,6 aggregator_type="maxpool",7 model_size=FLAGS.model_size,8 identity_dim=FLAGS.identity_dim,9 logging=True)

SampleAndAggregate在models.py中。

class SampleAndAggregate(GeneralizedModel)主要包含的函数有:

1. def __init__(self, placeholders, features, adj, degrees, layer_infos, concat=True, aggregator_type="mean",  model_size="small", identity_dim=0, **kwargs)

2. def sample(self, inputs, layer_infos, batch_size=None)

3. def aggregate(self, samples, input_features, dims, num_samples, support_sizes, batch_size=None,

aggregators=None, name=None, concat=False, model_size="small")

4. def _build(self)

5. def build(self)

6. def _loss(self)

7. def _accuracy(self)

---------------------------------------------------------------

(2) Session

Config

1 config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)2 #参数初始化为False:

3 #tf.app.flags.DEFINE_boolean('log_device_placement', False,

4 #"""Whether to log device placement.""")

5

6 config.gpu_options.allow_growth =True7 #控制GPU资源使用率

8 #使用allow_growth option,刚一开始分配少量的GPU容量,然后按需慢慢的增加,

9 #由于不会释放内存,所以会导致碎片

10

11 #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION

12 #设置每个GPU应该拿出多少容量给进程使用,

13 #per_process_gpu_memory_fraction =0.4代表 40%

14

15 config.allow_soft_placement =True16 #自动选择运行设备

17 #在tf中,通过命令 "with tf.device('/cpu:0'):",允许手动设置操作运行的设备。

18 #如果手动设置的设备不存在或者不可用,就会导致tf程序等待或异常,

19 #为了防止这种情况,可以设置tf.ConfigProto()中参数allow_soft_placement=True,

20 #允许tf自动选择一个存在并且可用的设备来运行操作。

Initialize session

1 #Initialize session

2 sess = tf.Session(config=config)3 merged =tf.summary.merge_all()4 #tf.summary()能够保存训练过程以及参数分布图并在tensorboard显示。

5 #merge_all 可以将所有summary全部保存到磁盘,以便tensorboard显示。

6 #如果没有特殊要求,一般用这一句就可一显示训练时的各种信息了

7

8 summary_writer =tf.summary.FileWriter(log_dir(), sess.graph)9 #指定一个文件用来保存图。

10 #格式:tf.summary.FileWritter(path,sess.graph)

11 #可以调用其add_summary()方法将训练过程数据保存在filewriter指定的文件中

Init variables

1 sess.run(tf.global_variables_initializer(),2 feed_dict={adj_info_ph: minibatch.adj})

---------------------------------------------------------

(4) Train model

1 feed_dict = minibatch.next_minibatch_feed_dict()

next_minibatch_feed_dict() 在minibatch.py的class EdgeMinibatchIterator(object)中定义。

1 defnext_minibatch_feed_dict(self):2 start_idx = self.batch_num *self.batch_size3 self.batch_num += 1

4 end_idx = min(start_idx +self.batch_size, len(self.train_edges))5 batch_edges =self.train_edges[start_idx: end_idx]6 return self.batch_feed_dict(batch_edges)

View Code

函数中获取下个edgeminibatch的起始与终止序号,将batch后的边的信息传给batch_feed_dict(self, batch_edges)函数,更新placeholders中的batch1, batch2, batch_size信息。

1 defbatch_feed_dict(self, batch_edges):2 batch1 =[]3 batch2 =[]4 for node1, node2 inbatch_edges:5 batch1.append(self.id2idx[node1])6 batch2.append(self.id2idx[node2])7

8 feed_dict =dict()9 feed_dict.update({self.placeholders['batch_size']: len(batch_edges)})10 feed_dict.update({self.placeholders['batch1']: batch1})11 feed_dict.update({self.placeholders['batch2']: batch2})12

13 return feed_dict

View Code

也即next_minibatch_feed_dict()返回的是下一个edge minibatch的placeholders信息。

=======================================

             

感谢您的打赏!

(梦想还是要有的,万一您喜欢我的文章呢)

你可能感兴趣的:(python,define,graph)