Problem

Implement a function to check if a tree is balanced.

Definition

Height-balanced binary tree: is defined as a binary tree in which the depth of the two sub-trees of every node never differ by more than 1.

Examples

Example 1

  3
 / \
9   20
   / \
  15  7
Input: root = [3,9,20,null,null,15,7]
Output: true

Example 2:

	  1
	 / \
    2   2
   / \   
  3	  3  
 / \   
4	4    
Input: root = [1,2,2,3,3,null,null,4,4]
Output: false

Example 3:

Input: root = []
Output: true

Solution

Method 1 - Difference of Height at Each Node Should Not Be Greater Than 1

Code

Java
public boolean isBalanced(Node root) {
    if (root == null) {
        return true; //tree is empty
    }
    int lh = height(root.left);
    int rh = height(root.right);

    if (isBalanced(root.left) && isBalanced(root.right) &&
        Math.abs(lh - rh) < = 1) {
        return true;
    } else {
	    return false;
    }
    
    return true;
}

Complexity

  • Time complexity - O(n^2). For each node, we are calculating height again and again, which takes O(n) time. Hence O(n^2).
  • Space Complexity - O(1)

Refer the problem - Maximum Depth of Binary Tree to see how to get the height.

Method 2 - Optimise Height Calculation From Method 1

When the tree is massively unbalanced - for eg. a million nodes deep on one side and three deep on the other, then method 1 will blow up. To fix this, we can optimize on how we calculate the height in previous solution.

We are seeing that we are again and again calculating the height of the tree. This can be done in 1 method itself, if we use pointers:

Code

C
bool isBalanced(Node root, out int* height) {
    if (root == null) {
        height = 0
        return true
    }
    int heightLeft = 0, heightRight = 0;
    //height is set by the isHeightBalanced function
    boolean balance = isHeightBalanced(root.left, &heightLeft) && isHeightBalanced(root.right, &heightRight);
    height = max(*heightLeft, *heightRight) + 1
    return balance and abs(*heightLeft - *heightRight) <= 1
}
Java

How to do that in java? In java, we can use a object, as java is pass by value for primitive types.

Using Pair class
public boolean isBalanced(TreeNode root) {
	return heightBalancedPair(root).balanced;
}

static class HeightBalancedPair {
    int height;
    boolean balanced;
}

private HeightBalancedPair heightBalancedPair(TreeNode root) {
    if (root == null) {
        return new HeightBalancedPair(0, true);
    }
    HeightBalancedPair left = heightBalancedPair(root.left);
    HeightBalancedPair right= heightBalancedPair(root.right);

    int height = Math.max(left.height, right.height) + 1;
    boolean balanced = Math.abs(left.height - right.height) < 2 && left.balanced && right.balanced;

    return new HeightBalancedPair(height, balanced);
}
Using -1 for denoting “not-balanced”

Another way to write the above code, without using boolean variable, we can get rid of it and instead return the height. We can check the height of each subtree as we recurse down from the root. If the subtree is balanced, then the function return actual height of the subtree. If the subtree is not balanced, then the function returns -1.

public int checkHeight(TreeNode root) {
    if (root == null)
        return 0;

    int leftHeight = checkHeight(root.left);
    if (leftHeight == -1)
        return -1; //not balanced

    int rightHeight = checkHeight(root.right);
    if (rightHeight == -1)
        return -1;

    int heightDiff = Math.abs(leftHeight - rightHeight);
    if (heightDiff > 1)
        return -1;
    else
        return Math.max(leftHeight, rightHeight) + 1;
}

public boolean isBalanced(TreeNode root) {
    if (checkHeight(root) == -1)
        return false;
    else
        return true;
}

Complexity

  • Time complexity - O(n)

Though this looks like a small optimization, but this will save lots of time.

Method 3 - Difference Between minDepth and maxDepth Should Be Less Than 1

A tree is considered balanced when the difference between the min depth and max depth does not exceed 1.

Recursive algorithms always work well on trees, so here’s some code.

Code

C
bool isBalanced(Node * root) {
	return (maxDepth(root) - minDepth(root)) <= 1
}

int minDepth(Node * root) {
	if (!root) {
		return 0;
	}

	return 1 + min(minDepth(root -> left),
		minDepth(root -> right));
}


int maxDepth(Node * root) {
	if (!root) {
		return 0;
	}

	return 1 + max(maxDepth(root -> left),
		maxDepth(root -> right));
}

Complexity

Time complexity- O(n) For each node we calculate maxDepth - takes O(n) time, and minDepth takes O(n) time, and finally the boolean operation takes O(1) time.