Problem

There is an undirected tree with n nodes labeled from 0 to n - 1 and n - 1 edges.

You are given a 2D integer array edges of length n - 1 where edges[i] = [ai, bi] indicates that there is an edge between nodes ai and bi in the tree. You are also given an integer array restricted which represents restricted nodes.

Return _the maximum number of nodes you can reach from node _0 without visiting a restricted node.

Note that node 0 will not be a restricted node.

Examples

Example 1

graph TD
    classDef reachable fill:#51ff51,stroke:#2b8fb8,stroke-width:2px;
    classDef restricted fill:#ffe6e6,stroke:#d9534f,stroke-width:1.5px;

    0(0):::reachable
    1(1):::reachable
    2(2):::reachable
    3(3):::reachable
    4(4):::restricted
    5(5):::restricted
    6(6)

    0 --- 1
    1 --- 2
    1 --- 3
    0 --- 4
    0 --- 5
    5 --- 6

    %% reachable: 0,1,2,3 (4 and 5 are restricted so their branches are blocked)
  
1
2
3
4
Input: n = 7, edges = [[0,1],[1,2],[3,1],[4,0],[0,5],[5,6]], restricted = [4,5]
Output: 4
Explanation: The diagram above shows the tree.
We have that [0,1,2,3] are the only nodes that can be reached from node 0 without visiting a restricted node.

Example 2

graph TD
    classDef reachable fill:#51ff51,stroke:#2b8fb8,stroke-width:2px;
    classDef restricted fill:#ffe6e6,stroke:#d9534f,stroke-width:1.5px;

    0(0):::reachable
    1(1):::restricted
    2(2):::restricted
    3(3)
    4(4):::restricted
    5(5):::reachable
    6(6):::reachable

    0 --- 1
    0 --- 2
    0 --- 5
    0 --- 4
    2 --- 3
    5 --- 6

    %% reachable: 0,5,6 (1,2,4 restricted)
  
1
2
3
4
Input: n = 7, edges = [[0,1],[0,2],[0,5],[0,4],[3,2],[6,5]], restricted = [4,2,1]
Output: 3
Explanation: The diagram above shows the tree.
We have that [0,5,6] are the only nodes that can be reached from node 0 without visiting a restricted node.

Constraints

  • 2 <= n <= 10^5
  • edges.length == n - 1
  • edges[i].length == 2
  • 0 <= ai, bi < n
  • ai != bi
  • edges represents a valid tree.
  • 1 <= restricted.length < n
  • 1 <= restricted[i] < n
  • All the values of restricted are unique.

Solution

Method 1 - BFS from Root Skipping Restricted

Intuition

We are given a tree and a set of restricted nodes. We need to count how many nodes are reachable from node 0 without visiting any restricted node. Since the graph is a tree (acyclic, connected), we can use BFS or DFS to traverse from node 0, skipping restricted nodes.

Approach

  1. Build an adjacency list for the tree from the edges.
  2. Use a set for restricted nodes for O(1) lookup.
  3. Traverse the tree from node 0 using BFS or DFS, skipping any node that is restricted or already visited.
  4. Count the number of nodes visited.

This approach is efficient because each node and edge is visited at most once.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include <vector>
#include <unordered_set>
#include <queue>
using namespace std;

