Problem#
You are given the root
of a binary tree with unique values, and an integer start
. At minute 0
, an infection starts from the node with value start
.
Each minute, a node becomes infected if:
The node is currently uninfected.
The node is adjacent to an infected node.
Return the number of minutes needed for the entire tree to be infected.
Examples#
Example 1:
graph TD;
A(1) --- E(5) & C(3):::blue
E ~~~ N1:::hidden
E --- D(4)
D --- I(9) & B(2)
C --- J(10) & F(6)
classDef blue fill:#ADD8E6,stroke:#000,stroke-width:1px;
classDef hidden display:none
1
2
3
4
5
6
7
8
9
Input: root = [ 1 , 5 , 3 , null , 4 , 10 , 6 , 9 , 2 ], start = 3
Output: 4
Explanation: The following nodes are infected during:
- Minute 0 : Node 3
- Minute 1 : Nodes 1 , 10 and 6
- Minute 2 : Node 5
- Minute 3 : Node 4
- Minute 4 : Nodes 9 and 2
It takes 4 minutes for the whole tree to be infected so we return 4.
Example 2:
graph TD;
A(1):::blue
classDef blue fill:#ADD8E6,stroke:#000,stroke-width:1px;
1
2
3
Input: root = [ 1 ], start = 1
Output: 0
Explanation: At minute 0 , the only node in the tree is infected so we return 0.
Constraints:
The number of nodes in the tree is in the range [1, 105]
.
1 <= Node.val <= 105
Each node has a unique value.
A node with a value of start
exists in the tree.
Solution#
Method 1 – BFS with Parent Mapping#
Intuition#
The infection spreads to adjacent nodes (parent or children) each minute, similar to a breadth-first search (BFS) traversal. By mapping each node to its parent, we can simulate the infection spreading in all directions from the starting node.
Approach#
Traverse the tree to build a mapping from each node to its parent and locate the start node.
Use BFS starting from the start node, marking nodes as infected (visited).
At each minute, infect all uninfected adjacent nodes (left, right, parent).
Track the time taken until all nodes are infected.
Return the total minutes needed.
Code#
Cpp
Go
Java
Kotlin
Python
Rust
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution {
public :
int amountOfTime(TreeNode* root, int start) {
unordered_map< int , TreeNode*> par;
queue< TreeNode*> q;
function< void (TreeNode* , TreeNode* )> dfs = [& ](TreeNode* node, TreeNode* p) {
if (! node) return ;
par[node-> val] = p;
dfs(node-> left, node);
dfs(node-> right, node);
};
dfs(root, nullptr );
unordered_set< int > vis;
TreeNode* st = nullptr ;
function< void (TreeNode* )> find = [& ](TreeNode* node) {
if (! node) return ;
if (node-> val == start) st = node;
find(node-> left);
find(node-> right);
};
find(root);
q.push(st);
vis.insert(st-> val);
int ans = - 1 ;
while (! q.empty()) {
int sz = q.size();
ans++ ;
for (int i = 0 ; i < sz; ++ i) {
TreeNode* cur = q.front(); q.pop();
for (TreeNode* nxt : {cur-> left, cur-> right, par[cur-> val]}) {
if (nxt && ! vis.count(nxt-> val)) {
vis.insert(nxt-> val);
q.push(nxt);
}
}
}
}
return ans;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
func amountOfTime (root * TreeNode , start int ) int {
par := map [int ]* TreeNode {}
var dfs func (* TreeNode , * TreeNode )
dfs = func (node , p * TreeNode ) {
if node == nil { return }
par [node .Val ] = p
dfs (node .Left , node )
dfs (node .Right , node )
}
dfs (root , nil )
var st * TreeNode
var find func (* TreeNode )
find = func (node * TreeNode ) {
if node == nil { return }
if node .Val == start { st = node }
find (node .Left )
find (node .Right )
}
find (root )
vis := map [int ]bool {}
q := []* TreeNode {st }
vis [st .Val ] = true
ans := - 1
for len(q ) > 0 {
sz := len(q )
ans ++
for i := 0 ; i < sz ; i ++ {
cur := q [0 ]
q = q [1 :]
for _ , nxt := range []* TreeNode {cur .Left , cur .Right , par [cur .Val ]} {
if nxt != nil && !vis [nxt .Val ] {
vis [nxt .Val ] = true
q = append(q , nxt )
}
}
}
}
return ans
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Solution {
public int amountOfTime (TreeNode root, int start) {
Map< Integer, TreeNode> par = new HashMap<> ();
Queue< TreeNode> q = new LinkedList<> ();
buildParent(root, null , par);
TreeNode st = findStart(root, start);
Set< Integer> vis = new HashSet<> ();
q.offer (st);
vis.add (st.val );
int ans = - 1;
while (! q.isEmpty ()) {
int sz = q.size ();
ans++ ;
for (int i = 0; i < sz; i++ ) {
TreeNode cur = q.poll ();
for (TreeNode nxt : new TreeNode[] {cur.left , cur.right , par.get (cur.val )}) {
if (nxt != null && ! vis.contains (nxt.val )) {
vis.add (nxt.val );
q.offer (nxt);
}
}
}
}
return ans;
}
private void buildParent (TreeNode node, TreeNode p, Map< Integer, TreeNode> par) {
if (node == null ) return ;
par.put (node.val , p);
buildParent(node.left , node, par);
buildParent(node.right , node, par);
}
private TreeNode findStart (TreeNode node, int start) {
if (node == null ) return null ;
if (node.val == start) return node;
TreeNode l = findStart(node.left , start);
if (l != null ) return l;
return findStart(node.right , start);
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Solution {
fun amountOfTime (root: TreeNode?, start: Int): Int {
val par = mutableMapOf<Int, TreeNode?>()
fun dfs (node: TreeNode?, p: TreeNode?) {
if (node == null ) return
par[node.`val`] = p
dfs(node.left, node)
dfs(node.right, node)
}
dfs(root, null )
fun find (node: TreeNode?): TreeNode? {
if (node == null ) return null
if (node.`val` == start) return node
return find(node.left) ?: find(node.right)
}
val st = find(root)
val vis = mutableSetOf<Int>()
val q = ArrayDeque<TreeNode>()
if (st != null ) {
q.add(st)
vis.add(st.`val`)
}
var ans = -1
while (q.isNotEmpty()) {
repeat(q.size) {
val cur = q.removeFirst()
for (nxt in listOf(cur.left, cur.right, par[cur.`val`])) {
if (nxt != null && nxt.`val` !in vis) {
vis.add(nxt.`val`)
q.add(nxt)
}
}
}
ans++
}
return ans
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution :
def amountOfTime (self, root: Optional[TreeNode], start: int) -> int:
par: dict[int, Optional[TreeNode]] = {}
def dfs (node: Optional[TreeNode], p: Optional[TreeNode]) -> None :
if not node: return
par[node. val] = p
dfs(node. left, node)
dfs(node. right, node)
dfs(root, None )
def find (node: Optional[TreeNode]) -> Optional[TreeNode]:
if not node: return None
if node. val == start: return node
return find(node. left) or find(node. right)
st = find(root)
vis: set[int] = set()
q: list[TreeNode] = [st]
vis. add(st. val)
ans = - 1
while q:
for _ in range(len(q)):
cur = q. pop(0 )
for nxt in (cur. left, cur. right, par[cur. val]):
if nxt and nxt. val not in vis:
vis. add(nxt. val)
q. append(nxt)
ans += 1
return ans
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
impl Solution {
pub fn amount_of_time (root: Option< Rc< RefCell< TreeNode>>> , start: i32 ) -> i32 {
use std::collections::{HashMap, HashSet, VecDeque};
fn dfs (node: Option< Rc< RefCell< TreeNode>>> , p: Option< Rc< RefCell< TreeNode>>> , par: & mut HashMap< i32 , Option< Rc< RefCell< TreeNode>>>> ) {
if let Some(n) = node.clone() {
let val = n.borrow().val;
par.insert(val, p.clone());
dfs(n.borrow().left.clone(), node.clone(), par);
dfs(n.borrow().right.clone(), node.clone(), par);
}
}
fn find (node: Option< Rc< RefCell< TreeNode>>> , start: i32 ) -> Option< Rc< RefCell< TreeNode>>> {
if let Some(n) = node.clone() {
if n.borrow().val == start { return node; }
let l = find(n.borrow().left.clone(), start);
if l.is_some() { return l; }
return find(n.borrow().right.clone(), start);
}
None
}
let mut par = HashMap::new();
dfs(root.clone(), None, & mut par);
let st = find(root.clone(), start);
let mut vis = HashSet::new();
let mut q = VecDeque::new();
if let Some(s) = st.clone() {
q.push_back(s.clone());
vis.insert(s.borrow().val);
}
let mut ans = - 1 ;
while ! q.is_empty() {
for _ in 0 .. q.len() {
let cur = q.pop_front().unwrap();
let val = cur.borrow().val;
for nxt in [
cur.borrow().left.clone(),
cur.borrow().right.clone(),
par.get(& val).cloned().flatten()
] {
if let Some(n) = nxt {
let v = n.borrow().val;
if ! vis.contains(& v) {
vis.insert(v);
q.push_back(n);
}
}
}
}
ans += 1 ;
}
ans
}
}
Complexity#
⏰ Time complexity: O(n)
, where n
is the number of nodes (each node is visited at most twice).
🧺 Space complexity: O(n)
for parent mapping and visited set.