Problem

You are given the root of a binary tree with unique values, and an integer start. At minute 0, an infection starts from the node with value start.

Each minute, a node becomes infected if:

  • The node is currently uninfected.
  • The node is adjacent to an infected node.

Return the number of minutes needed for the entire tree to be infected.

Examples

Example 1:

graph TD;
A(1) --- E(5) & C(3):::blue
E ~~~ N1:::hidden
E --- D(4)
D --- I(9) & B(2)
C --- J(10) & F(6)

classDef blue fill:#ADD8E6,stroke:#000,stroke-width:1px;
classDef hidden display:none
  
1
2
3
4
5
6
7
8
9
Input: root = [1,5,3,null,4,10,6,9,2], start = 3
Output: 4
Explanation: The following nodes are infected during:
- Minute 0: Node 3
- Minute 1: Nodes 1, 10 and 6
- Minute 2: Node 5
- Minute 3: Node 4
- Minute 4: Nodes 9 and 2
It takes 4 minutes for the whole tree to be infected so we return 4.

Example 2:

graph TD;
A(1):::blue

classDef blue fill:#ADD8E6,stroke:#000,stroke-width:1px;
  
1
2
3
Input: root = [1], start = 1
Output: 0
Explanation: At minute 0, the only node in the tree is infected so we return 0.

Constraints:

  • The number of nodes in the tree is in the range [1, 105].
  • 1 <= Node.val <= 105
  • Each node has a unique value.
  • A node with a value of start exists in the tree.

Solution

Method 1 – BFS with Parent Mapping

Intuition

The infection spreads to adjacent nodes (parent or children) each minute, similar to a breadth-first search (BFS) traversal. By mapping each node to its parent, we can simulate the infection spreading in all directions from the starting node.