class Solution {
public:
    int reachableNodes(int n, vector<vector<int>>& edges, vector<int>& restricted) {
        vector<vector<int>> adj(n);
        for (auto& e : edges) {
            adj[e[0]].push_back(e[1]);
            adj[e[1]].push_back(e[0]);
        }
        unordered_set<int> rest(restricted.begin(), restricted.end());
        vector<bool> vis(n, false);
        queue<int> q;
        q.push(0);
        vis[0] = true;
        int cnt = 0;
        while (!q.empty()) {
            int u = q.front(); q.pop();
            cnt++;
            for (int v : adj[u]) {
                if (!vis[v] && !rest.count(v)) {
                    vis[v] = true;
                    q.push(v);
                }
            }
        }
        return cnt;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
package solution

func reachableNodes(n int, edges [][]int, restricted []int) int {
    adj := make([][]int, n)
    for _, e := range edges {
        adj[e[0]] = append(adj[e[0]], e[1])
        adj[e[1]] = append(adj[e[1]], e[0])
    }
    rest := make(map[int]bool)
    for _, r := range restricted {
        rest[r] = true
    }
    vis := make([]bool, n)
    q := []int{0}
    vis[0] = true
    cnt := 0
    for len(q) > 0 {
        u := q[0]
        q = q[1:]
        cnt++
        for _, v := range adj[u] {
            if !vis[v] && !rest[v] {
                vis[v] = true
                q = append(q, v)
            }
        }
    }
    return cnt
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import java.util.*;
class Solution {
    public int reachableNodes(int n, int[][] edges, int[] restricted) {
        List<List<Integer>> adj = new ArrayList<>();
        for (int i = 0; i < n; ++i) adj.add(new ArrayList<>());
        for (int[] e : edges) {
            adj.get(e[0]).add(e[1]);
            adj.get(e[1]).add(e[0]);
        }
        Set<Integer> rest = new HashSet<>();
        for (int r : restricted) rest.add(r);
        boolean[] vis = new boolean[n];
        Queue<Integer> q = new LinkedList<>();
        q.add(0);
        vis[0] = true;
        int cnt = 0;
        while (!q.isEmpty()) {
            int u = q.poll();
            cnt++;
            for (int v : adj.get(u)) {
                if (!vis[v] && !rest.contains(v)) {
                    vis[v] = true;
                    q.add(v);
                }
            }
        }
        return cnt;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
    fun reachableNodes(n: Int, edges: Array<IntArray>, restricted: IntArray): Int {
        val adj = Array(n) { mutableListOf<Int>() }
        for (e in edges) {
            adj[e[0]].add(e[1])
            adj[e[1]].add(e[0])
        }
        val rest = restricted.toHashSet()
        val vis = BooleanArray(n)
        val q = ArrayDeque<Int>()
        q.add(0)
        vis[0] = true
        var cnt = 0
        while (q.isNotEmpty()) {
            val u = q.removeFirst()
            cnt++
            for (v in adj[u]) {
                if (!vis[v] && v !in rest) {
                    vis[v] = true
                    q.add(v)
                }
            }
        }
        return cnt
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from collections import deque, defaultdict
class Solution:
    def reachableNodes(self, n: int, edges: list[list[int]], restricted: list[int]) -> int:
        adj = defaultdict(list)
        for a, b in edges:
            adj[a].append(b)
            adj[b].append(a)
        rest = set(restricted)
        vis = {0}
        q = deque([0])
        cnt = 0
        while q:
            u = q.popleft()
            cnt += 1
            for v in adj[u]:
                if v not in vis and v not in rest:
                    vis.add(v)
                    q.append(v)
        return cnt
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
use std::collections::{VecDeque, HashSet};
impl Solution {
    pub fn reachable_nodes(n: i32, edges: Vec<Vec<i32>>, restricted: Vec<i32>) -> i32 {
        let n = n as usize;
        let mut adj = vec![vec![]; n];
        for e in edges.iter() {
            let (a, b) = (e[0] as usize, e[1] as usize);
            adj[a].push(b);
            adj[b].push(a);
        }
        let rest: HashSet<_> = restricted.iter().map(|&x| x as usize).collect();
        let mut vis = vec![false; n];
        let mut q = VecDeque::new();
        q.push_back(0);
        vis[0] = true;
        let mut cnt = 0;
        while let Some(u) = q.pop_front() {
            cnt += 1;
            for &v in &adj[u] {
                if !vis[v] && !rest.contains(&v) {
                    vis[v] = true;
                    q.push_back(v);
                }
            }
        }
        cnt
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
function reachableNodes(n: number, edges: number[][], restricted: number[]): number {
    const adj: number[][] = Array.from({length: n}, () => []);
    for (const [a, b] of edges) {
        adj[a].push(b);
        adj[b].push(a);
    }
    const rest = new Set(restricted);
    const vis = new Array(n).fill(false);
    const q: number[] = [0];
    vis[0] = true;
    let cnt = 0;
    while (q.length) {
        const u = q.shift()!;
        cnt++;
        for (const v of adj[u]) {
            if (!vis[v] && !rest.has(v)) {
                vis[v] = true;
                q.push(v);
            }
        }
    }
    return cnt;
}

Complexity

  • Time complexity: O(n) – Each node and edge is visited at most once.
  • 🧺 Space complexity: O(n) – For adjacency list, visited set, and queue.