Problem

You are given a network of n nodes represented as an n x n adjacency matrix graph, where the ith node is directly connected to the jth node if graph[i][j] == 1.

Some nodes initial are initially infected by malware. Whenever two nodes are directly connected, and at least one of those two nodes is infected by malware, both nodes will be infected by malware. This spread of malware will continue until no more nodes can be infected in this manner.

Suppose M(initial) is the final number of nodes infected with malware in the entire network after the spread of malware stops.

We will remove exactly one node from initialcompletely removing it and any connections from this node to any other node.

Return the node that, if removed, would minimize M(initial). If multiple nodes could be removed to minimize M(initial), return such a node with the smallest index.

Examples

Example 1:

1
2
3
4
Input:
graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1]
Output:
 0

Example 2:

1
2
3
4
Input:
graph = [[1,1,0],[1,1,1],[0,1,1]], initial = [0,1]
Output:
 1

Example 3:

1
2
3
4
Input:
graph = [[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]], initial = [0,1]
Output:
 1

Solution

Method 1 – Union-Find and Component Analysis (1)

Intuition

The key idea is to find the connected components of the graph and analyze how many initial infected nodes are in each component. If a component contains exactly one initial node, removing it will prevent the infection of the entire component. We want to maximize the number of saved nodes, and in case of a tie, return the smallest index.

