gradle run --args="--c Foo/config.json -l"
String[] args1={"-c","Foo/config.json","-l"};
if (cmd.hasOption("l")) {
GPFL system = new GPFL(config, "learn_log");
system.learn();
}
public void learn() {
graphFile = new File(home, "databases/graph.db");
graph = IO.loadGraph(graphFile);
trainFile = new File(home, "data/annotated_train.txt");
validFile = new File(home, "data/annotated_valid.txt");
ruleFile = IO.createEmptyFile(new File(out, "rules.txt"));
ruleIndexHome = new File(out, "index");
ruleIndexHome.mkdir();
populateTargets();
GlobalTimer.programStartTime = System.currentTimeMillis();
for (String target : targets) {
File ruleIndexFile = IO.createEmptyFile(new File(ruleIndexHome
, target.replaceAll("[:/<>]", "_") + ".txt"));
Settings.TARGET = target;
Context context = new Context();
Logger.println(MessageFormat.format("\n# ({0}\\{1}) Start Learning Rules for Target: {2}",
globalTargetCounter++, targets.size(), target), 1);
try (Transaction tx = graph.beginTx()) {
Set trainPairs = IO.readPair(graph, trainFile, target);
Settings.TARGET_FUNCTIONAL = IO.isTargetFunctional(trainPairs);
Set validPairs = IO.readPair(graph, validFile, target);
Logger.println(MessageFormat.format("# Train Size: {0}", trainPairs.size()), 1);
generalization(trainPairs, context);
if(Settings.ESSENTIAL_TIME != -1 && Settings.INS_DEPTH != 0)
EssentialRuleGenerator.generateEssentialRules(trainPairs, validPairs, context, graph, ruleIndexFile, ruleFile);
specialization(context, trainPairs, validPairs, ruleIndexFile);
IO.orderRuleIndexFile(ruleIndexFile);
tx.success();
}
}
IO.orderRules(out);
GlobalTimer.reportMaxMemoryUsed();
GlobalTimer.reportTime();
}
graph = IO.loadGraph(graphFile);
public static GraphDatabaseService loadGraph(File graphFile) {
Logger.println("\n# Load Neo4J Graph from: " + graphFile.getPath(), 1);
GraphDatabaseService graph = new GraphDatabaseFactory()
.newEmbeddedDatabase( graphFile );
Runtime.getRuntime().addShutdownHook( new Thread( graph::shutdown ));
DecimalFormat format = new DecimalFormat("####.###");
try(Transaction tx = graph.beginTx()) {
long relationshipTypes = graph.getAllRelationshipTypes().stream().count();
long relationships = graph.getAllRelationships().stream().count();
long nodes = graph.getAllNodes().stream().count();
Logger.println(MessageFormat.format("# Relationship Types: {0} | Relationships: {1} " +
"| Nodes: {2} | Instance Density: {3} | Degree: {4}",
relationshipTypes,
relationships,
nodes,
format.format((double) relationships / relationshipTypes),
format.format((double) relationships / nodes)), 1);
tx.success();
}
return graph;
}
GraphDatabaseService graph = new GraphDatabaseFactory()
.newEmbeddedDatabase( graphFile );
指定文件夹路径,创建相应的文件
trainFile = new File(home, "data/annotated_train.txt");
validFile = new File(home, "data/annotated_valid.txt");
ruleFile = IO.createEmptyFile(new File(out, "rules.txt"));
ruleIndexHome = new File(out, "index");
ruleIndexHome.mkdir();
public void populateTargets() {
if(testFile != null)
//从testFile中,获取targets有哪些类型
targets = IO.readTargets(testFile);
else {
try(Transaction tx = graph.beginTx()) {
for (RelationshipType type : graph.getAllRelationshipTypes())
targets.add(type.name());
tx.success();
}
}
//选择学些的Targets
Set selectedTargets = new HashSet<>();
try {
JSONArray array = args.getJSONArray("target_relation");
if(!array.isEmpty()) {
for (Object o : array) {
String target = (String) o;
if (targets.contains(target))
selectedTargets.add(target);
else {
System.err.println("# Selected Targets do not exist in the test file.");
System.exit(-1);
}
}
targets = selectedTargets;
}
} catch (JSONException e) {
try {
int randomSelect = args.getInt("randomly_selected_relations");
if (randomSelect != 0) {
List targetList = new ArrayList<>(targets);
Collections.shuffle(targetList);
targets = new HashSet<>(targetList.subList(0, Math.min(randomSelect, targetList.size())));
}
} catch (JSONException ignored) { }
}
}
以第一个target,publishes为例
# (1\12) Start Learning Rules for Target: PUBLISHES
for (String target : targets) {
File ruleIndexFile = IO.createEmptyFile(new File(ruleIndexHome
, target.replaceAll("[:/<>]", "_") + ".txt"));
Settings.TARGET = target;
Context context = new Context();
Logger.println(MessageFormat.format("\n# ({0}\\{1}) Start Learning Rules for Target: {2}",
globalTargetCounter++, targets.size(), target), 1);
try (Transaction tx = graph.beginTx()) {
Set trainPairs = IO.readPair(graph, trainFile, target);
Settings.TARGET_FUNCTIONAL = IO.isTargetFunctional(trainPairs);
Set validPairs = IO.readPair(graph, validFile, target);
Logger.println(MessageFormat.format("# Train Size: {0}", trainPairs.size()), 1);
generalization(trainPairs, context);
if(Settings.ESSENTIAL_TIME != -1 && Settings.INS_DEPTH != 0)
EssentialRuleGenerator.generateEssentialRules(trainPairs, validPairs, context, graph, ruleIndexFile, ruleFile);
specialization(context, trainPairs, validPairs, ruleIndexFile);
IO.orderRuleIndexFile(ruleIndexFile);
tx.success();
}
}
Foo\i3c3\index\PUBLISHES.txt
File ruleIndexFile = IO.createEmptyFile(new File(ruleIndexHome
, target.replaceAll("[:/<>]", "_") + ".txt"));
1841条数据,其中有440条数据与Publishes相关
Set trainPairs = IO.readPair(graph, trainFile, target);
public static Set readPair(GraphDatabaseService graph, File in, String target) {
Set pairs = new HashSet<>();
try(Transaction tx = graph.beginTx()) {
try(LineIterator l = FileUtils.lineIterator(in)) {//它创建了一个迭代器,可以按行遍历给定的文件
while(l.hasNext()) {
String[] words = l.nextLine().split("\t");
long relationId = Long.parseLong(words[0]);//735
long headId = Long.parseLong(words[1]);//407
String type = words[2];//PUBLISHES
long tailId = Long.parseLong(words[3]);//266
if(type.equals(target)) {
if(relationId != -1) {
Relationship rel = graph.getRelationshipById(relationId);//以第一条数据为例(407)-[PUBLISHES,735]->(266)
pairs.add(new Pair(rel.getStartNodeId(), rel.getEndNodeId(), rel.getId()
, rel, rel.getType()
, (String) rel.getStartNode().getProperty(Settings.NEO4J_IDENTIFIER)
, (String) rel.getEndNode().getProperty(Settings.NEO4J_IDENTIFIER)
, rel.getType().name()));
} else {
pairs.add(new Pair(headId, tailId));
}
}
}
} catch (IOException e) {
e.printStackTrace();
System.exit(-1);
}
tx.success();
}
return pairs;
}
按行遍历每一条数据
FileUtils使用介绍_码路编的博客-CSDN博客_fileutils
try(LineIterator l = FileUtils.lineIterator(in)) {//它创建了一个迭代器,可以按行遍历给定的文件
while(l.hasNext()) {
以第一条数据为例relationId 是735,headId 是407,type是PUBLISHES,tailid是266
String[] words = l.nextLine().split("\t");
long relationId = Long.parseLong(words[0]);//735
long headId = Long.parseLong(words[1]);//407
String type = words[2];//PUBLISHES
long tailId = Long.parseLong(words[3]);//266
将相关信息保存到Pairs当中
if(type.equals(target)) {
if(relationId != -1) {
Relationship rel = graph.getRelationshipById(relationId);//以第一条数据为例(407)-[PUBLISHES,735]->(266)
pairs.add(new Pair(rel.getStartNodeId(), rel.getEndNodeId(), rel.getId()
, rel, rel.getType()
, (String) rel.getStartNode().getProperty(Settings.NEO4J_IDENTIFIER)
, (String) rel.getEndNode().getProperty(Settings.NEO4J_IDENTIFIER)
, rel.getType().name()));
} else {
pairs.add(new Pair(headId, tailId));
}
}
Settings.NEO4J_IDENTIFIER="name"
Relationship rel = graph.getRelationshipById(relationId);
pairs.add(new Pair(rel.getStartNodeId(), rel.getEndNodeId(), rel.getId()
, rel, rel.getType()
, (String) rel.getStartNode().getProperty(Settings.NEO4J_IDENTIFIER)
, (String) rel.getEndNode().getProperty(Settings.NEO4J_IDENTIFIER)
, rel.getType().name()));
(String) rel.getStartNode().getProperty(Settings.NEO4J_IDENTIFIER)
, (String) rel.getEndNode().getProperty(Settings.NEO4J_IDENTIFIER)
Settings.TARGET_FUNCTIONAL = IO.isTargetFunctional(trainPairs);
public static boolean isTargetFunctional(Set trainPairs) {
Multimap subToObjs = MultimapBuilder.hashKeys().hashSetValues().build();
for (Pair trainPair : trainPairs) {
subToObjs.put(trainPair.subId, trainPair.objId);
}
int functionalCount = 0;
for (Long sub : subToObjs.keySet()) {
if(subToObjs.get(sub).size() == 1)
functionalCount++;
}
return ((double) functionalCount / subToObjs.keySet().size()) >= 0.9d;
}
subToObjs.put(trainPair.subId, trainPair.objId);
functionalCount++;
return ((double) functionalCount / subToObjs.keySet().size()) >= 0.9d;
generalization(trainPairs, context);
public void generalization(Set trainPairs, Context context) {
long s = System.currentTimeMillis();
BlockingQueue ruleQueue = new LinkedBlockingDeque<>(Settings.BATCH_SIZE * 2);
Set visitedTrainPairs = new HashSet<>();
RuleProducer[] producers = new RuleProducer[Settings.THREAD_NUMBER];
RuleConsumer consumer = new RuleConsumer(0, ruleQueue, context);
for (int i = 0; i < producers.length; i++) {
producers[i] = new RuleProducer(i, ruleQueue, trainPairs, visitedTrainPairs, graph, consumer);
}
try {
for (RuleProducer producer : producers) {
producer.join();
}
consumer.join();
} catch (InterruptedException e) {
e.printStackTrace();
System.exit(-1);
}
Logger.println(MessageFormat.format("# Visited/Training Instances: {0}/{1} | Ratio: {2}%" +
" | Sampled Paths: {3} | Saturation: {4}%"
, visitedTrainPairs.size()
, trainPairs.size()
, new DecimalFormat("###.##").format(((double) visitedTrainPairs.size() / trainPairs.size()) * 100f)
, consumer.getPathCount()
, new DecimalFormat("###.##").format(consumer.getSaturation() * 100f))
);
GlobalTimer.updateTemplateGenStats(Helpers.timerAndMemory(s, "# Generalization"));
Logger.println(Context.analyzeRuleComposition("# Generated Templates"
, context.getAbstractRules()), 1);
}
BlockingQueue
什么是BlockingQueue?一次性说清了_java令人头秃的博客-CSDN博客_blockingqueue
使用无限 BlockingQueue 设计生产者 - 消费者模型时最重要的是 消费者应该能够像生产者向队列添加消息一样快地消费消息 。
BlockingQueue ruleQueue = new LinkedBlockingDeque<>(Settings.BATCH_SIZE * 2);
RuleProducer[] producers = new RuleProducer[Settings.THREAD_NUMBER];
RuleConsumer consumer = new RuleConsumer(0, ruleQueue, context);
RuleConsumer(int id, BlockingQueue ruleQueue, Context context) {
super("RuleConsumer-" + id);
this.id = id;
this.ruleQueue = ruleQueue;
this.context = context;
GlobalTimer.setGenStartTime(System.currentTimeMillis());
start();
}
@Override
public void run() {
Set currentBatch = new HashSet<>();
do {
if(pathCount % Settings.BATCH_SIZE == 0) {
int overlaps = 0;
if(currentBatch.isEmpty() && pathCount != 0)
break;
for (Rule rule : currentBatch) {
if(context.getAbstractRules().contains(rule))
overlaps++;
else
context.updateFreqAndIndex(rule);
}
saturation = currentBatch.isEmpty() ? 0d : (double) overlaps / currentBatch.size();
currentBatch.clear();
}
Rule rule = ruleQueue.poll();
if(rule != null) {
if(rule.isClosed() ? rule.length() <= Settings.CAR_DEPTH : rule.length() <= Settings.INS_DEPTH)
currentBatch.add(rule);
pathCount++;
}
} while (saturation < Settings.SATURATION && !GlobalTimer.stopGen());
for (Rule rule : currentBatch) {
context.updateFreqAndIndex(rule);
}
}
值的注意的是 getAbstractRules()是获取变量ruleFrequency的规则,ruleFrequency包含了abstract(closed)和instantited(非closed)
context.getAbstractRules().contains(rule)
有可能产生了相同的规则也进行累加的Settings.BATCH_SIZE =30000
相当于是pathCount<30000的时候,要检查currentbatch是否与ruleFrequency重复,再添加规则到ruleFrequency当中,currentBatch在此时是每次都只有一条规则,因为poll()是每次只吐出一条。而当pathCount>30000的时候,当currentBatch
if(pathCount % Settings.BATCH_SIZE == 0) {
如果rulefrequency中包含这些rule,则overlap++
if(context.getAbstractRules().contains(rule))
overlaps++;
值的注意的是ruleFrequency是没有限制大小的,ruleQueue才有限制大小
context.updateFreqAndIndex(rule);
public synchronized void updateFreqAndIndex(Rule rule) {
if(ruleFrequency.containsKey(rule))
ruleFrequency.put(rule, ruleFrequency.get(rule) + 1);
else {
ruleFrequency.put(rule, 1);
indexRule.put(index++, rule);
}
}
currentBatch是当前得到的规则集合
如果当前集合没有规则,则saturation =0,否则的话saturation =overlaps / currentBatch.size()
saturation = currentBatch.isEmpty() ? 0d : (double) overlaps / currentBatch.size();
从ruleQueue中获取规则,相当于是消费者
Rule rule = ruleQueue.poll();
if(rule != null) {
if(rule.isClosed() ? rule.length() <= Settings.CAR_DEPTH : rule.length() <= Settings.INS_DEPTH)
currentBatch.add(rule);
pathCount++;
}
如果规则是closed则用阈值Settings.CAR_DEPTH进行判断,非closed则用INS_DEPTH进行判断
值的注意的是closed在文章里也称作closed abstract rules,非closed的在文章里称作instantiated rules
if(rule.isClosed() ? rule.length() <= Settings.CAR_DEPTH : rule.length() <= Settings.INS_DEPTH)
只有满足阈值才添加到 currentBatch
currentBatch.add(rule);
while (saturation < Settings.SATURATION && !GlobalTimer.stopGen());
判断数量是否达到阈值
saturation < Settings.SATURATION
在Global Timer类stopGen(),判断时间是否达到阈值
GlobalTimer.stopGen()
public static boolean stopGen() {
if(Settings.GEN_TIME == 0)
return false;
return ((double) (System.currentTimeMillis() - genStartTime) / 1000d) > Settings.GEN_TIME;
}
for (int i = 0; i < producers.length; i++) {
producers[i] = new RuleProducer(i, ruleQueue, trainPairs, visitedTrainPairs, graph, consumer);
}
RuleProducer(int id, BlockingQueue ruleQueue, Set trainPairs, Set visitedTrainPairs
, GraphDatabaseService graph, Thread consumer) {
super("RuleProducer-" + id);
this.id = id;
this.ruleQueue = ruleQueue;
this.trainPairs = new ArrayList<>(trainPairs);
this.graph = graph;
this.consumer = consumer;
this.visitedTrainPairs = visitedTrainPairs;
start();
}
@Override
public void run() {
Random rand = new Random();
try(Transaction tx = graph.beginTx()) {
while(consumer.isAlive()) {
Pair pair = trainPairs.get(rand.nextInt(trainPairs.size()));
addVisitedPair(pair);
Traverser traverser = GraphOps.buildStandardTraverser(graph, pair, Settings.RANDOM_WALKERS);
for (Path path : traverser) {
Rule rule = Context.createTemplate(path, pair);
while (consumer.isAlive()) {
if (ruleQueue.offer(rule, 100, TimeUnit.MILLISECONDS))
break;
}
if (!consumer.isAlive())
break;
}
}
tx.success();
} catch (InterruptedException e) {
e.printStackTrace();
System.exit(-1);
}
}
判断线程是否存活pisAlive()方法_墨白hu的博客-CSDN博客_isalive
consumer.isAlive()
Pair pair = trainPairs.get(rand.nextInt(trainPairs.size()));
addVisitedPair(pair);
private synchronized void addVisitedPair(Pair pair) {
visitedTrainPairs.add(pair);
}
Traverser traverser = GraphOps.buildStandardTraverser(graph, pair, Settings.RANDOM_WALKERS);
public static Traverser buildStandardTraverser(GraphDatabaseService graph, Pair pair, int randomWalkers){
Traverser traverser;
Node startNode = graph.getNodeById(pair.subId);
Node endNode = graph.getNodeById(pair.objId);
traverser = graph.traversalDescription()
.uniqueness(Uniqueness.NODE_PATH)
.order(BranchingPolicy.PreorderBFS())
.expand(standardRandomWalker(randomWalkers))
.evaluator(toDepthNoTrivial(Settings.DEPTH, pair))
.traverse(startNode, endNode);
return traverser;
}
假设当前获取的数据为[413,514]
Node startNode = graph.getNodeById(pair.subId);
Node endNode = graph.getNodeById(pair.objId);
traverser = graph.traversalDescription()
.uniqueness(Uniqueness.NODE_PATH)
.order(BranchingPolicy.PreorderBFS())
.expand(standardRandomWalker(randomWalkers))
.evaluator(toDepthNoTrivial(Settings.DEPTH, pair))
.traverse(startNode, endNode);
.traversalDescription()
Neo4j之图数据遍历总结 - 墨天轮 (modb.pro)
Neo4j:入门基础(八)之Traversal API_Dawn_www的博客-CSDN博客
Neo4j3.5学习笔记——Traversal遍历之黑客帝国_蹦蹦恰Amy的博客-CSDN博客
Neo4j3.5学习笔记——Traversal遍历之在遍历查询中的唯一路径_蹦蹦恰Amy的博客-CSDN博客
.uniqueness(Uniqueness.NODE_PATH)
用java构建neo4j数据库 - it610.com
广度优先遍历也分先序(PreOrder)、中序和后序(PostOrder)
在循环的时候一个traverse包含了一次广搜的内容
.order(BranchingPolicy.PreorderBFS())
Neo4j 遍历框架_承接各种编程私活的博客-CSDN博客
一个BranchSelector是用来定义如何选择遍历下一个分支。这被用来实现遍历顺序。遍历框架提供了一些基本的顺序实现:
Traversal.preorderDepthFirst()
: 深度优先,在访问的子节点之前访问每一个节点。Traversal.postorderDepthFirst()
: 深度优先,在访问的子节点之后访问每一个节点。Traversal.preorderBreadthFirst()
: 宽度优先,在访问的子节点之前访问每一个节点。Traversal.postorderBreadthFirst()
: 宽度优先,在访问的子节点之后访问每一个节点。.expand(standardRandomWalker(randomWalkers))
java - neo4j中的随机后置遍历 | Neo4j (lmlphp.com)
Neo4j中的定向路径(Directed paths in Neo4j)_电脑培训 (656463.com)
standardRandomWalker(randomWalkers)
public interface PathExpander
{
/**
* Returns relationships for a {@link Path}, most commonly from the
* {@link Path#endNode()}.
*
* @param path the path to expand (most commonly the end node).
* @param state the state of this branch in the current traversal.
* {@link BranchState#getState()} returns the state and
* {@link BranchState#setState(Object)} optionally sets the state for
* the children of this branch. If state isn't altered the children
* of this path will see the state of the parent.
* @return the relationships to return for the {@code path}.
*/
Iterable expand( Path path, BranchState state );
/**
* Returns a new instance with the exact expansion logic, but reversed.
*
* @return a reversed {@link PathExpander}.
*/
PathExpander reverse();
}
在遍历traverser的时候会先执行next()方法,再执行Iterable中的expand()方法
执行1次for循环,生成一条规则。执行完当前traverser里的遍历,多个for循环,就要重新获取startnode与endnode,生成新的traverser
for (Path path : traverser) {
@Override
public TraversalBranch next( TraversalContext metadata )
{
TraversalBranch result = null;
while ( result == null )
{
TraversalBranch next = current.next( expander, metadata );
if ( next != null )
{
queue.add( next );
result = next;
}
else
{
current = queue.poll();
if ( current == null )
{
return null;
}
}
}
return result;
}
public static PathExpander standardRandomWalker(int randomWalkers) {
return new PathExpander() {
@Override
public Iterable expand(Path path, BranchState state) {
Set results = Sets.newHashSet();
List candidates = Lists.newArrayList( path.endNode().getRelationships() );
if ( candidates.size() < randomWalkers || randomWalkers == 0 ) return candidates;
Random rand = new Random();
for ( int i = 0; i < randomWalkers; i++ ) {
int choice = rand.nextInt( candidates.size() );
results.add( candidates.get( choice ) );
candidates.remove( choice );
}
return results;
}
@Override
public PathExpander reverse() {
return null;
}
};
}
在这里是找到包含path节点的数据,可以是头也可以是尾
第二次,找到包含514的数据
第三次,找到包含249的数据
List candidates = Lists.newArrayList( path.endNode().getRelationships() );
直接返回candidates
if ( candidates.size() < randomWalkers || randomWalkers == 0 ) return candidates;
Random rand = new Random();
for ( int i = 0; i < randomWalkers; i++ ) {
int choice = rand.nextInt( candidates.size() );
results.add( candidates.get( choice ) );
candidates.remove( choice );
}
return results;
关于neo4j:基于关系属性的最长路径 | 码农家园 (codenong.com)
.evaluator(toDepthNoTrivial(Settings.DEPTH, pair))
public static PathEvaluator toDepthNoTrivial(final int depth, Pair pair) {
return new PathEvaluator.Adapter()
{
@Override
public Evaluation evaluate(Path path, BranchState state)
{
boolean fromSource = pair.subId == path.startNode().getId();
boolean closed = pathIsClosed( path, pair );
boolean hasTargetRelation = false;
int pathLength = path.length();
if ( path.lastRelationship() != null ) {
Relationship relation = path.lastRelationship();
hasTargetRelation = relation.getType().equals(pair.type);
if ( pathLength == 1
&& relation.getStartNodeId() == pair.objId
&& relation.getEndNodeId() == pair.subId
&& hasTargetRelation)
return Evaluation.INCLUDE_AND_PRUNE;
}
if ( pathLength == 0 )
return Evaluation.EXCLUDE_AND_CONTINUE;
if ( pathLength == 1 && hasTargetRelation && closed )
return Evaluation.EXCLUDE_AND_PRUNE;
if ( closed && fromSource )
return Evaluation.INCLUDE_AND_PRUNE;
else if ( closed )
return Evaluation.EXCLUDE_AND_PRUNE;
if (selfloop(path))
return Evaluation.EXCLUDE_AND_PRUNE;
return Evaluation.of( pathLength <= depth, pathLength < depth );
}
};
}
其中值的注意的是,hastargetrelation是判断pairs与path的target是否相同,但是只去判断一次,只判断firstatom
1. 如果是pathLength =1,并且pairs的尾与path的头相同,以及path的尾与pairs的头相同,并且target相同
if ( pathLength == 1
&& relation.getStartNodeId() == pair.objId
&& relation.getEndNodeId() == pair.subId
&& hasTargetRelation)
return Evaluation.INCLUDE_AND_PRUNE;
2.执行startnode和endnode单一节点的时候
if ( pathLength == 0 )
return Evaluation.EXCLUDE_AND_CONTINUE;
3.当pathlength=1,当前的path与pairs是头头,尾尾相同或者是path的头与pairs的尾相同,不纳入规则,并且不能继续遍历这些节点,要跳出不然就是死循环
if ( pathLength == 1 && hasTargetRelation && closed )
return Evaluation.EXCLUDE_AND_PRUNE;
4.中间有元素,也就是path不止为1,头头尾尾相同,纳入规则,如果头尾倒序相同不能纳入为规则
if ( closed && fromSource )
return Evaluation.INCLUDE_AND_PRUNE;
else if ( closed )
return Evaluation.EXCLUDE_AND_PRUNE;
5.当该节点只有一个节点,后续没有接任何一个节点
if (selfloop(path))
return Evaluation.EXCLUDE_AND_PRUNE;
private static boolean selfloop(Path path) {
return path.startNode().equals( path.endNode() ) && path.length() != 0;
}
6.如果两个长度都满足 pathLength <= depth, pathLength < depth,执行INCLUDE_AND_CONTINUE
如果只满足pathLength <= depth,执行INCLUDE_AND_PRUNE,包含这条规则但不再遍历下去了
如果不满足pathLength <= depth,超过了范围不能包含这条规则,并且不对path的节点继续遍历
后一个双目实际上都是赋值EXCLUDE_AND_PRUNE
return Evaluation.of( pathLength <= depth, pathLength < depth );
public static Evaluation of( boolean includes, boolean continues )
{
return includes ? (continues ? INCLUDE_AND_CONTINUE : INCLUDE_AND_PRUNE)
: (continues ? EXCLUDE_AND_CONTINUE : EXCLUDE_AND_PRUNE);
}
判断是否是closed
fromSource path的前与pair的前比较,相同则为fromSource
fromSource为true表示,前和前相同
在fromSource的基础上
如果fromSource正确,则比较path的后与pair的后是否相同,如果相同则为closed
如果fromSource错误,则比较path的后与pair的前是否相同,如果相同则为closed
closed为true表示,两个节点至少有一个节点是相同的,可以是正序相同也可以是倒序相同
private static boolean pathIsClosed(Path path, Pair pair) {
boolean fromSource = path.startNode().getId() == pair.subId;
if ( fromSource )
return path.endNode().getId() == pair.objId;
else
return path.endNode().getId() == pair.subId;
}
大量的Evaluator是用来决定在每一个位置(用一个 Path
表示):是应该继续遍历查询以及节点是否包括在结果中。 对于一个给定的 Path,它要求对遍历查询分支采用下面四个动作中的一种:
Evaluation.INCLUDE_AND_CONTINUE
: 包括这个节点在结果中并且继续遍历查询。Evaluation.INCLUDE_AND_PRUNE
: 包括这个节点在结果中并且继续不遍历查询。Evaluation.EXCLUDE_AND_CONTINUE
: 排除这个节点在结果中并且继续遍历查询。Evaluation.EXCLUDE_AND_PRUNE
: 排除这个节点在结果中并且继续不遍历查询。.traverse(startNode, endNode);
第一条path是数据 (547)-[PUBLISHES,986]->(871)
第二条path是数据 (547)-[PUBLISHES,1649]->(1003)
for (Path path : traverser) {
Rule rule = Context.createTemplate(path, pair);
长度为2
第一次
第二次
长度为3
注意pair是一直保持不变的,path是会因traverse而不断的更新
public static Rule createTemplate(Path path, Pair pair) {
List bodyAtoms = buildBodyAtoms(path);
Atom head = new Atom(pair);
return new Template(head, bodyAtoms);
}
对(547)-[PUBLISHES,1649]->(1003)<-[PUBLISHES,1323]-(98),从左到右生成bodyatoms,第一个节点是547,第二个节点是1003
得到下面的结果
注意看都是正序存储的,1003是endnode,都在后面,publishes(xxx,title91)
具体方向是在里面有表示sub到obj是从左到右,outgoing箭头是正向,incoming箭头是反向
private static List buildBodyAtoms(Path path) {
List bodyAtoms = Lists.newArrayList();
List relationships = Lists.newArrayList( path.relationships() );
List nodes = Lists.newArrayList( path.nodes() );
for( int i = 0; i < relationships.size(); i++ ) {
bodyAtoms.add( new Atom( nodes.get( i ), relationships.get( i ) ) );
}
return bodyAtoms;
}
List bodyAtoms = Lists.newArrayList();
List relationships = Lists.newArrayList( path.relationships() );
List nodes = Lists.newArrayList( path.nodes() );
(547)-[PUBLISHES,1779]->(924)
Atom head = new Atom(pair);
pair的方向是outgoing的,直接从annotedtrainning.txt中读取
/**
* Init head atom with info provided by instance.
*/
public Atom(Pair pair) {
type = pair.type;
predicate = pair.type.name();
subject = pair.subName;
subjectId = pair.subId;
object = pair.objName;
objectId = pair.objId;
direction = Direction.OUTGOING;
}
值得注意的是path的方向不一定是outgoing,要判断
以(924)<-[PUBLISHES,1114]-(98)为例
nodes.get(0)获得924
relationship.get(0) (98)-[PUBLISHES,1114]->(924)
bodyatom就表示为 PUB(98,924) 从左到右,sub是924,obj是98,只不过是incoming
bodyAtoms.add( new Atom( nodes.get( i ), relationships.get( i ) ) );
判断此时的endnode与source值即左边的节点是否相同,如果相同则是inverse,inverse就执行Incoming
public Atom(Node source, Relationship relationship) {
// System.out.println(relationship.getEndNode());
boolean inverse = source.equals(relationship.getEndNode());
type = relationship.getType();
predicate = relationship.getType().name();
if ( inverse ) {
direction = Direction.INCOMING;
subject = (String) relationship.getEndNode().getProperty(Settings.NEO4J_IDENTIFIER);
object = (String) relationship.getStartNode().getProperty(Settings.NEO4J_IDENTIFIER);
subjectId = relationship.getEndNodeId();
objectId = relationship.getStartNodeId();
}
else {
direction = Direction.OUTGOING;
subject = (String) relationship.getStartNode().getProperty(Settings.NEO4J_IDENTIFIER);
object = (String) relationship.getEndNode().getProperty(Settings.NEO4J_IDENTIFIER);
subjectId = relationship.getStartNodeId();
objectId = relationship.getEndNodeId();
}
}
return new Template(head, bodyAtoms);
public Template(Atom h, List b) {
super( h, b );
Atom firstAtom = bodyAtoms.get( 0 );
Atom lastAtom = bodyAtoms.get( bodyAtoms.size() - 1 );
head.subject = "X";
head.object = "Y";
int variableCount = 0;
for(Atom atom : bodyAtoms) {
atom.subject = "V" + variableCount;
atom.object = "V" + ++variableCount;
}
if ( fromSubject ) firstAtom.subject = "X";
else firstAtom.subject = "Y";
if ( closed && fromSubject ) lastAtom.object = "Y";
else if ( closed ) lastAtom.object = "X";
}
super( h, b );
值的注意的是bodyatom就是 提取path中的每一个关系
lastAtom 是bodyatom中的最后一个
Rule(Atom head, List bodyAtoms) {
this.head = head;
this.bodyAtoms = new ArrayList<>(bodyAtoms);
Atom lastAtom = bodyAtoms.get(bodyAtoms.size() - 1);
closed = head.getSubjectId() == lastAtom.getObjectId() || head.getObjectId() == lastAtom.getObjectId();
Atom firstAtom = bodyAtoms.get( 0 );
fromSubject = head.getSubjectId() == firstAtom.getSubjectId();
}
值的注意的是,在一次traverse当中head永远是固定的,直到遍历完这个pairs中的规则,获取新的pairs才改变这个head
相当于一个head会对应一个traverse的遍历,对应多次for循环生成规则,生成完以后,才重新获取一个head
this.head = head;
this.bodyAtoms = new ArrayList<>(bodyAtoms);
通过head的前与lastatom的后相同,或者head的后与lastatom的后相同,则是closed的
但是这种规则,都被exclude了。(head的后与lastatom的后相同)
closed = head.getSubjectId() == lastAtom.getObjectId() || head.getObjectId() == lastAtom.getObjectId();
firstatom是bodyatom中的第一个
通过firstatom的前与head的前是否相同
Atom firstAtom = bodyAtoms.get( 0 );
fromSubject = head.getSubjectId() == firstAtom.getSubjectId();
Atom firstAtom = bodyAtoms.get( 0 );
Atom lastAtom = bodyAtoms.get( bodyAtoms.size() - 1 );
head.subject = "X";
head.object = "Y";
下面的例子都是正向的(outgoing),(数字一,数字二),数字一大于数字二
如果(数字一,数字二),数字一小于数字二,则是反向的(incoming),也就是数字二在前,箭头是<-, 数字二在后
int variableCount = 0;
for(Atom atom : bodyAtoms) {
atom.subject = "V" + variableCount;
atom.object = "V" + ++variableCount;
firstAtom的前缀改为x,否则firstAtom的前缀改为Y
if ( fromSubject ) firstAtom.subject = "X";
else firstAtom.subject = "Y";
lastAtom的后缀改为Y,否则lastAtom的前缀改为X
if ( closed && fromSubject ) lastAtom.object = "Y";
else if ( closed ) lastAtom.object = "X";
(547)-[PUBLISHES,1649]->(1003)<-[PUBLISHES,1323]-(98),得到的规则如下所示
Rule rule = Context.createTemplate(path, pair);
把规则添加到队列当中
两者都是往队列尾部插入元素,不同的时候,当超出队列界限的时候,add()方法是抛出异常让你处理,而offer()方法是直接返回false
阻塞队列(BlockingQueue)是一个支持两个附加操作的队列。这两个附加的操作是:在队列为空时,获取元素的线程会等待队列变为非空。当队列满时,存储元素的线程会等待队列可用。阻塞队列常用于生产者和消费者的场景,生产者是往队列里添加元素的线程,消费者是从队列里拿元素的线程。阻塞队列就是生产者存放元素的容器,而消费者也只从容器里拿元素。
阻塞队列(BlockingQueue)_忘川丿的博客-CSDN博客_阻塞队列
阻塞队列 - 会飞的金鱼 - 博客园 (cnblogs.com)
队列的add()方法和offer()方法的区别_machihaoyu的博客-CSDN博客_offer方法
if (ruleQueue.offer(rule, 100, TimeUnit.MILLISECONDS))
break;
}
public boolean offer(E e, long timeout, TimeUnit unit)
throws InterruptedException {
checkNotNull(e);
// 把超时时间转换成纳秒
long nanos = unit.toNanos(timeout);
final ReentrantLock lock = this.lock;
// 获取一个可中断的互斥锁
lock.lockInterruptibly();
try {
// while循环的目的是防止在中断后没有到达传入的timeout时间,继续重试
while (count == items.length) {
if (nanos <= 0)
return false;
// 等待nanos纳秒,返回剩余的等待时间(可被中断)
nanos = notFull.awaitNanos(nanos);
}
enqueue(e);
return true;
} finally {
lock.unlock();
}
}
只用一条数据学习的规则
generalization(trainPairs, context);
pairs.subId=547&&pairs.objId=924
# (1\12) Start Learning Rules for Target: PUBLISHES
# Train Size: 440
# Visited/Training Instances: 1/440 | Ratio: 0.23% | Sampled Paths: 60,001 | Saturation: 100%
# Generalization: time = 6.412s | memory = 90.109mb
# Generated Templates: 33 | ClosedRules: 1 | OpenRules: 32 | len=1: 3 | len=2: 9 | len=3: 20
Logger.println(MessageFormat.format("# Visited/Training Instances: {0}/{1} | Ratio: {2}%" +
" | Sampled Paths: {3} | Saturation: {4}%"
, visitedTrainPairs.size()
, trainPairs.size()
, new DecimalFormat("###.##").format(((double) visitedTrainPairs.size() / trainPairs.size()) * 100f)
, consumer.getPathCount()
, new DecimalFormat("###.##").format(consumer.getSaturation() * 100f))
);
GlobalTimer.updateTemplateGenStats(Helpers.timerAndMemory(s, "# Generalization"));
Logger.println(Context.analyzeRuleComposition("# Generated Templates"
, context.getAbstractRules()), 1);
}
Visited/Training Instances:表示440条publishes相关的数据一共遍历了多少条,也就是pairs的数量
通过下面这行代码随机获取进行遍历
Pair pair = trainPairs.get(rand.nextInt(trainPairs.size()));
Sampled Paths:paths的数量,生成过多少条规则,生成的规则可以是重复的,统计数量时候要加上
Rule rule = ruleQueue.poll();
if(rule != null) {
if(rule.isClosed() ? rule.length() <= Settings.CAR_DEPTH : rule.length() <= Settings.INS_DEPTH)
currentBatch.add(rule);
pathCount++;
}
Saturation:当前currentBatch中有多少条与ruleFrequency最终集合是重合的,占比多少
currentBatch是当前遍历回合当中满足长度条件的规则数量
saturation = currentBatch.isEmpty() ? 0d : (double) overlaps / currentBatch.size();
Generated Templates:得到不重复,且满足条件的规则总数
接下来的数据统计使用的是这一段代码
public static String analyzeRuleComposition(String header, Collection rules) {
NumberFormat f = NumberFormat.getNumberInstance(Locale.US);
int closedRules = 0;
int openRules = 0;
Map lengthMap = new TreeMap<>();
for (Rule rule : rules) {
if(rule.isClosed())
closedRules++;
else {
openRules++;
if (lengthMap.containsKey(rule.length())) {
lengthMap.put(rule.length(), lengthMap.get(rule.length()) + 1);
} else {
lengthMap.put(rule.length(), 1);
}
}
}
String content = MessageFormat.format("{0}: {1} | ClosedRules: {2} | OpenRules: {3} | "
, header, rules.size(), closedRules, openRules);
List words = new ArrayList<>();
for (Integer length : lengthMap.keySet()) {
words.add("len=" + length + ": " + f.format(lengthMap.get(length)));
}
content += String.join(" | ", words);
return content;
}
ClosedRules:ClosedRules的数量
OpenRules:OpenRules的数量
3个len:Open Rules len的情况,长度1有多少条,2有多少条,3有多少条
if(Settings.ESSENTIAL_TIME != -1 && Settings.INS_DEPTH != 0)
EssentialRuleGenerator.generateEssentialRules(trainPairs, validPairs, context, graph, ruleIndexFile, ruleFile);
public static void generateEssentialRules(Set trainPairs, Set validPairs
, Context context, GraphDatabaseService graph
, File tempFile, File ruleFile) {
long s = System.currentTimeMillis();
NumberFormat f = NumberFormat.getNumberInstance(Locale.US);
GlobalTimer.setEssentialStartTime(System.currentTimeMillis());
Set essentialRules = new HashSet<>();
BlockingQueue tempFileContents = new LinkedBlockingDeque<>(10000000);
BlockingQueue ruleFileContents = new LinkedBlockingDeque<>(10000000);
for (Rule rule : context.getAbstractRules()) {
if(!rule.isClosed() && rule.length() == 1)
essentialRules.add(rule);
}
Multimap trainObjToSub = MultimapBuilder.hashKeys().hashSetValues().build();
Multimap validObjToSub = MultimapBuilder.hashKeys().hashSetValues().build();
Multimap trainSubToObj = MultimapBuilder.hashKeys().hashSetValues().build();
Multimap validSubToObj = MultimapBuilder.hashKeys().hashSetValues().build();
for (Pair trainPair : trainPairs) {
trainObjToSub.put(trainPair.objId, trainPair.subId);
trainSubToObj.put(trainPair.subId, trainPair.objId);
}
for (Pair validPair : validPairs) {
validObjToSub.put(validPair.objId, validPair.subId);
validSubToObj.put(validPair.subId, validPair.objId);
}
ExecutorService executors = new SemaphoredThreadPool(Settings.THREAD_NUMBER);
RuleWriter tempFileWriter = new RuleWriter(0, executors, tempFile, tempFileContents, true);
RuleWriter ruleFileWriter = new RuleWriter(0, executors, ruleFile, ruleFileContents, true);
Set specializedRules = new HashSet<>();
try {
for (Rule rule : essentialRules) {
if(GlobalTimer.stopEssential()) break;
BlockingQueue contents = new LinkedBlockingDeque<>();
Set> futures = new HashSet<>();
context.ruleFrequency.remove(rule);
Set groundings = generateBodyGrounding(rule, graph);
Multimap originalToTails = MultimapBuilder.hashKeys().hashSetValues().build();
Multimap tailToOriginals = MultimapBuilder.hashKeys().hashSetValues().build();
for (Pair grounding : groundings) {
originalToTails.put(grounding.subId, grounding.objId);
tailToOriginals.put(grounding.objId, grounding.subId);
}
Multimap trainAnchoringToOriginals = rule.isFromSubject() ? trainObjToSub : trainSubToObj;
Multimap validAnchoringToOriginals = rule.isFromSubject() ? validObjToSub : validSubToObj;
for (Long anchoring : trainAnchoringToOriginals.keySet()) {
if(GlobalTimer.stopEssential()) break;
Collection validOriginals = validAnchoringToOriginals.get(anchoring);
futures.add(executors.submit(new CreateHAR(rule, anchoring, trainAnchoringToOriginals.get(anchoring)
, validOriginals, originalToTails.keySet()
, graph, contents, ruleFileContents, context)));
Set candidates = new HashSet<>();
for (Long original : trainAnchoringToOriginals.get(anchoring)) {
if(GlobalTimer.stopEssential()) break;
for (Long tail : originalToTails.get(original)) {
if(GlobalTimer.stopEssential()) break;
Pair candidate = new Pair(anchoring, tail);
if (!candidates.contains(candidate) && !trivialCheck(rule, anchoring, tail)) {
candidates.add(candidate);
futures.add(executors.submit(new CreateBAR(rule, candidate, trainAnchoringToOriginals.get(anchoring)
, validOriginals, tailToOriginals.get(tail)
, graph, contents, ruleFileContents, context)));
}
}
}
}
for (Future> future : futures) {
future.get();
}
if(!contents.isEmpty()) {
specializedRules.add(rule);
rule.stats.compute();
tempFileContents.put("ABS: " + context.getIndex(rule) + "\t"
+ ((Template) rule).toRuleIndexString() + "\t"
+ f.format(rule.getStandardConf()) + "\t"
+ f.format(rule.getSmoothedConf()) + "\t"
+ f.format(rule.getPcaConf()) + "\t"
+ f.format(rule.getApcaConf()) + "\t"
+ f.format(rule.getHeadCoverage()) + "\t"
+ f.format(rule.getValidPrecision()) + "\n"
+ String.join("\t", contents) + "\n");
}
}
executors.shutdown();
executors.awaitTermination(1L, TimeUnit.MINUTES);
tempFileWriter.join();
ruleFileWriter.join();
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
}
GlobalTimer.updateGenEssentialStats(Helpers.timerAndMemory(s, "# Generate Essentials"));
Logger.println("# Specialized Essential Templates: " + f.format(specializedRules.size()) + " | " +
"Generated Essential Rules: " + f.format(context.getEssentialRules()), 1);
}
NumberFormat(数字格式化类)_小池先生的博客-CSDN博客_numberformat
NumberFormat f = NumberFormat.getNumberInstance(Locale.US);
GlobalTimer.setEssentialStartTime(System.currentTimeMillis());
context变量里存储了前面ruleFrequency的规则
for (Rule rule : context.getAbstractRules()) {
if(!rule.isClosed() && rule.length() == 1)
essentialRules.add(rule);
}
有sub左,obj右,也有sub右,obj左,其中ObjToSub与SubToObj的顺序是相反的
但值的注意的不管是ObjToSub或者是SubtoObj都包含了incoming和outgoing的规则
Multimap trainObjToSub = MultimapBuilder.hashKeys().hashSetValues().build();
Multimap validObjToSub = MultimapBuilder.hashKeys().hashSetValues().build();
Multimap trainSubToObj = MultimapBuilder.hashKeys().hashSetValues().build();
Multimap validSubToObj = MultimapBuilder.hashKeys().hashSetValues().build();
for (Pair trainPair : trainPairs) {
trainObjToSub.put(trainPair.objId, trainPair.subId);
trainSubToObj.put(trainPair.subId, trainPair.objId);
}
for (Pair validPair : validPairs) {
validObjToSub.put(validPair.objId, validPair.subId);
validSubToObj.put(validPair.subId, validPair.objId);
}
ExecutorService executors = new SemaphoredThreadPool(Settings.THREAD_NUMBER);
允许同时执行任务的最大线程数为6
Semaphore详解_JackieZhengChina的博客-CSDN博客_semaphore
并发编程 Semaphore的使用和详解_丽闪无敌的博客-CSDN博客_semaphore 使用详解
public SemaphoredThreadPool(int nThreads) {
super(nThreads, nThreads, 0L, TimeUnit.MILLISECONDS
, new LinkedBlockingDeque(nThreads * 2));
semaphore = new Semaphore(nThreads);
}
RuleWriter tempFileWriter = new RuleWriter(0, executors, tempFile, tempFileContents, true);
RuleWriter ruleFileWriter = new RuleWriter(0, executors, ruleFile, ruleFileContents, true);
RuleWriter(int id, ExecutorService service, File file, BlockingQueue contents, boolean append) {
super("EssentialRule-RuleWriter-" + id);
this.id = id;
this.service = service;
this.file = file;
this.contents = contents;
this.append = append;
start();
}
@Override
public void run() {
try(PrintWriter writer = new PrintWriter(new FileWriter(file, append))) {
while(!service.isTerminated() || !contents.isEmpty()) {
String line = contents.poll();
if(line != null) {
writer.println(line);
}
}
} catch (IOException e) {
e.printStackTrace();
System.exit(-1);
}
}
}