Problem

Given the root of a complete binary tree, return the number of the nodes in the tree.

Definition and Properties of Complete Binary Tree

Examples

Example 1:

	  1
	/   \
   2     3
  / \   /
 4   5 6
Input: root = [1,2,3,4,5,6]
Output: 6

Solution

Method 1 - Basic Brute Force

Code

Java
public int countNodes(TreeNode root) {
    if (root == null)
        return 0;
    return 1 + countNodes(root.left) + countNodes(root.right);
}

But the above solution misses on the properties of the complete binary tree.

Complexity

  • ⏰ Time complexity: O(n)
  • 🧺 Space complexity: O(n) assuming recursion stack

Method 2 - Calculate Left Most and RIght Most Heights

Steps to solve this problem:

  1. get the height of left-most part
  2. get the height of right-most part
  3. when they are equal, the # of nodes = 2^h -1
  4. when they are not equal, recursively get # of nodes from left&right sub-trees

Case 1 - When Left and Right Most Height Are Equal

Refer the diagram below, then number of nodes = 2 ^h - 1 = 2 ^ 3 = 7

Case 2 - When They Are Not

Consider the case 2 in above diagram.

We have left height = 5, and right height = 4.

Now we recurse and look at left tree, they are l = 4, and r = 3 with root being at 2. Now again l and r not equal, so we again recurse and so on.

On right, we have complete tree rooted at 3, l = r = 3. So, we return quickly on the right side.

Code

Java
public int countNodes(TreeNode root) {
	if (root == null) {
		return 0;
	}

	int left = getLeftHeight(root) + 1;
	int right = getRightHeight(root) + 1;

	if (left == right) {
		return (2<< (left - 1)) - 1;
	} 
	return countNodes(root.left) + countNodes(root.right) + 1;
}

public int getLeftHeight(TreeNode n) {
	if (n == null) return 0;

	int height = 0;
	while (n.left != null) {
		height++;
		n = n.left;
	}
	return height;
}

public int getRightHeight(TreeNode n) {
	if (n == null) return 0;

	int height = 0;
	while (n.right != null) {
		height++;
		n = n.right;
	}
	return height;
}

Complexity

  • ⏰ Time complexity: O(h^2) = O(log n)^2, because at each level we are checking the left and right height again and again.
  • 🧺 Space complexity: O(h), assuming recursion stack

Method 3 - Without Using Separate Height Method 🏆

 It first walks all the way left and right to determine the height and whether it’s a full tree, meaning the last row is full. If so, then the answer is just 2^height-1. And since always at least one of the two recursive calls is such a full tree, at least one of the two calls immediately stops. Again we have runtime O(log(n)^2).  

Code

Java
public int countNodes(TreeNode root) {
	if (root == null)
		return 0;
	TreeNode left = root, right = root;
	int height = 0;
	// if right is null, left may still exist as it is complete BT
	while (right != null) {
		left = left.left;
		right = right.right;
		height++;
	}
	if (left == null) {
		return (1 << height) - 1;
	}
		
	return 1 + countNodes(root.left) + countNodes(root.right);
}

Dry Run

Look at the case 2 again, to see the dry run. root = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17]

First call - root val = 1
left = root (root value 1)
right = root (root value 1) 

Find the height of right side:

Iteration 1 => 
left = root.left (node value 2)
right = root.right (node value 3)
height++ => 1

Iteration 2 => 
left = root.left (node value 4)
right = root.right (node value 7)
height++ => 2

Iteration 3 => 
left = root.left (node value 8)
right = root.right (node value 15)
height++ => 3

Iteration 4 => 
left = root.left (node value 16)
right = root.right (node value null)
height++ => 4
  • right is now null, loop exits.
  • As left != null ⇨ tree is not perfect binary tree

So, we recurse with:

return 1 + countNodes(root.left) + countNodes(root.right)
=> 
return 1 + countNodes(2) + countNodes(3)
Second call - root.left - root val = 2
Iteration 1 => 
left = root.left (node value 4)
right = root.right (node value 5)
height++ => 1

Iteration 2 => 
left = root.left (node value 8)
right = root.right (node value 11)
height++ => 2

Iteration 3 => 
left = root.left (node value 16)
right = root.right (node value null)
height++ => 3

As again, tree is not perfect tree (left != null), we recurse.

So, we recurse with:

return 1 + countNodes(root.left) + countNodes(root.right)
=> 
return 1 + countNodes(4) + countNodes(5)
Third call - root.left.left - root val = 4
Iteration 1 => 
left = root.left (node value 8)
right = root.right (node value 9)
height++ => 1

Iteration 2 => 
left = root.left (node value 16)
right = root.right (node value null)
height++ => 2

As again, tree is not perfect tree (left != null), we recurse.

So, we recurse with:

return 1 + countNodes(root.left) + countNodes(root.right)
=> 
return 1 + countNodes(8) + countNodes(9)
Fourth call - root.left.left.left - root val = 8
Iteration 1 => 
left = root.left (node value 16)
right = root.right (node value 17)
height++ => 1

Iteration 2 => 
left = root.left (null)
right = root.right (null)
height++ => 1

As again, tree is perfect tree (left == null), so, we calculate:

nodes = 2 ^ (1+1) - 1 = 3
Fifth call - root.left.left.right - root val = 9
Iteration 1 => 
left = root.left (null)
right = root.right (null)
height++ => 1

As again, tree is perfect tree (left == null), so, we calculate:

nodes = 2 ^ (10+1) - 1 = 1

Similarly we go on and combine the results.