Problem

Given an undirected graph with weighted edges, compute the maximum weight spanning tree.

Recall that the minimum spanning tree is the subset of edges of a tree that connect all its vertices with the smallest possible total edge weight. A maximum weight spanning tree is similar, but instead connects all vertices with the largest possible total edge weight.

Examples

Example 1

1
2
3
4
5
6
7
8
9
Input: 
edges = [[0,1,10],[1,2,6],[0,2,15]]
n = 3

Output: 
[[0,2,15],[0,1,10]]

Explanation: 
The graph has 3 vertices (0,1,2) and 3 edges. The maximum spanning tree includes edges with weights 15 and 10, totaling 25. Edge (1,2) with weight 6 is excluded to avoid cycles.

Example 2

1
2
3
4
5
6
7
8
9
Input: 
edges = [[0,1,1],[1,2,3],[2,3,2],[0,3,4],[1,3,5]]
n = 4

Output: 
[[1,3,5],[0,3,4],[1,2,3]]

Explanation: 
The maximum spanning tree includes edges with weights 5, 4, and 3, totaling 12. This connects all 4 vertices with maximum total weight.

Example 3

1
2
3
4
5
6
7
8
9
Input: 
edges = [[0,1,7]]
n = 2

Output: 
[[0,1,7]]

Explanation: 
With only two vertices, there's only one possible spanning tree using the single edge.

Solution

Method 1 - Modified Kruskal’s Algorithm

Intuition

The key insight is that finding a maximum spanning tree is essentially the same as finding a minimum spanning tree, but with edge weights in descending order instead of ascending order. We can use Kruskal’s algorithm but sort edges by weight in descending order and greedily pick the heaviest edges that don’t create cycles.

