因工作需要,写了一个通用的树迭代器。主体逻辑参考了AOM (www.operamasks.org)的ComponentIterator,作了以下改进:
1. 接受任意节点类型(泛型参数)
2. 加入了一个stack来跟踪parent (节点无需提供getParent)
3. 加入了一个函数接口来过滤子树
主类:
public abstract class TreeIterator<Node> implements Iterator<Node>
{
private Node root;
private Node current;
private Node next;
private Fun1<Node, Boolean> subTreeFilter;
private Map<Node, Iterator<Node>> iteratorMap = new IdentityHashMap<Node, Iterator<Node>>();
private Deque<Node> path = new ArrayDeque<Node>();
public TreeIterator(Node root, boolean includeRoot, Fun1<Node, Boolean> subTreeFilter)
{
this.root = root;
this.subTreeFilter = subTreeFilter;
if (includeRoot && isSubTreeAccepted(root))
{
current = next = root;
}
else
{
current = next = getNextNode(root);
}
this.subTreeFilter = subTreeFilter;
}
private boolean isSubTreeAccepted(Node c)
{
if (subTreeFilter != null && !subTreeFilter.apply(c)) return false;
return true;
}
public boolean hasNext()
{
if (next == null && current != null)
{
current = next = getNextNode(current);
}
return next != null;
}
public Node next()
{
if (!hasNext())
{
throw new NoSuchElementException();
}
Node result = next;
next = null;
return result;
}
public void remove()
{
throw new UnsupportedOperationException();
}
private Node getNextNode(Node node)
{
Iterator<Node> children = getChildren(node);
if (children != null)
{
while (children.hasNext())
{
Node next = children.next();
if (isSubTreeAccepted(next))
{
iteratorMap.put(node, children);
path.push(node);
return next;
}
}
}
if (node == root)
{
return null;
}
Node next;
while ((next = getNextSibling(node)) == null)
{
node = path.pop();
if (node == null || node == root) return null;
}
return next;
}
private Node getNextSibling(Node node)
{
Node parent = path.peek();
if (parent == null)
{
return null;
}
Iterator<Node> children = iteratorMap.get(parent);
Node result = null;
if (children != null)
{
while (children.hasNext())
{
Node next = children.next();
if (isSubTreeAccepted(next))
{
result = next;
break;
}
}
}
if (result == null)
{
iteratorMap.remove(parent);
}
return result;
}
protected abstract Iterator<Node> getChildren(Node node);
}
Filter辅助类(用于过滤子树)
public abstract class Fun1<P1, R>
{
protected abstract R f (P1 arg0) throws Exception;
public final R apply(P1 arg0) {
try {
return f(arg0);
} catch(Exception e) {
throw new FunctionException(e);
}
}
}
public class FunctionException extends RuntimeException
{
private static final long serialVersionUID = 1L;
public FunctionException()
{
super();
}
public FunctionException(String message, Throwable cause)
{
super(message, cause);
}
public FunctionException(String message)
{
super(message);
}
public FunctionException(Throwable cause)
{
super(cause);
}
}
UnitTest (用法参考)
public class TreeIteratorTest
{
class TreeNode {
private String value;
private List<TreeNode> children;
public TreeNode(String value, TreeNode... children)
{
super();
this.value = value;
this.children = Arrays.asList(children);
}
public TreeNode(String value)
{
super();
this.value = value;
this.children = null;
}
public String getValue()
{
return value;
}
public void setValue(String value)
{
this.value = value;
}
public List<TreeNode> getChildren()
{
return children;
}
public void setChildren(List<TreeNode> children)
{
this.children = children;
}
@Override
public String toString()
{
return "TreeNode [value=" + value + "]";
}
}
private TreeNode root;
private TreeNode node(String value) {
return new TreeNode(value);
}
private TreeNode node(String value, TreeNode... nodes) {
return new TreeNode(value, nodes);
}
class TreeNodeIterator extends TreeIterator<TreeNode> {
public TreeNodeIterator(TreeNode root, boolean includeRoot, Fun1<TreeNode, Boolean> subTreeFilter)
{
super(root, includeRoot, subTreeFilter);
}
@Override
protected Iterator<TreeNode> getChildren(TreeNode node)
{
return node.getChildren() != null ? node.getChildren().iterator() : null;
}
}
@Before
public void setUp() throws Exception
{
root = node("root",
node("child1",
node ("child11"),
node ("child12")),
node("child2"),
node("child3",
node ("child31"),
node ("child32",
node("child321"),
node("child322")),
node ("child33"))
);
}
@Test
public void testFullScan() throws Exception
{
TreeNodeIterator it = new TreeNodeIterator(root, true, null);
StringBuilder sb = new StringBuilder();
while (it.hasNext()) {
sb.append(it.next().getValue()).append('|');
}
assertEquals("root|child1|child11|child12|child2|child3|child31|child32|child321|child322|child33|", sb.toString());
}
@Test
public void testFullScanWithoutRoot() throws Exception
{
TreeNodeIterator it = new TreeNodeIterator(root, false, null);
StringBuilder sb = new StringBuilder();
while (it.hasNext()) {
sb.append(it.next().getValue()).append('|');
}
assertEquals("child1|child11|child12|child2|child3|child31|child32|child321|child322|child33|", sb.toString());
}
@Test
public void testFullScanWithoutBranch3() throws Exception
{
TreeNodeIterator it = new TreeNodeIterator(root, false, new Fun1<TreeNode, Boolean>() {
@Override protected Boolean f(TreeNode arg0) throws Exception
{
return !arg0.value.equals("child3");
}});
StringBuilder sb = new StringBuilder();
while (it.hasNext()) {
sb.append(it.next().getValue()).append('|');
}
assertEquals("child1|child11|child12|child2|", sb.toString());
}
@Test
public void testFullScanWithoutBranch32() throws Exception
{
TreeNodeIterator it = new TreeNodeIterator(root, false, new Fun1<TreeNode, Boolean>() {
@Override protected Boolean f(TreeNode arg0) throws Exception
{
return !arg0.value.equals("child32");
}});
StringBuilder sb = new StringBuilder();
while (it.hasNext()) {
sb.append(it.next().getValue()).append('|');
}
assertEquals("child1|child11|child12|child2|child3|child31|child33|", sb.toString());
}
@Test
public void testFullScanSingleNode() throws Exception
{
TreeNodeIterator it = new TreeNodeIterator(node("root"), true, null);
StringBuilder sb = new StringBuilder();
while (it.hasNext()) {
sb.append(it.next().getValue()).append('|');
}
assertEquals("root|", sb.toString());
}
@After
public void tearDown() throws Exception
{}
}