Problem

Given the root of a binary tree, construct a 0-indexed m x n string matrix res that represents a formatted layout of the tree. The formatted layout matrix should be constructed using the following rules:

  • The height of the tree is height and the number of rows m should be equal to height + 1.
  • The number of columns n should be equal to 2height+1 - 1.
  • Place the root node in the middle of the top row (more formally, at location res[0][(n-1)/2]).
  • For each node that has been placed in the matrix at position res[r][c], place its left child at res[r+1][c-2height-r-1] and its right child at res[r+1][c+2height-r-1].
  • Continue this process until all the nodes in the tree have been placed.
  • Any empty cells should contain the empty string "".

Return the constructed matrix res.

Examples

Example 1:

  1
 / 
2
Input: root = [1,2]
Output:
[ ["","1",""],
 ["2","",""] ]

Example 2:

  1
 / \
2   3
 \
  4
Input: root = [1,2,3,null,4]
Output:
[ ["","","","1","","",""],
 ["","2","","","","3",""],
 ["","","4","","","",""] ]

Solution

Method 1 -

Let’s start by defining how we can compute the height of a given binary tree. Then we can use this height to create our matrix and populate it according to the rules.

Here’s the step-by-step solution in Python:

  1. Compute the Height of the Tree: The height is the longest path from the root to a leaf node.
  2. Calculate the Dimensions of the Matrix: The number of rows m is height + 1 and the number of columns n is 2^(height + 1) - 1.
  3. Place the Root Node: Place it in the middle of the top row.
  4. Place the Children Nodes: Use the given formulas to place each node’s children in the correct position.
  5. Recursive Filling: Use a recursive function to fill each node and its children into the matrix.

Code

Java
public class Solution {

	public List < List<String>> printTree(TreeNode root) {
		int height = getHeight(root);
		int m = height + 1;
		int n = (1<< (height + 1)) - 1; // this is 2^(height + 1) - 1

		// Create an empty matrix
		String[][] res = new String[m][n];

		for (int i = 0; i < m; i++) {
			Arrays.fill(res[i], "");
		}

		// Fill the matrix with tree values
		fillMatrix(res, root, 0, (n - 1) / 2, height);

		List < List<String>> formatted_matrix = new ArrayList<>();

		for (int i = 0; i < m; i++) {
			formatted_matrix.add(Arrays.asList(res[i]));
		}

		return formatted_matrix;
	}

	public int getHeight(TreeNode root) {
		if (root == null) {
			return -1;
		}

		return 1 + Math.max(getHeight(root.left), getHeight(root.right));
	}

	public void fillMatrix(String[][] matrix, TreeNode root, int r, int c, int height) {
		if (root == null) {
			return;
		}

		matrix[r][c] = Integer.toString(root.val);

		if (root.left != null) {
			fillMatrix(matrix, root.left, r + 1, c - (1<< (height - r - 1)), height);
		}

		if (root.right != null) {
			fillMatrix(matrix, root.right, r + 1, c + (1<< (height - r - 1)), height);
		}
	}

}
Python
def print_tree(root):
    # Calculate the height of the tree
    height = get_height(root)
    m = height + 1
    n = 2 ** (m) - 1

    # Create an empty matrix
    res = [["" for _ in range(n)] for _ in range(m)]

    # Fill the matrix with tree values
    fill_matrix(res, root, 0, (n - 1) // 2, height)

    return res

def get_height(root):
    if not root:
        return -1
    return 1 + max(get_height(root.left), get_height(root.right))


def fill_matrix(matrix, root, r, c, height):
    if not root:
        return
    matrix[r][c] = str(root.val)
    if root.left:
        fill_matrix(matrix, root.left, r + 1, c - 2 ** (height - r - 1), height)
    if root.right:
        fill_matrix(matrix, root.right, r + 1, c + 2 ** (height - r - 1), height)

Complexity

  • Time: O(n), where is n is number of nodes in tree. Calculating the height takes O(n) times, and then filling the matrix again takes O(n) as well.
  • Space: O(2^n). Recursive stack takes O(h) space, and in worst case O(n). Matrix storage takes O(r*c) where r = h + 1 as we save height + 1 rows, and columns c = 2 ^ h - 1, so it is h*2^h or 2^(h+1), and in worst case h = n, hence O(2^n).