Approach

  1. Sort all edges by weight in descending order (heaviest first)
  2. Initialize a Union-Find data structure to detect cycles
  3. Iterate through sorted edges and add each edge to the result if it doesn’t create a cycle
  4. Continue until we have exactly (n-1) edges in our spanning tree
  5. Return the selected edges

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution {
public:
    vector<int> parent, rank;
    
    int find(int x) {
        if (parent[x] != x) {
            parent[x] = find(parent[x]);
        }
        return parent[x];
    }
    
    bool unite(int x, int y) {
        int px = find(x), py = find(y);
        if (px == py) return false;
        
        if (rank[px] < rank[py]) swap(px, py);
        parent[py] = px;
        if (rank[px] == rank[py]) rank[px]++;
        return true;
    }
    
    vector<vector<int>> maxSpanningTree(vector<vector<int>>& edges, int n) {
        parent.resize(n);
        rank.resize(n, 0);
        for (int i = 0; i < n; i++) parent[i] = i;
        
        sort(edges.begin(), edges.end(), [](const vector<int>& a, const vector<int>& b) {
            return a[2] > b[2];
        });
        
        vector<vector<int>> ans;
        for (auto& edge : edges) {
            if (unite(edge[0], edge[1])) {
                ans.push_back(edge);
                if (ans.size() == n - 1) break;
            }
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
func maxSpanningTree(edges [][]int, n int) [][]int {
    parent := make([]int, n)
    rank := make([]int, n)
    for i := 0; i < n; i++ {
        parent[i] = i
    }
    
    var find func(int) int
    find = func(x int) int {
        if parent[x] != x {
            parent[x] = find(parent[x])
        }
        return parent[x]
    }
    
    unite := func(x, y int) bool {
        px, py := find(x), find(y)
        if px == py {
            return false
        }
        if rank[px] < rank[py] {
            px, py = py, px
        }
        parent[py] = px
        if rank[px] == rank[py] {
            rank[px]++
        }
        return true
    }
    
    sort.Slice(edges, func(i, j int) bool {
        return edges[i][2] > edges[j][2]
    })
    
    var ans [][]int
    for _, edge := range edges {
        if unite(edge[0], edge[1]) {
            ans = append(ans, edge)
            if len(ans) == n-1 {
                break
            }
        }
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Solution {
    private int[] parent, rank;
    
    private int find(int x) {
        if (parent[x] != x) {
            parent[x] = find(parent[x]);
        }
        return parent[x];
    }
    
    private boolean unite(int x, int y) {
        int px = find(x), py = find(y);
        if (px == py) return false;
        
        if (rank[px] < rank[py]) {
            int temp = px; px = py; py = temp;
        }
        parent[py] = px;
        if (rank[px] == rank[py]) rank[px]++;
        return true;
    }
    
    public int[][] maxSpanningTree(int[][] edges, int n) {
        parent = new int[n];
        rank = new int[n];
        for (int i = 0; i < n; i++) parent[i] = i;
        
        Arrays.sort(edges, (a, b) -> Integer.compare(b[2], a[2]));
        
        List<int[]> ans = new ArrayList<>();
        for (int[] edge : edges) {
            if (unite(edge[0], edge[1])) {
                ans.add(edge);
                if (ans.size() == n - 1) break;
            }
        }
        return ans.toArray(new int[ans.size()][]);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution:
    def maxSpanningTree(self, edges: List[List[int]], n: int) -> List[List[int]]:
        parent = list(range(n))
        rank = [0] * n
        
        def find(x: int) -> int:
            if parent[x] != x:
                parent[x] = find(parent[x])
            return parent[x]
        
        def unite(x: int, y: int) -> bool:
            px, py = find(x), find(y)
            if px == py:
                return False
            
            if rank[px] < rank[py]:
                px, py = py, px
            parent[py] = px
            if rank[px] == rank[py]:
                rank[px] += 1
            return True
        
        edges.sort(key=lambda x: -x[2])
        
        ans = []
        for edge in edges:
            if unite(edge[0], edge[1]):
                ans.append(edge)
                if len(ans) == n - 1:
                    break
        return ans

Complexity

  • ⏰ Time complexity: O(E log E), where E is the number of edges. Sorting edges takes O(E log E) and Union-Find operations take nearly O(E α(V)) where α is the inverse Ackermann function
  • 🧺 Space complexity: O(V), for the Union-Find data structure where V is the number of vertices

Method 2 - Modified Prim’s Algorithm

Intuition

Similar to Prim’s algorithm for minimum spanning tree, we can build the maximum spanning tree by starting from any vertex and repeatedly adding the heaviest edge that connects a vertex in our current tree to a vertex not yet included.

Approach

  1. Start with an arbitrary vertex and mark it as visited
  2. Use a max-heap (priority queue) to store edges from visited vertices to unvisited vertices
  3. Repeatedly extract the heaviest edge that connects to an unvisited vertex
  4. Add this edge to the result and mark the new vertex as visited
  5. Add all edges from the newly visited vertex to the heap
  6. Continue until all vertices are visited

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution {
public:
    vector<vector<int>> maxSpanningTree(vector<vector<int>>& edges, int n) {
        vector<vector<pair<int, int>>> graph(n);
        for (auto& edge : edges) {
            graph[edge[0]].push_back({edge[1], edge[2]});
            graph[edge[1]].push_back({edge[0], edge[2]});
        }
        
        vector<bool> visited(n, false);
        priority_queue<vector<int>> pq;
        vector<vector<int>> ans;
        
        visited[0] = true;
        for (auto& [neighbor, weight] : graph[0]) {
            pq.push({weight, 0, neighbor});
        }
        
        while (!pq.empty() && ans.size() < n - 1) {
            auto curr = pq.top();
            pq.pop();
            
            int weight = curr[0], u = curr[1], v = curr[2];
            if (visited[v]) continue;
            
            visited[v] = true;
            ans.push_back({u, v, weight});
            
            for (auto& [neighbor, w] : graph[v]) {
                if (!visited[neighbor]) {
                    pq.push({w, v, neighbor});
                }
            }
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import "container/heap"

func maxSpanningTree(edges [][]int, n int) [][]int {
    graph := make([][][2]int, n)
    for _, edge := range edges {
        u, v, w := edge[0], edge[1], edge[2]
        graph[u] = append(graph[u], [2]int{v, w})
        graph[v] = append(graph[v], [2]int{u, w})
    }
    
    visited := make([]bool, n)
    pq := &MaxHeap{}
    heap.Init(pq)
    var ans [][]int
    
    visited[0] = true
    for _, neighbor := range graph[0] {
        heap.Push(pq, []int{neighbor[1], 0, neighbor[0]})
    }
    
    for pq.Len() > 0 && len(ans) < n-1 {
        curr := heap.Pop(pq).([]int)
        weight, u, v := curr[0], curr[1], curr[2]
        
        if visited[v] {
            continue
        }
        
        visited[v] = true
        ans = append(ans, []int{u, v, weight})
        
        for _, neighbor := range graph[v] {
            if !visited[neighbor[0]] {
                heap.Push(pq, []int{neighbor[1], v, neighbor[0]})
            }
        }
    }
    return ans
}

type MaxHeap [][]int

func (h MaxHeap) Len() int           { return len(h) }
func (h MaxHeap) Less(i, j int) bool { return h[i][0] > h[j][0] }
func (h MaxHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }

func (h *MaxHeap) Push(x interface{}) {
    *h = append(*h, x.([]int))
}

func (h *MaxHeap) Pop() interface{} {
    old := *h
    n := len(old)
    x := old[n-1]
    *h = old[0 : n-1]
    return x
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution {
    public int[][] maxSpanningTree(int[][] edges, int n) {
        List<List<int[]>> graph = new ArrayList<>();
        for (int i = 0; i < n; i++) graph.add(new ArrayList<>());
        
        for (int[] edge : edges) {
            graph.get(edge[0]).add(new int[]{edge[1], edge[2]});
            graph.get(edge[1]).add(new int[]{edge[0], edge[2]});
        }
        
        boolean[] visited = new boolean[n];
        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> Integer.compare(b[0], a[0]));
        List<int[]> ans = new ArrayList<>();
        
        visited[0] = true;
        for (int[] neighbor : graph.get(0)) {
            pq.offer(new int[]{neighbor[1], 0, neighbor[0]});
        }
        
        while (!pq.isEmpty() && ans.size() < n - 1) {
            int[] curr = pq.poll();
            int weight = curr[0], u = curr[1], v = curr[2];
            
            if (visited[v]) continue;
            
            visited[v] = true;
            ans.add(new int[]{u, v, weight});
            
            for (int[] neighbor : graph.get(v)) {
                if (!visited[neighbor[0]]) {
                    pq.offer(new int[]{neighbor[1], v, neighbor[0]});
                }
            }
        }
        return ans.toArray(new int[ans.size()][]);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import heapq

class Solution:
    def maxSpanningTree(self, edges: List[List[int]], n: int) -> List[List[int]]:
        graph = [[] for _ in range(n)]
        for u, v, w in edges:
            graph[u].append((v, w))
            graph[v].append((u, w))
        
        visited = [False] * n
        pq = []
        ans = []
        
        visited[0] = True
        for neighbor, weight in graph[0]:
            heapq.heappush(pq, (-weight, 0, neighbor))
        
        while pq and len(ans) < n - 1:
            neg_weight, u, v = heapq.heappop(pq)
            weight = -neg_weight
            
            if visited[v]:
                continue
            
            visited[v] = True
            ans.append([u, v, weight])
            
            for neighbor, w in graph[v]:
                if not visited[neighbor]:
                    heapq.heappush(pq, (-w, v, neighbor))
        
        return ans

Complexity

  • ⏰ Time complexity: O(E log V), where E is the number of edges and V is the number of vertices. Each edge is processed once and heap operations take O(log V)
  • 🧺 Space complexity: O(V + E), for the adjacency list representation and the priority queue