Sum of Distances in Tree Problem
Problem
There is an undirected connected tree with n
nodes labeled from 0
to n - 1
and n - 1
edges.
You are given the integer n
and the array edges
where edges[i] = [ai, bi]
indicates that there is an edge between nodes ai
and bi
in the tree.
Return an array answer
of length n
where answer[i]
is the sum of the distances between the ith
node in the tree and all other nodes.
Examples
Example 1:
graph TD; 0 --- 1 & 2 2 --- 3 & 4 & 5
Input:
n = 6, edges =[[0,1],[0,2],[2,3],[2,4],[2,5]]
Output:
[8,12,6,10,10,10]
Explanation: The tree is shown above.
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8.
Hence, answer[0] = 8, and so on.
Example 2:
graph TD; 0
Input:
n = 1, edges = []
Output:
[0]
Example 3:
graph TD; 0 --- 1
Input:
n = 2, edges =[[1,0]]
Output:
[1,1]
Solution
Method 1 - Pre-order and Post-order Traversal
Solve for root
Lets start with example 1, and lets say we have to find the sum of distances only for root node.
graph TD; 0 --- 1 & 2 2 --- 3 & 4 & 5 style 0 fill:#f9f
Then, we can see following:
sumOfDistance(0) = sumOfDistance(1) + number of nodes in subtree with root 1 +
sumOfDistance(2) + number of nodes in subtree with root 2
sumOfDistance(0) = 0 + 1 + 3 + 4 = 8
sumOfDistance(1) means distance, when 1
is root (see tree below), and it is 0 because it doesn’t have any child.
graph TD; 1 style 1 fill:#f9f
sumOfDistance(2) is 3, because we can see it has 3 children, all as leaf (see below)
graph TD; 2 --- 3 & 4 & 5 style 2 fill:#f9f
We are adding the number of nodes of subtree because every node in subtree will be 1 unit more far from the original root.
For eg. for subtree with 2 as root, distance from 0 adds 4:
We need two arrays cnt
and ans
. cnt
will store the number of nodes in each subtree and ans
will the store the answer as discussed above. Below is the code.
Original Solution
We can run the dfs function above for every node and get the solution. This will result in O(N * N) time complexity. We can do this in O(N) time using a technique popularly known as re-rooting technique.
Suppose, instead of 0
as root, we take 2
as root. What changes will we observe?
graph TD; 2 --- 0 & 3 & 4 & 5 0 --- 1 style 2 fill:#f9f
cnt[2]
nodes got closer by 1 unit to new root, and n - cnt[2]
nodes got farther by 1 unit from new root.
ans[2] = ans[0] - cnt[2] + n - cnt[2]
i.e.
ans[newRoot] = ans[oldRoot] - cnt[newRoot] + n - cnt[newRoot];
Refer to dfs2
function for that.
Code
Java
class Solution {
public int[] sumOfDistancesInTree(int n, int[][] edges) {
List<Integer>[] tree = buildGraph(n, edges);
int[] cnt = new int[n];
int[] ans = new int[n];
dfs(tree, 0, -1, cnt, ans);
dfs2(tree, 0, -1, cnt, ans);
return ans;
}
private List<Integer>[] buildGraph(int n, int[][] edges) {
List<Integer>[] tree = new List[n];
for (int i = 0; i < n; i++) {
tree[i] = new ArrayList<>();
}
for (int[] edge : edges) {
tree[edge[0]].add(edge[1]);
tree[edge[1]].add(edge[0]);
}
return tree;
}
private void dfs(List<Integer>[] tree, int cur, int pre, int[] cnt, int[] ans) {
for (int child : tree[cur]) {
if (child != pre) {
dfs(tree, child, cur, cnt, ans);
ans[cur] += ans[child] + cnt[child];
cnt[cur] += cnt[child];
}
}
cnt[cur]++;
}
private void dfs2(List<Integer>[] tree, int cur, int pre, int[] cnt, int[] ans) {
for (int child : tree[cur]) {
if (child != pre) {
ans[child] = ans[cur] - cnt[child] + cnt.length - cnt[child];
dfs2(tree, child, cur, cnt, ans);
}
}
}
}
Complexity
- Time:
O(N)
- for both dfs - Space:
O(N)