Approach

  1. Traverse the tree to build a mapping from each node to its parent and locate the start node.
  2. Use BFS starting from the start node, marking nodes as infected (visited).
  3. At each minute, infect all uninfected adjacent nodes (left, right, parent).
  4. Track the time taken until all nodes are infected.
  5. Return the total minutes needed.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution {
public:
  int amountOfTime(TreeNode* root, int start) {
    unordered_map<int, TreeNode*> par;
    queue<TreeNode*> q;
    function<void(TreeNode*, TreeNode*)> dfs = [&](TreeNode* node, TreeNode* p) {
      if (!node) return;
      par[node->val] = p;
      dfs(node->left, node);
      dfs(node->right, node);
    };
    dfs(root, nullptr);
    unordered_set<int> vis;
    TreeNode* st = nullptr;
    function<void(TreeNode*)> find = [&](TreeNode* node) {
      if (!node) return;
      if (node->val == start) st = node;
      find(node->left);
      find(node->right);
    };
    find(root);
    q.push(st);
    vis.insert(st->val);
    int ans = -1;
    while (!q.empty()) {
      int sz = q.size();
      ans++;
      for (int i = 0; i < sz; ++i) {
        TreeNode* cur = q.front(); q.pop();
        for (TreeNode* nxt : {cur->left, cur->right, par[cur->val]}) {
          if (nxt && !vis.count(nxt->val)) {
            vis.insert(nxt->val);
            q.push(nxt);
          }
        }
      }
    }
    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
func amountOfTime(root *TreeNode, start int) int {
  par := map[int]*TreeNode{}
  var dfs func(*TreeNode, *TreeNode)
  dfs = func(node, p *TreeNode) {
    if node == nil { return }
    par[node.Val] = p
    dfs(node.Left, node)
    dfs(node.Right, node)
  }
  dfs(root, nil)
  var st *TreeNode
  var find func(*TreeNode)
  find = func(node *TreeNode) {
    if node == nil { return }
    if node.Val == start { st = node }
    find(node.Left)
    find(node.Right)
  }
  find(root)
  vis := map[int]bool{}
  q := []*TreeNode{st}
  vis[st.Val] = true
  ans := -1
  for len(q) > 0 {
    sz := len(q)
    ans++
    for i := 0; i < sz; i++ {
      cur := q[0]
      q = q[1:]
      for _, nxt := range []*TreeNode{cur.Left, cur.Right, par[cur.Val]} {
        if nxt != nil && !vis[nxt.Val] {
          vis[nxt.Val] = true
          q = append(q, nxt)
        }
      }
    }
  }
  return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Solution {
  public int amountOfTime(TreeNode root, int start) {
    Map<Integer, TreeNode> par = new HashMap<>();
    Queue<TreeNode> q = new LinkedList<>();
    buildParent(root, null, par);
    TreeNode st = findStart(root, start);
    Set<Integer> vis = new HashSet<>();
    q.offer(st);
    vis.add(st.val);
    int ans = -1;
    while (!q.isEmpty()) {
      int sz = q.size();
      ans++;
      for (int i = 0; i < sz; i++) {
        TreeNode cur = q.poll();
        for (TreeNode nxt : new TreeNode[]{cur.left, cur.right, par.get(cur.val)}) {
          if (nxt != null && !vis.contains(nxt.val)) {
            vis.add(nxt.val);
            q.offer(nxt);
          }
        }
      }
    }
    return ans;
  }
  private void buildParent(TreeNode node, TreeNode p, Map<Integer, TreeNode> par) {
    if (node == null) return;
    par.put(node.val, p);
    buildParent(node.left, node, par);
    buildParent(node.right, node, par);
  }
  private TreeNode findStart(TreeNode node, int start) {
    if (node == null) return null;
    if (node.val == start) return node;
    TreeNode l = findStart(node.left, start);
    if (l != null) return l;
    return findStart(node.right, start);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Solution {
  fun amountOfTime(root: TreeNode?, start: Int): Int {
    val par = mutableMapOf<Int, TreeNode?>()
    fun dfs(node: TreeNode?, p: TreeNode?) {
      if (node == null) return
      par[node.`val`] = p
      dfs(node.left, node)
      dfs(node.right, node)
    }
    dfs(root, null)
    fun find(node: TreeNode?): TreeNode? {
      if (node == null) return null
      if (node.`val` == start) return node
      return find(node.left) ?: find(node.right)
    }
    val st = find(root)
    val vis = mutableSetOf<Int>()
    val q = ArrayDeque<TreeNode>()
    if (st != null) {
      q.add(st)
      vis.add(st.`val`)
    }
    var ans = -1
    while (q.isNotEmpty()) {
      repeat(q.size) {
        val cur = q.removeFirst()
        for (nxt in listOf(cur.left, cur.right, par[cur.`val`])) {
          if (nxt != null && nxt.`val` !in vis) {
            vis.add(nxt.`val`)
            q.add(nxt)
          }
        }
      }
      ans++
    }
    return ans
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution:
  def amountOfTime(self, root: Optional[TreeNode], start: int) -> int:
    par: dict[int, Optional[TreeNode]] = {}
    def dfs(node: Optional[TreeNode], p: Optional[TreeNode]) -> None:
      if not node: return
      par[node.val] = p
      dfs(node.left, node)
      dfs(node.right, node)
    dfs(root, None)
    def find(node: Optional[TreeNode]) -> Optional[TreeNode]:
      if not node: return None
      if node.val == start: return node
      return find(node.left) or find(node.right)
    st = find(root)
    vis: set[int] = set()
    q: list[TreeNode] = [st]
    vis.add(st.val)
    ans = -1
    while q:
      for _ in range(len(q)):
        cur = q.pop(0)
        for nxt in (cur.left, cur.right, par[cur.val]):
          if nxt and nxt.val not in vis:
            vis.add(nxt.val)
            q.append(nxt)
      ans += 1
    return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
impl Solution {
  pub fn amount_of_time(root: Option<Rc<RefCell<TreeNode>>>, start: i32) -> i32 {
    use std::collections::{HashMap, HashSet, VecDeque};
    fn dfs(node: Option<Rc<RefCell<TreeNode>>>, p: Option<Rc<RefCell<TreeNode>>>, par: &mut HashMap<i32, Option<Rc<RefCell<TreeNode>>>>) {
      if let Some(n) = node.clone() {
        let val = n.borrow().val;
        par.insert(val, p.clone());
        dfs(n.borrow().left.clone(), node.clone(), par);
        dfs(n.borrow().right.clone(), node.clone(), par);
      }
    }
    fn find(node: Option<Rc<RefCell<TreeNode>>>, start: i32) -> Option<Rc<RefCell<TreeNode>>> {
      if let Some(n) = node.clone() {
        if n.borrow().val == start { return node; }
        let l = find(n.borrow().left.clone(), start);
        if l.is_some() { return l; }
        return find(n.borrow().right.clone(), start);
      }
      None
    }
    let mut par = HashMap::new();
    dfs(root.clone(), None, &mut par);
    let st = find(root.clone(), start);
    let mut vis = HashSet::new();
    let mut q = VecDeque::new();
    if let Some(s) = st.clone() {
      q.push_back(s.clone());
      vis.insert(s.borrow().val);
    }
    let mut ans = -1;
    while !q.is_empty() {
      for _ in 0..q.len() {
        let cur = q.pop_front().unwrap();
        let val = cur.borrow().val;
        for nxt in [
          cur.borrow().left.clone(),
          cur.borrow().right.clone(),
          par.get(&val).cloned().flatten()
        ] {
          if let Some(n) = nxt {
            let v = n.borrow().val;
            if !vis.contains(&v) {
              vis.insert(v);
              q.push_back(n);
            }
          }
        }
      }
      ans += 1;
    }
    ans
  }
}

Complexity

  • ⏰ Time complexity: O(n), where n is the number of nodes (each node is visited at most twice).
  • 🧺 Space complexity: O(n) for parent mapping and visited set.