Vertical Order Traversal of a Binary Tree Problem

Problem

Given the root of a binary tree, calculate the vertical order traversal of the binary tree.

For each node at position (row, col), its left and right children will be at positions (row + 1, col - 1) and (row + 1, col + 1) respectively. The root of the tree is at (0, 0).

The vertical order traversal of a binary tree is a list of top-to-bottom orderings for each column index starting from the leftmost column and ending on the rightmost column. There may be multiple nodes in the same row and same column. In such a case, sort these nodes by their values.

Return the vertical order traversal of the binary tree.

Examples

Example 1:

Input:
root = [3,9,20,null,null,15,7]
Output:
 [ [9],[3,15],[20],[7] ]
Explanation:
Column -1: Only node 9 is in this column.
Column 0: Nodes 3 and 15 are in this column in that order from top to bottom.
Column 1: Only node 20 is in this column.
Column 2: Only node 7 is in this column.

Example 2:

Input:
root = [1,2,3,4,5,6,7]
Output:
 [ [4],[2],[1,5,6],[3],[7] ]
Explanation:
Column -2: Only node 4 is in this column.
Column -1: Only node 2 is in this column.
Column 0: Nodes 1, 5, and 6 are in this column.
          1 is at the top, so it comes first.
          5 and 6 are at the same position (2, 0), so we order them by their value, 5 before 6.
Column 1: Only node 3 is in this column.
Column 2: Only node 7 is in this column.

Example 3:

Input:
root = [1,2,3,4,6,5,7]
Output:
 [ [4],[2],[1,5,6],[3],[7] ]
Explanation:
This case is the exact same as example 2, but with nodes 5 and 6 swapped.
Note that the solution remains the same since 5 and 6 are in the same location and should be ordered by their values.

We have seen similar problem - Binary Tree Traversal - Vertical Order Traversal, which is leetcode 314. This one is leetcode 987.

Difference between this solution and other one: 314. If two nodes are in the same row and column, the order should be from left to right. 987. If two nodes have the same position, then the value of the node that is reported first is the value that is smaller.

Solution

Method 1 - Using Priority Queue, Tracking Height Depth 🏆

We can augment treenode’s value with height and depth. Let this new object be Point, with x being horizontal distance from center or root of tree, y being depth from root.:

static class Point{
    int x,y,val;
    Point(int x,int y,int val){
        this.x = x;
        this.y = y;
        this.val = val;
    }
}

Then we write a comparator for priority queue, when we have 2 nodes to compare:

  1. Left most node comes first.So, smaller x value comes first.
  2. if x value is same, then one with less depth comes first
  3. if everything is same, then smaller value is prefered.

We can follow vertical order traversal i.e. preorder traversal similar to Binary Tree Traversal - Vertical Order Traversal.

static class Point{
    int x,y,val;
    Point(int x,int y,int val){
        this.x = x;
        this.y = y;
        this.val = val;
    }
}
public List<List<Integer>> verticalTraversal(TreeNode root) {	
	PriorityQueue<Point> pq = new PriorityQueue<Point>(new Comparator<Point>(){
		public int compare(Point p1,Point p2){
			if(p1.x < p2.x) return -1;
			if(p2.x < p1.x) return 1;
			if(p1.y > p2.y) return -1;
			if(p1.y < p2.y) return 1;
			return p1.val - p2.val;
		}
	});
	
	dfs(root,0,0,pq);
	
	List<List<Integer>> ans = new ArrayList<List<Integer>>();
	
	Point prev = pq.poll();
	List<Integer> zerothAns = new ArrayList<>();
	zerothAns.add(prev.val);
	ans.add(zerothAns);
	
	while(!pq.isEmpty()){
		Point p = pq.poll();
		if(p.x != prev.x){
			List<Integer> subAns = new ArrayList<>();
			subAns.add(p.val);
			ans.add(subAns);
		}else {
			List<Integer> subAns = ans.get(ans.size()-1);
            subAns.add(p.val);
		}
		
		prev = p;
	}
	
	return ans;
}
    
private void dfs(TreeNode root,int x,int y,PriorityQueue<Point> pq){
	if(root == null) return;
	pq.offer(new Point(x,y,root.val));
	dfs(root.left,x-1,y-1,pq);
	dfs(root.right,x+1,y-1,pq);
}

We can do slight improvements. We can use smallest x value as the offset to the index of the result list while doing dfs. Then we don’t have to think about prev value. We can also improve comparator code a bit:

Code:

