Problem
Given the root
of a complete binary tree, return the number of the nodes in the tree.
Definition and Properties of Complete Binary Tree
Examples
Example 1:
1
/ \
2 3
/ \ /
4 5 6
Input: root = [1,2,3,4,5,6]
Output: 6
Solution
Method 1 - Basic Brute Force
Code
Java
public int countNodes(TreeNode root) {
if (root == null)
return 0;
return 1 + countNodes(root.left) + countNodes(root.right);
}
But the above solution misses on the properties of the complete binary tree.
Complexity
- ⏰ Time complexity:
O(n)
- 🧺 Space complexity:
O(n)
assuming recursion stack
Method 2 - Calculate Left Most and RIght Most Heights
Steps to solve this problem:
- get the height of left-most part
- get the height of right-most part
- when they are equal, the # of nodes = 2^h -1
- when they are not equal, recursively get # of nodes from left&right sub-trees
Case 1 - When Left and Right Most Height Are Equal
Refer the diagram below, then number of nodes = 2 ^h - 1
= 2 ^ 3 = 7
Case 2 - When They Are Not
Consider the case 2 in above diagram.
We have left height = 5, and right height = 4.
Now we recurse and look at left tree, they are l = 4
, and r = 3
with root being at 2
. Now again l
and r
not equal, so we again recurse and so on.
On right, we have complete tree rooted at 3
, l = r = 3
. So, we return quickly on the right side.
Code
Java
public int countNodes(TreeNode root) {
if (root == null) {
return 0;
}
int left = getLeftHeight(root) + 1;
int right = getRightHeight(root) + 1;
if (left == right) {
return (2<< (left - 1)) - 1;
}
return countNodes(root.left) + countNodes(root.right) + 1;
}
public int getLeftHeight(TreeNode n) {
if (n == null) return 0;
int height = 0;
while (n.left != null) {
height++;
n = n.left;
}
return height;
}
public int getRightHeight(TreeNode n) {
if (n == null) return 0;
int height = 0;
while (n.right != null) {
height++;
n = n.right;
}
return height;
}
Complexity
- ⏰ Time complexity:
O(h^2)
=O(log n)^2
, because at each level we are checking the left and right height again and again. - 🧺 Space complexity:
O(h)
, assuming recursion stack
Method 3 - Without Using Separate Height Method 🏆
It first walks all the way left and right to determine the height and whether it’s a full tree, meaning the last row is full. If so, then the answer is just 2^height-1. And since always at least one of the two recursive calls is such a full tree, at least one of the two calls immediately stops. Again we have runtime O(log(n)^2).
Code
Java
public int countNodes(TreeNode root) {
if (root == null)
return 0;
TreeNode left = root, right = root;
int height = 0;
// if right is null, left may still exist as it is complete BT
while (right != null) {
left = left.left;
right = right.right;
height++;
}
if (left == null) {
return (1 << height) - 1;
}
return 1 + countNodes(root.left) + countNodes(root.right);
}
Dry Run
Look at the case 2 again, to see the dry run. root = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17]
First call - root val = 1
left = root (root value 1)
right = root (root value 1)
Find the height of right side:
Iteration 1 =>
left = root.left (node value 2)
right = root.right (node value 3)
height++ => 1
Iteration 2 =>
left = root.left (node value 4)
right = root.right (node value 7)
height++ => 2
Iteration 3 =>
left = root.left (node value 8)
right = root.right (node value 15)
height++ => 3
Iteration 4 =>
left = root.left (node value 16)
right = root.right (node value null)
height++ => 4
right
is nownull
, loop exits.- As
left != null
⇨ tree is not perfect binary tree
So, we recurse with:
return 1 + countNodes(root.left) + countNodes(root.right)
=>
return 1 + countNodes(2) + countNodes(3)
Second call - root.left - root val = 2
Iteration 1 =>
left = root.left (node value 4)
right = root.right (node value 5)
height++ => 1
Iteration 2 =>
left = root.left (node value 8)
right = root.right (node value 11)
height++ => 2
Iteration 3 =>
left = root.left (node value 16)
right = root.right (node value null)
height++ => 3
As again, tree is not perfect tree (left != null
), we recurse.
So, we recurse with:
return 1 + countNodes(root.left) + countNodes(root.right)
=>
return 1 + countNodes(4) + countNodes(5)
Third call - root.left.left - root val = 4
Iteration 1 =>
left = root.left (node value 8)
right = root.right (node value 9)
height++ => 1
Iteration 2 =>
left = root.left (node value 16)
right = root.right (node value null)
height++ => 2
As again, tree is not perfect tree (left != null
), we recurse.
So, we recurse with:
return 1 + countNodes(root.left) + countNodes(root.right)
=>
return 1 + countNodes(8) + countNodes(9)
Fourth call - root.left.left.left - root val = 8
Iteration 1 =>
left = root.left (node value 16)
right = root.right (node value 17)
height++ => 1
Iteration 2 =>
left = root.left (null)
right = root.right (null)
height++ => 1
As again, tree is perfect tree (left == null
), so, we calculate:
nodes = 2 ^ (1+1) - 1 = 3
Fifth call - root.left.left.right - root val = 9
Iteration 1 =>
left = root.left (null)
right = root.right (null)
height++ => 1
As again, tree is perfect tree (left == null
), so, we calculate:
nodes = 2 ^ (10+1) - 1 = 1
…
Similarly we go on and combine the results.