from torch_geometric.nn import RGCNConv,GraphConv
import torch.nn as nn
import torch
num_features = 64
hidden_size = 16
conv1 = RGCNConv(in_channels = num_features,out_channels = hidden_size,num_relations = 3,num_bases = 10)
print('RGCNConv模型里的参数:')
for name,parameter in list(conv1.named_parameters()):
print(name,':',parameter.shape)
num_nodes = 53
x = torch.randn(num_nodes,num_features)
num_edges = 111
edge_index = torch.randint(0,num_nodes,(2,num_edges))
edge_type = torch.randint(0,10,(num_edges,))
output1 = conv1(x,edge_index,edge_type)
print('output的shape:',output1.shape)
RGCNConv模型里的参数:
weight : torch.Size([10, 64, 16])
comp : torch.Size([3, 10])
root : torch.Size([64, 16])
bias : torch.Size([16])
output的shape: torch.Size([53, 16])
from torch_geometric.nn import GraphConv
import torch.nn as nn
import torch
input_size = 16
output_size = 5
conv2 = GraphConv(input_size, output_size)
print('GraphConv模型里的参数:')
for name,parameter in list(conv2.named_parameters()):
print(name,':',parameter.shape)
num_nodes = 53
num_edges = 111
input_ = torch.randn([num_nodes,input_size])
edge_index = torch.randint(0,num_nodes,(2,num_edges))
output2 = conv2(input_,edge_index)
print('output的shape:',output2.shape)
GraphConv模型里的参数:
lin_rel.weight : torch.Size([5, 16])
lin_rel.bias : torch.Size([5])
lin_root.weight : torch.Size([5, 16])
output的shape: torch.Size([53, 5])