static class Point{
    int x,y,val;
    Point(int x,int y,int val){
        this.x = x;
        this.y = y;
        this.val = val;
    }
}
public List<List<Integer>> verticalTraversal(TreeNode root) {	
	PriorityQueue<Point> pq = new PriorityQueue<Point>((a, b) -> 
		a.x != b.x ? Integer.compare(a.x, b.x) : (a.y != b.y ? Integer.compare(b.y, a.y) : Integer.compare(a.val, b.val))		
	);
	
	dfs(root,0,0,pq);

	int leftMost = Math.abs(pq.peek().x);
	List<List<Integer>> ans = new ArrayList<List<Integer>>();
	
	while(!pq.isEmpty()){
		Point p = pq.poll();
		int idx = p.x + leftMost;
		while (ans.size() < idx+1) {
			ans.add(new ArrayList<>());
		}
		ans.get(idx).add(p.val);
	}
	
	return ans;
}
    
private void dfs(TreeNode root,int x,int y,PriorityQueue<Point> pq){
	if(root == null) return;
	pq.offer(new Point(x,y,root.val));
	dfs(root.left,x-1,y-1,pq);
	dfs(root.right,x+1,y-1,pq);
}

Source - [3]

Note that instead of creating Point class we can use TreeMap class, but then to gain speed, we cann’t handle duplicate elements: (99+) Java TreeMap Solution - LeetCode Discuss.

Method 2 - Using TreeMap Tracking Height and Depth 🥈

Similar to method 1, we can create a Point object, but instead of priority queue, we can use treemap. The key will be horizontal distance i.e. x and value as list of points. As, x is already part of keyset, we can remove it from points to save space.

static class Point {
    int y, val;
    Point(int y,int val){
        this.y = y;
        this.val = val;
    }
}
public List<List<Integer>> verticalTraversal(TreeNode root) {
	Map<Integer, List<Point>> map = new TreeMap<>();
	dfs(root, 0, 0, map);
	List<List<Integer>> ans = new ArrayList<>();
	for (int x: map.keySet()) {
		Collections.sort(map.get(x), new Comparator<Point>() {
			public int compare(Point a, Point b) {
				if (a.y == b.y) {
					return a.val - b.val;
				}
				return a.y - b.y;
			}
		});
		ans.add(map.get(x).stream().map(p -> p.val).collect(Collectors.toList()));
	}
	return ans;
}
private void dfs(TreeNode root, int x, int y, Map<Integer, List<Point>> map) {
	if (root == null) {
		return;
	}
	map.putIfAbsent(x, new ArrayList<Point>());
	map.get(x).add(new Point(y, root.val));
	dfs(root.left, x - 1, y + 1, map);
	dfs(root.right, x + 1, y + 1, map);
}

Source - [4]

Method 3 - Using BFS and Distance Queue and TreeMap

We can use level order traversal by each level - Binary Tree Level Order Traversal - Level by Level. Now, when we are done at the level, we have to sort the level. We know, If two nodes have the same position, then:

  • check the layer, the node on higher level(close to root) goes first
  • if they also in the same level, order from small to large

We can do BFS. Then on each level, we use treemap to sort the list.

Code:

public verticalTraversal(TreeNode root) {
	if (root == null) {
		return new LinkedList<>();
	}

	// Using treemap as keys are automatically sorted
	Map<Integer, List<Integer>> map = new TreeMap<>();
	List<List<Integer>> result = new LinkedList<>();

	Queue<TreeNode> queue = new LinkedList<>();
	queue.add(root);
	Queue<Integer> distanceQueue = new LinkedList<>();
	distanceQueue.add(0);

	while (!queue.isEmpty()) {
		int size = queue.size();
		Map<Integer, List<Integer>> tmp = new HashMap<>();

		while (size > 0) {
			TreeNode current = queue.poll();

			int currDistance = distanceQueue.poll();

			List<Integer> currentList = tmp.getOrDefault(currDistance, new LinkedList<>());
			currentList.add(current.val);
			tmp.put(currDistance, currentList);

			if (current.left != null) {
				queue.add(current.left);
				distanceQueue.add(currDistance - 1);
			}

			if (current.right != null) {
				queue.add(current.right);
				distanceQueue.add(currDistance + 1);
			}
			size--;
		}
		// Now sort the current level and add to map
		for (int key: tmp.keySet()) {
			List<Integer> list = tmp.get(key);
			Collections.sort(list);

			map.putIfAbsent(key, new LinkedList<>());
			map.get(key).addAll(list);
		}
	}

	// left to right, from maximum -ve distance from root to maximum +ve distance from root
	for (int i: map.keySet()) {
		List<Integer> list = map.get(i);
		result.add(list);
	}
	return result;
}