Amount of Time for Binary Tree to Be Infected
MediumUpdated: Aug 2, 2025
Practice on:
Problem
You are given the root of a binary tree with unique values, and an integer start. At minute 0, an infection starts from the node with value start.
Each minute, a node becomes infected if:
- The node is currently uninfected.
- The node is adjacent to an infected node.
Return the number of minutes needed for the entire tree to be infected.
Examples
Example 1:
graph TD; A(1) --- E(5) & C(3):::blue E ~~~ N1:::hidden E --- D(4) D --- I(9) & B(2) C --- J(10) & F(6) classDef blue fill:#ADD8E6,stroke:#000,stroke-width:1px; classDef hidden display:none
Input: root = [1,5,3,null,4,10,6,9,2], start = 3
Output: 4
Explanation: The following nodes are infected during:
- Minute 0: Node 3
- Minute 1: Nodes 1, 10 and 6
- Minute 2: Node 5
- Minute 3: Node 4
- Minute 4: Nodes 9 and 2
It takes 4 minutes for the whole tree to be infected so we return 4.
Example 2:
graph TD; A(1):::blue classDef blue fill:#ADD8E6,stroke:#000,stroke-width:1px;
Input: root = [1], start = 1
Output: 0
Explanation: At minute 0, the only node in the tree is infected so we return 0.
Constraints:
- The number of nodes in the tree is in the range
[1, 105]. 1 <= Node.val <= 105- Each node has a unique value.
- A node with a value of
startexists in the tree.
Solution
Method 1 – BFS with Parent Mapping
Intuition
The infection spreads to adjacent nodes (parent or children) each minute, similar to a breadth-first search (BFS) traversal. By mapping each node to its parent, we can simulate the infection spreading in all directions from the starting node.
Approach
- Traverse the tree to build a mapping from each node to its parent and locate the start node.
- Use BFS starting from the start node, marking nodes as infected (visited).
- At each minute, infect all uninfected adjacent nodes (left, right, parent).
- Track the time taken until all nodes are infected.
- Return the total minutes needed.
Code
C++
class Solution {
public:
int amountOfTime(TreeNode* root, int start) {
unordered_map<int, TreeNode*> par;
queue<TreeNode*> q;
function<void(TreeNode*, TreeNode*)> dfs = [&](TreeNode* node, TreeNode* p) {
if (!node) return;
par[node->val] = p;
dfs(node->left, node);
dfs(node->right, node);
};
dfs(root, nullptr);
unordered_set<int> vis;
TreeNode* st = nullptr;
function<void(TreeNode*)> find = [&](TreeNode* node) {
if (!node) return;
if (node->val == start) st = node;
find(node->left);
find(node->right);
};
find(root);
q.push(st);
vis.insert(st->val);
int ans = -1;
while (!q.empty()) {
int sz = q.size();
ans++;
for (int i = 0; i < sz; ++i) {
TreeNode* cur = q.front(); q.pop();
for (TreeNode* nxt : {cur->left, cur->right, par[cur->val]}) {
if (nxt && !vis.count(nxt->val)) {
vis.insert(nxt->val);
q.push(nxt);
}
}
}
}
return ans;
}
};
Go
func amountOfTime(root *TreeNode, start int) int {
par := map[int]*TreeNode{}
var dfs func(*TreeNode, *TreeNode)
dfs = func(node, p *TreeNode) {
if node == nil { return }
par[node.Val] = p
dfs(node.Left, node)
dfs(node.Right, node)
}
dfs(root, nil)
var st *TreeNode
var find func(*TreeNode)
find = func(node *TreeNode) {
if node == nil { return }
if node.Val == start { st = node }
find(node.Left)
find(node.Right)
}
find(root)
vis := map[int]bool{}
q := []*TreeNode{st}
vis[st.Val] = true
ans := -1
for len(q) > 0 {
sz := len(q)
ans++
for i := 0; i < sz; i++ {
cur := q[0]
q = q[1:]
for _, nxt := range []*TreeNode{cur.Left, cur.Right, par[cur.Val]} {
if nxt != nil && !vis[nxt.Val] {
vis[nxt.Val] = true
q = append(q, nxt)
}
}
}
}
return ans
}
Java
class Solution {
public int amountOfTime(TreeNode root, int start) {
Map<Integer, TreeNode> par = new HashMap<>();
Queue<TreeNode> q = new LinkedList<>();
buildParent(root, null, par);
TreeNode st = findStart(root, start);
Set<Integer> vis = new HashSet<>();
q.offer(st);
vis.add(st.val);
int ans = -1;
while (!q.isEmpty()) {
int sz = q.size();
ans++;
for (int i = 0; i < sz; i++) {
TreeNode cur = q.poll();
for (TreeNode nxt : new TreeNode[]{cur.left, cur.right, par.get(cur.val)}) {
if (nxt != null && !vis.contains(nxt.val)) {
vis.add(nxt.val);
q.offer(nxt);
}
}
}
}
return ans;
}
private void buildParent(TreeNode node, TreeNode p, Map<Integer, TreeNode> par) {
if (node == null) return;
par.put(node.val, p);
buildParent(node.left, node, par);
buildParent(node.right, node, par);
}
private TreeNode findStart(TreeNode node, int start) {
if (node == null) return null;
if (node.val == start) return node;
TreeNode l = findStart(node.left, start);
if (l != null) return l;
return findStart(node.right, start);
}
}
Kotlin
class Solution {
fun amountOfTime(root: TreeNode?, start: Int): Int {
val par = mutableMapOf<Int, TreeNode?>()
fun dfs(node: TreeNode?, p: TreeNode?) {
if (node == null) return
par[node.`val`] = p
dfs(node.left, node)
dfs(node.right, node)
}
dfs(root, null)
fun find(node: TreeNode?): TreeNode? {
if (node == null) return null
if (node.`val` == start) return node
return find(node.left) ?: find(node.right)
}
val st = find(root)
val vis = mutableSetOf<Int>()
val q = ArrayDeque<TreeNode>()
if (st != null) {
q.add(st)
vis.add(st.`val`)
}
var ans = -1
while (q.isNotEmpty()) {
repeat(q.size) {
val cur = q.removeFirst()
for (nxt in listOf(cur.left, cur.right, par[cur.`val`])) {
if (nxt != null && nxt.`val` !in vis) {
vis.add(nxt.`val`)
q.add(nxt)
}
}
}
ans++
}
return ans
}
}
Python
class Solution:
def amountOfTime(self, root: Optional[TreeNode], start: int) -> int:
par: dict[int, Optional[TreeNode]] = {}
def dfs(node: Optional[TreeNode], p: Optional[TreeNode]) -> None:
if not node: return
par[node.val] = p
dfs(node.left, node)
dfs(node.right, node)
dfs(root, None)
def find(node: Optional[TreeNode]) -> Optional[TreeNode]:
if not node: return None
if node.val == start: return node
return find(node.left) or find(node.right)
st = find(root)
vis: set[int] = set()
q: list[TreeNode] = [st]
vis.add(st.val)
ans = -1
while q:
for _ in range(len(q)):
cur = q.pop(0)
for nxt in (cur.left, cur.right, par[cur.val]):
if nxt and nxt.val not in vis:
vis.add(nxt.val)
q.append(nxt)
ans += 1
return ans
Rust
impl Solution {
pub fn amount_of_time(root: Option<Rc<RefCell<TreeNode>>>, start: i32) -> i32 {
use std::collections::{HashMap, HashSet, VecDeque};
fn dfs(node: Option<Rc<RefCell<TreeNode>>>, p: Option<Rc<RefCell<TreeNode>>>, par: &mut HashMap<i32, Option<Rc<RefCell<TreeNode>>>>) {
if let Some(n) = node.clone() {
let val = n.borrow().val;
par.insert(val, p.clone());
dfs(n.borrow().left.clone(), node.clone(), par);
dfs(n.borrow().right.clone(), node.clone(), par);
}
}
fn find(node: Option<Rc<RefCell<TreeNode>>>, start: i32) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(n) = node.clone() {
if n.borrow().val == start { return node; }
let l = find(n.borrow().left.clone(), start);
if l.is_some() { return l; }
return find(n.borrow().right.clone(), start);
}
None
}
let mut par = HashMap::new();
dfs(root.clone(), None, &mut par);
let st = find(root.clone(), start);
let mut vis = HashSet::new();
let mut q = VecDeque::new();
if let Some(s) = st.clone() {
q.push_back(s.clone());
vis.insert(s.borrow().val);
}
let mut ans = -1;
while !q.is_empty() {
for _ in 0..q.len() {
let cur = q.pop_front().unwrap();
let val = cur.borrow().val;
for nxt in [
cur.borrow().left.clone(),
cur.borrow().right.clone(),
par.get(&val).cloned().flatten()
] {
if let Some(n) = nxt {
let v = n.borrow().val;
if !vis.contains(&v) {
vis.insert(v);
q.push_back(n);
}
}
}
}
ans += 1;
}
ans
}
}
Complexity
- ⏰ Time complexity:
O(n), wherenis the number of nodes (each node is visited at most twice). - 🧺 Space complexity:
O(n)for parent mapping and visited set.