一个通用的TreeIterator

因工作需要,写了一个通用的树迭代器。主体逻辑参考了AOM (www.operamasks.org)的ComponentIterator,作了以下改进:
1. 接受任意节点类型(泛型参数)
2. 加入了一个stack来跟踪parent (节点无需提供getParent)
3. 加入了一个函数接口来过滤子树

主类:
public abstract class TreeIterator<Node> implements Iterator<Node>
{
	private Node root;
	private Node current;
	private Node next;
	private Fun1<Node, Boolean> subTreeFilter;

	private Map<Node, Iterator<Node>> iteratorMap = new IdentityHashMap<Node, Iterator<Node>>();
	private Deque<Node> path = new ArrayDeque<Node>();

	public TreeIterator(Node root, boolean includeRoot, Fun1<Node, Boolean> subTreeFilter)
	{
		this.root = root;
		this.subTreeFilter = subTreeFilter;
		if (includeRoot && isSubTreeAccepted(root))
		{
			current = next = root;
		}
		else
		{
			current = next = getNextNode(root);
		}
		this.subTreeFilter = subTreeFilter;
	}

	private boolean isSubTreeAccepted(Node c)
	{
		if (subTreeFilter != null && !subTreeFilter.apply(c)) return false;
		return true;
	}

	public boolean hasNext()
	{
		if (next == null && current != null)
		{
			current = next = getNextNode(current);
		}
		return next != null;
	}

	public Node next()
	{
		if (!hasNext())
		{
			throw new NoSuchElementException();
		}

		Node result = next;
		next = null;
		return result;
	}

	public void remove()
	{
		throw new UnsupportedOperationException();
	}

	private Node getNextNode(Node node)
	{
		Iterator<Node> children = getChildren(node);
		if (children != null)
		{
			while (children.hasNext())
			{
				Node next = children.next();
				if (isSubTreeAccepted(next))
				{
					iteratorMap.put(node, children);
					path.push(node);
					return next;
				}
			}
		}

		if (node == root)
		{
			return null;
		}

		Node next;
		while ((next = getNextSibling(node)) == null)
		{
			node = path.pop();
			if (node == null || node == root) return null;
		}
		return next;
	}

	private Node getNextSibling(Node node)
	{
		Node parent = path.peek();
		if (parent == null)
		{
			return null;
		}

		Iterator<Node> children = iteratorMap.get(parent);
		Node result = null;
		if (children != null)
		{
			while (children.hasNext())
			{
				Node next = children.next();
				if (isSubTreeAccepted(next))
				{
					result = next;
					break;
				}
			}
		}
		if (result == null)
		{
			iteratorMap.remove(parent);
		}
		return result;
	}

	protected abstract Iterator<Node> getChildren(Node node);
}


Filter辅助类(用于过滤子树)
public abstract class Fun1<P1, R>
{
	protected abstract R f (P1 arg0) throws Exception;
	
	public final R apply(P1 arg0) {
		try {
			return f(arg0);
		} catch(Exception e) {
			throw new FunctionException(e);
		}
	}
}


public class FunctionException extends RuntimeException
{
	private static final long serialVersionUID = 1L;

	public FunctionException()
	{
		super();
	}

	public FunctionException(String message, Throwable cause)
	{
		super(message, cause);
	}

	public FunctionException(String message)
	{
		super(message);
	}

	public FunctionException(Throwable cause)
	{
		super(cause);
	}
}


UnitTest (用法参考)
public class TreeIteratorTest
{
	class TreeNode {
		private String value;
		private List<TreeNode> children;
		public TreeNode(String value, TreeNode... children)
		{
			super();
			this.value = value;
			this.children = Arrays.asList(children);
		}
		public TreeNode(String value)
		{
			super();
			this.value = value;
			this.children = null;
		}
		public String getValue()
		{
			return value;
		}
		public void setValue(String value)
		{
			this.value = value;
		}
		public List<TreeNode> getChildren()
		{
			return children;
		}
		public void setChildren(List<TreeNode> children)
		{
			this.children = children;
		}
		@Override
		public String toString()
		{
			return "TreeNode [value=" + value + "]";
		}
	}
	
	private TreeNode root;
	
	private TreeNode node(String value) {
		return new TreeNode(value);
	}
	
	private TreeNode node(String value, TreeNode... nodes) {
		return new TreeNode(value, nodes);
	}
	
	class TreeNodeIterator extends TreeIterator<TreeNode> {
		
		public TreeNodeIterator(TreeNode root, boolean includeRoot, Fun1<TreeNode, Boolean> subTreeFilter)
		{
			super(root, includeRoot, subTreeFilter);
		}

		@Override
		protected Iterator<TreeNode> getChildren(TreeNode node)
		{
			return node.getChildren() != null ? node.getChildren().iterator() : null;
		}
	}
	
	@Before
	public void setUp() throws Exception
	{
		root = node("root",
			node("child1",
				node ("child11"),
				node ("child12")),
			node("child2"),
			node("child3",
				node ("child31"),
				node ("child32",
					node("child321"),
					node("child322")),
				node ("child33"))
		);
	}
	
	@Test
	public void testFullScan() throws Exception
	{
		TreeNodeIterator it = new TreeNodeIterator(root, true, null);
		StringBuilder sb = new StringBuilder();
		while (it.hasNext()) {
			sb.append(it.next().getValue()).append('|');
		}
		assertEquals("root|child1|child11|child12|child2|child3|child31|child32|child321|child322|child33|", sb.toString());
	}
	
	@Test
	public void testFullScanWithoutRoot() throws Exception
	{
		TreeNodeIterator it = new TreeNodeIterator(root, false, null);
		StringBuilder sb = new StringBuilder();
		while (it.hasNext()) {
			sb.append(it.next().getValue()).append('|');
		}
		assertEquals("child1|child11|child12|child2|child3|child31|child32|child321|child322|child33|", sb.toString());
	}
	
	@Test
	public void testFullScanWithoutBranch3() throws Exception
	{
		TreeNodeIterator it = new TreeNodeIterator(root, false, new Fun1<TreeNode, Boolean>() {
			@Override protected Boolean f(TreeNode arg0) throws Exception
			{
				return !arg0.value.equals("child3");
			}});
		StringBuilder sb = new StringBuilder();
		while (it.hasNext()) {
			sb.append(it.next().getValue()).append('|');
		}
		assertEquals("child1|child11|child12|child2|", sb.toString());
	}
	
	@Test
	public void testFullScanWithoutBranch32() throws Exception
	{
		TreeNodeIterator it = new TreeNodeIterator(root, false, new Fun1<TreeNode, Boolean>() {
			@Override protected Boolean f(TreeNode arg0) throws Exception
			{
				return !arg0.value.equals("child32");
			}});
		StringBuilder sb = new StringBuilder();
		while (it.hasNext()) {
			sb.append(it.next().getValue()).append('|');
		}
		assertEquals("child1|child11|child12|child2|child3|child31|child33|", sb.toString());
	}
	
	@Test
	public void testFullScanSingleNode() throws Exception
	{
		TreeNodeIterator it = new TreeNodeIterator(node("root"), true, null);
		StringBuilder sb = new StringBuilder();
		while (it.hasNext()) {
			sb.append(it.next().getValue()).append('|');
		}
		assertEquals("root|", sb.toString());
	}

	@After
	public void tearDown() throws Exception
	{}

}

你可能感兴趣的:(iterator)