Problem

Given a binary tree, find the size of the largest independent subset (LISS). A subset of nodes is independent if no two nodes in the subset are directly connected by an edge.

Example tree:

1
2
3
4
5
6
7
    9
   / \
  5  11
   / \   \
  3   7   10
   \
  4

Some valid independent subsets: [9,3,7,10], [5,11,4]. The largest in the example has size 4: [9,3,7,10].

Related LeetCode problem: not an exact match, but House Robber III is closely related (tree DP with non-adjacent constraint, maximizing sum instead of count). Added to related_problems.

Solution

We present two methods:

  • Method 1 — Naive recursion (exponential) to illustrate the idea.
  • Method 2 — Optimal O(n) DP using two-state per node (include / exclude) with a single post-order traversal.

Method 1 — Naive recursion (include grandchildren)

Intuition

If we include a node v in the independent set, we cannot include its immediate children, but we may include its grandchildren. If we exclude v we are free to include its children. The naive recursion compares these two choices.

Approach

  1. If root is null return 0.
  2. Compute sizeIncluding = 1 + sum(liss(g) for g in grandchildren of root).
  3. Compute sizeExcluding = liss(root.left) + liss(root.right).
  4. Return max(sizeIncluding, sizeExcluding).

This recalculates many subproblems and runs exponentially in the worst case.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Solution {
 public:
  int liss(TreeNode* root) {
    if (!root) return 0;
    int incl = 1;
    if (root->left) {
      if (root->left->left) incl += liss(root->left->left);
      if (root->left->right) incl += liss(root->left->right);
    }
    if (root->right) {
      if (root->right->left) incl += liss(root->right->left);
      if (root->right->right) incl += liss(root->right->right);
    }
    int excl = liss(root->left) + liss(root->right);
    return max(incl, excl);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
package main

func liss(root *TreeNode) int {
    if root == nil { return 0 }
    incl := 1
    if root.Left != nil {
        if root.Left.Left != nil { incl += liss(root.Left.Left) }
        if root.Left.Right != nil { incl += liss(root.Left.Right) }
    }
    if root.Right != nil {
        if root.Right.Left != nil { incl += liss(root.Right.Left) }
        if root.Right.Right != nil { incl += liss(root.Right.Right) }
    }
    excl := liss(root.Left) + liss(root.Right)
    if incl > excl { return incl }
    return excl
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution {
  public int liss(TreeNode root) {
    if (root == null) return 0;
    int incl = 1;
    if (root.left != null) {
      if (root.left.left != null) incl += liss(root.left.left);
      if (root.left.right != null) incl += liss(root.left.right);
    }
    if (root.right != null) {
      if (root.right.left != null) incl += liss(root.right.left);
      if (root.right.right != null) incl += liss(root.right.right);
    }
    int excl = liss(root.left) + liss(root.right);
    return Math.max(incl, excl);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
from typing import Optional

class Solution:
    def liss(self, root: Optional[TreeNode]) -> int:
        if root is None: return 0
        incl = 1
        if root.left:
            if root.left.left: incl += self.liss(root.left.left)
            if root.left.right: incl += self.liss(root.left.right)
        if root.right:
            if root.right.left: incl += self.liss(root.right.left)
            if root.right.right: incl += self.liss(root.right.right)
        excl = self.liss(root.left) + self.liss(root.right)
        return max(incl, excl)

Complexity

  • ⏰ Time complexity: O(2^n) in the worst case (exponential) — repeated overlapping subproblems.
  • 🧺 Space complexity: O(h) recursion stack, h is tree height.

Method 2 — Optimal DP (two-state per node)

Intuition

For each node v compute two values: incl[v] = size of largest independent set of subtree rooted at v when v is included; excl[v] = size when v is excluded. Then:

  • incl[v] = 1 + sum(excl[child] for child in children) (if v is included, children cannot be included)
  • excl[v] = sum(max(incl[child], excl[child]) for child in children) (if v is excluded, each child may be included or excluded)

A post-order traversal computes these values in O(n) time.

Approach

  1. Do a post-order DFS. For each node v compute and return (incl, excl) as described.
  2. Answer for the whole tree is max(incl[root], excl[root]).
  3. This visits each node once and performs O(1) work per node.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution {
 public:
  pair<int,int> dfs(TreeNode* node) {
    if (!node) return {0,0};
    auto L = dfs(node->left);
    auto R = dfs(node->right);
    int incl = 1 + L.second + R.second; // include node -> children excluded
    int excl = max(L.first, L.second) + max(R.first, R.second); // exclude node
    return {incl, excl};
  }

  int liss(TreeNode* root) {
    auto res = dfs(root);
    return max(res.first, res.second);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
package main

func dfs(node *TreeNode) (int,int) {
    if node == nil { return 0,0 }
    lf, le := dfs(node.Left)
    rf, re := dfs(node.Right)
    incl := 1 + le + re
    excl := max(lf, le) + max(rf, re)
    return incl, excl
}

func liss(root *TreeNode) int {
    a,b := dfs(root)
    return max(a,b)
}

func max(a,b int) int { if a>b { return a }; return b }
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
class Solution {
  private int[] dfs(TreeNode node) {
    if (node == null) return new int[]{0,0};
    int[] L = dfs(node.left);
    int[] R = dfs(node.right);
    int incl = 1 + L[1] + R[1];
    int excl = Math.max(L[0], L[1]) + Math.max(R[0], R[1]);
    return new int[]{incl, excl};
  }

  public int liss(TreeNode root) {
    int[] res = dfs(root);
    return Math.max(res[0], res[1]);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
from typing import Tuple, Optional

class Solution:
    def dfs(self, node: Optional[TreeNode]) -> Tuple[int,int]:
        if not node: return (0,0)
        li, le = self.dfs(node.left)
        ri, re = self.dfs(node.right)
        incl = 1 + le + re
        excl = max(li, le) + max(ri, re)
        return (incl, excl)

    def liss(self, root: Optional[TreeNode]) -> int:
        incl, excl = self.dfs(root)
        return max(incl, excl)

Complexity

  • ⏰ Time complexity: O(n) — single post-order traversal visiting each node once.
  • 🧺 Space complexity: O(h) for recursion stack (or O(n) worst-case for a skewed tree), plus O(1) auxiliary per node if values are returned rather than stored.

Created at: 2018-02-07T09:21:59+01:00 Updated at: 2018-02-07T09:21:59+01:00