Approach

  1. Use union-find to find all connected components.
  2. For each component, count its size and the number of initial infected nodes in it.
  3. For each initial node, if it is the only infected node in its component, calculate the number of nodes that would be saved by removing it.
  4. Return the node that saves the most nodes (break ties by smallest index).

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution {
public:
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        vector<int> par(n);
        iota(par.begin(), par.end(), 0);
        function<int(int)> find = [&](int x) { return par[x] == x ? x : par[x] = find(par[x]); };
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < n; ++j)
                if (graph[i][j]) par[find(i)] = find(j);
        unordered_map<int, int> cnt, infected;
        for (int i = 0; i < n; ++i) cnt[find(i)]++;
        for (int u : initial) infected[find(u)]++;
        int ans = -1, maxSaved = -1;
        sort(initial.begin(), initial.end());
        for (int u : initial) {
            int root = find(u);
            if (infected[root] == 1) {
                if (cnt[root] > maxSaved) {
                    maxSaved = cnt[root];
                    ans = u;
                }
            }
        }
        return ans == -1 ? *min_element(initial.begin(), initial.end()) : ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func minMalwareSpread(graph [][]int, initial []int) int {
    n := len(graph)
    par := make([]int, n)
    for i := range par { par[i] = i }
    var find func(int) int
    find = func(x int) int { if par[x] != x { par[x] = find(par[x]) }; return par[x] }
    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            if graph[i][j] == 1 {
                par[find(i)] = find(j)
            }
        }
    }
    cnt := map[int]int{}
    infected := map[int]int{}
    for i := 0; i < n; i++ { cnt[find(i)]++ }
    for _, u := range initial { infected[find(u)]++ }
    ans, maxSaved := -1, -1
    sort.Ints(initial)
    for _, u := range initial {
        root := find(u)
        if infected[root] == 1 && cnt[root] > maxSaved {
            maxSaved = cnt[root]
            ans = u
        }
    }
    if ans == -1 {
        ans = initial[0]
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution {
    public int minMalwareSpread(int[][] graph, int[] initial) {
        int n = graph.length;
        int[] par = new int[n];
        for (int i = 0; i < n; i++) par[i] = i;
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                if (graph[i][j] == 1) par[find(par, i)] = find(par, j);
        Map<Integer, Integer> cnt = new HashMap<>(), infected = new HashMap<>();
        for (int i = 0; i < n; i++) cnt.put(find(par, i), cnt.getOrDefault(find(par, i), 0) + 1);
        for (int u : initial) infected.put(find(par, u), infected.getOrDefault(find(par, u), 0) + 1);
        Arrays.sort(initial);
        int ans = -1, maxSaved = -1;
        for (int u : initial) {
            int root = find(par, u);
            if (infected.get(root) == 1 && cnt.get(root) > maxSaved) {
                maxSaved = cnt.get(root);
                ans = u;
            }
        }
        return ans == -1 ? initial[0] : ans;
    }
    private int find(int[] par, int x) {
        if (par[x] != x) par[x] = find(par, par[x]);
        return par[x];
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
    fun minMalwareSpread(graph: Array<IntArray>, initial: IntArray): Int {
        val n = graph.size
        val par = IntArray(n) { it }
        fun find(x: Int): Int = if (par[x] == x) x else { par[x] = find(par[x]); par[x] }
        for (i in 0 until n) for (j in 0 until n) if (graph[i][j] == 1) par[find(i)] = find(j)
        val cnt = mutableMapOf<Int, Int>()
        val infected = mutableMapOf<Int, Int>()
        for (i in 0 until n) cnt[find(i)] = cnt.getOrDefault(find(i), 0) + 1
        for (u in initial) infected[find(u)] = infected.getOrDefault(find(u), 0) + 1
        initial.sort()
        var ans = -1
        var maxSaved = -1
        for (u in initial) {
            val root = find(u)
            if (infected[root] == 1 && cnt[root] > maxSaved) {
                maxSaved = cnt[root]
                ans = u
            }
        }
        return if (ans == -1) initial[0] else ans
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution:
    def minMalwareSpread(self, graph: list[list[int]], initial: list[int]) -> int:
        n = len(graph)
        par = list(range(n))
        def find(x):
            if par[x] != x:
                par[x] = find(par[x])
            return par[x]
        for i in range(n):
            for j in range(n):
                if graph[i][j]:
                    par[find(i)] = find(j)
        cnt = {}
        infected = {}
        for i in range(n):
            root = find(i)
            cnt[root] = cnt.get(root, 0) + 1
        for u in initial:
            root = find(u)
            infected[root] = infected.get(root, 0) + 1
        ans = -1
        maxSaved = -1
        for u in sorted(initial):
            root = find(u)
            if infected[root] == 1 and cnt[root] > maxSaved:
                maxSaved = cnt[root]
                ans = u
        return ans if ans != -1 else min(initial)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
impl Solution {
    pub fn min_malware_spread(graph: Vec<Vec<i32>>, initial: Vec<i32>) -> i32 {
        let n = graph.len();
        let mut par: Vec<_> = (0..n).collect();
        fn find(par: &mut Vec<usize>, x: usize) -> usize {
            if par[x] != x { par[x] = find(par, par[x]); }
            par[x]
        }
        for i in 0..n {
            for j in 0..n {
                if graph[i][j] == 1 {
                    let pi = find(&mut par, i);
                    let pj = find(&mut par, j);
                    par[pi] = pj;
                }
            }
        }
        let mut cnt = std::collections::HashMap::new();
        let mut infected = std::collections::HashMap::new();
        for i in 0..n {
            let root = find(&mut par, i);
            *cnt.entry(root).or_insert(0) += 1;
        }
        for &u in &initial {
            let root = find(&mut par, u as usize);
            *infected.entry(root).or_insert(0) += 1;
        }
        let mut ans = -1;
        let mut max_saved = -1;
        let mut initial_sorted = initial.clone();
        initial_sorted.sort();
        for &u in &initial_sorted {
            let root = find(&mut par, u as usize);
            if *infected.get(&root).unwrap_or(&0) == 1 && *cnt.get(&root).unwrap_or(&0) > max_saved {
                max_saved = *cnt.get(&root).unwrap();
                ans = u;
            }
        }
        if ans == -1 {
            ans = *initial_sorted.iter().min().unwrap();
        }
        ans
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
    minMalwareSpread(graph: number[][], initial: number[]): number {
        const n = graph.length;
        const par = Array.from({length: n}, (_, i) => i);
        const find = (x: number): number => par[x] === x ? x : (par[x] = find(par[x]));
        for (let i = 0; i < n; i++)
            for (let j = 0; j < n; j++)
                if (graph[i][j]) par[find(i)] = find(j);
        const cnt: Record<number, number> = {};
        const infected: Record<number, number> = {};
        for (let i = 0; i < n; i++) cnt[find(i)] = (cnt[find(i)] || 0) + 1;
        for (const u of initial) infected[find(u)] = (infected[find(u)] || 0) + 1;
        initial.sort((a, b) => a - b);
        let ans = -1, maxSaved = -1;
        for (const u of initial) {
            const root = find(u);
            if (infected[root] === 1 && cnt[root] > maxSaved) {
                maxSaved = cnt[root];
                ans = u;
            }
        }
        return ans === -1 ? initial[0] : ans;
    }
}

Complexity

  • ⏰ Time complexity: O(n^2), where n is the number of nodes; we process each cell in the adjacency matrix.
  • 🧺 Space complexity: O(n), for the parent array and component counts.