There is an undirected tree with n nodes labeled from 0 to n - 1.
You are given a 0-indexed integer array nums of length n where
nums[i] represents the value of the ith node. You are also given a 2D integer array edges of length n - 1 where edges[i] = [ai, bi] indicates that there is an edge between nodes ai and bi in the tree.
You are allowed to delete some edges, splitting the tree into multiple connected components. Let the value of a component be the sum of allnums[i] for which node i is in the component.
Return themaximum number of edges you can delete, such that every connected component in the tree has the same value.

Input: nums =[6,2,2,2,6], edges =[[0,1],[1,2],[1,3],[3,4]]Output: 2Explanation: The above figure shows how we can delete the edges [0,1] and [3,4]. The created components are nodes [0],[1,2,3] and [4]. The sum of the values in each component equals 6. It can be proven that no better deletion exists, so the answer is2.
The key idea is to split the tree into components with equal sum by removing edges. The sum of each component must be a divisor of the total sum. For each divisor, we use DFS to check if we can partition the tree into components of that sum by counting how many subtrees sum to the target.
classSolution {
public:int componentValue(vector<int>& nums, vector<vector<int>>& edges) {
int n = nums.size();
vector<vector<int>> g(n);
for (auto& e : edges) {
g[e[0]].push_back(e[1]);
g[e[1]].push_back(e[0]);
}
int total = accumulate(nums.begin(), nums.end(), 0);
int ans =0;
for (int k =1; k * k <= total; ++k) {
if (total % k ==0) {
for (int target : {k, total / k}) {
if (target < total) {
int cnt =0;
function<int(int, int)> dfs = [&](int u, int p) {
int sum = nums[u];
for (int v : g[u]) if (v != p) sum += dfs(v, u);
if (sum == target) { ++cnt; return0; }
return sum;
};
dfs(0, -1);
if (cnt == total / target) ans = max(ans, cnt -1);
}
}
}
}
return ans;
}
};
classSolution {
publicintcomponentValue(int[] nums, int[][] edges) {
int n = nums.length;
List<Integer>[] g =new List[n];
for (int i = 0; i < n; i++) g[i]=new ArrayList<>();
for (int[] e : edges) {
g[e[0]].add(e[1]);
g[e[1]].add(e[0]);
}
int total = 0;
for (int v : nums) total += v;
int ans = 0;
for (int k = 1; k * k <= total; k++) {
if (total % k == 0) {
for (int target : newint[]{k, total / k}) {
if (target < total) {
int[] cnt =newint[1];
dfs(0, -1, g, nums, target, cnt);
if (cnt[0]== total / target && ans < cnt[0]- 1) ans = cnt[0]- 1;
}
}
}
}
return ans;
}
privateintdfs(int u, int p, List<Integer>[] g, int[] nums, int target, int[] cnt) {
int sum = nums[u];
for (int v : g[u]) if (v != p) sum += dfs(v, u, g, nums, target, cnt);
if (sum == target) {
cnt[0]++;
return 0;
}
return sum;
}
}
classSolution {
funcomponentValue(nums: IntArray, edges: Array<IntArray>): Int {
val n = nums.size
val g = Array(n) { mutableListOf<Int>() }
for (e in edges) {
g[e[0]].add(e[1])
g[e[1]].add(e[0])
}
val total = nums.sum()
var ans = 0for (k in1..total) {
if (k * k > total) breakif (total % k ==0) {
for (target in listOf(k, total / k)) {
if (target < total) {
var cnt = 0fundfs(u: Int, p: Int): Int {
var sum = nums[u]
for (v in g[u]) if (v != p) sum += dfs(v, u)
if (sum == target) {
cnt++return0 }
return sum
}
dfs(0, -1)
if (cnt == total / target && ans < cnt - 1) ans = cnt - 1 }
}
}
}
return ans
}
}
classSolution:
defcomponentValue(self, nums: list[int], edges: list[list[int]]) -> int:
from math import isqrt
n = len(nums)
g = [[] for _ in range(n)]
for a, b in edges:
g[a].append(b)
g[b].append(a)
total = sum(nums)
ans =0for k in range(1, isqrt(total) +1):
if total % k ==0:
for target in [k, total // k]:
if target < total:
cnt =0defdfs(u, p):
nonlocal cnt
s = nums[u]
for v in g[u]:
if v != p:
s += dfs(v, u)
if s == target:
cnt +=1return0return s
dfs(0, -1)
if cnt == total // target and ans < cnt -1:
ans = cnt -1return ans
impl Solution {
pubfncomponent_value(nums: Vec<i32>, edges: Vec<Vec<i32>>) -> i32 {
let n = nums.len();
letmut g =vec![vec![]; n];
for e in&edges {
g[e[0] asusize].push(e[1] asusize);
g[e[1] asusize].push(e[0] asusize);
}
let total: i32= nums.iter().sum();
letmut ans =0;
letmut k =1;
while k * k <= total {
if total % k ==0 {
for&target in&[k, total / k] {
if target < total {
letmut cnt =0;
fndfs(u: usize, p: i32, g: &Vec<Vec<usize>>, nums: &Vec<i32>, target: i32, cnt: &muti32) -> i32 {
letmut sum = nums[u];
for&v in&g[u] {
if v asi32!= p {
sum += dfs(v, u asi32, g, nums, target, cnt);
}
}
if sum == target {
*cnt +=1;
return0;
}
sum
}
dfs(0, -1, &g, &nums, target, &mut cnt);
if cnt == total / target && ans < cnt -1 {
ans = cnt -1;
}
}
}
}
k +=1;
}
ans
}
}