Input:
N =5, K =6Output:
0Explanation:
No partition of 5 can have all parts ≥6, so count =0 and aggregate =0.
Why does Aggregate = N * P(N,K)? Because each partition contributes exactly N to the aggregate (sum of its parts). Summing over all partitions multiplies N by the count.
Observation: Every valid partition sums to N. If the number of valid partitions is P(N, K), then the aggregate sum is simply N * P(N, K).
We only need to count partitions of N whose minimum part ≥ K. Define F(n, m) = number of partitions of n with all parts ≥ m (non-decreasing order enforced by always choosing next part ≥ previous). Then answer = N * F(N, K).
Two useful recurrences:
Summation form (direct branching on first part x):
F(n, m) = 0 if n < mF(n, m) = 1 if m ≤ n < 2m (only partition is {n})
F(n, m) = 1 + Σ_{x = m}^{n - m} F(n - x, x) otherwise.
Optimized two-way recurrence (include-or-skip m):
F(n, m) = 0 if n < mF(n, m) = 1 if m ≤ n < 2mF(n, m) = F(n, m + 1) + F(n - m, m) otherwise.
Explanation: Partitions either use at least one m (remove one m, still need parts ≥ m) or use none (all parts ≥ m+1). This removes the summation loop and yields O(n^2) time.
Branch on the first chosen part x (≥ current minimum m) and recursively partition the remaining n - x with new minimum x. Count all possibilities; base cases prune search.
function F(n, m):
if n < m: return 0
if n < 2*m: return 1 # only {n}
total = 1 # partition {n}
for x in [m .. n - m]:
total += F(n - x, x)
return total
longlongcountPartitionsNaive(longlong n, longlong m) {
if (n < m) return0;
if (n <2*m) return1; // only {n}
longlong total =1; // the partition {n}
for (longlong x = m; x <= n - m; ++x) {
total += countPartitionsNaive(n - x, x);
}
return total;
}
longlongaggregatePartitionSumNaive(int N, int K) {
return (longlong)N * countPartitionsNaive(N, K);
}
classPartitionNaive {
longcount(long n, long m) {
if (n < m) return 0;
if (n < 2*m) return 1;
long total = 1;
for (long x = m; x <= n - m; x++) {
total += count(n - x, x);
}
return total;
}
longaggregatePartitionSum(int N, int K) {
return (long)N * count(N, K);
}
}
1
2
3
4
5
6
7
8
9
10
11
12
defcount_partitions_naive(n: int, m: int) -> int:
if n < m:
return0if n <2*m:
return1 total =1for x in range(m, n - m +1):
total += count_partitions_naive(n - x, x)
return total
defaggregate_partition_sum_naive(N: int, K: int) -> int:
return N * count_partitions_naive(N, K)
Use the include-or-skip recurrence F(n,m) = F(n,m+1) + F(n-m,m) with memoization to avoid recomputation. This reduces time to the number of (n,m) states: O(N^2).
classPartitionMemo {
long[][] memo;
longdfs(int n, int m) {
if (n < m) return 0L;
if (n < 2*m) return 1L;
if (memo[n][m]!=-1L) return memo[n][m];
memo[n][m]= dfs(n, m+1) + dfs(n - m, m);
return memo[n][m];
}
longaggregatePartitionSum(int N, int K) {
memo =newlong[N+1][N+2];
for (int i = 0; i <= N; i++) java.util.Arrays.fill(memo[i], -1L);
return (long)N * dfs(N, K);
}
}
1
2
3
4
5
6
7
8
9
10
11
from functools import lru_cache
defaggregate_partition_sum_memo(N: int, K: int) -> int:
@lru_cache(None)
deff(n: int, m: int) -> int:
if n < m:
return0if n <2*m:
return1return f(n, m+1) + f(n - m, m)
return N * f(N, K)
classPartitionBottomUp {
longaggregatePartitionSum(int N, int K) {
long[][] dp =newlong[N+1][N+2];
for (int m = N; m >= 1; m--) {
for (int n = 0; n <= N; n++) {
if (n < m) dp[n][m]= 0;
elseif (n < 2*m) dp[n][m]= 1;
else dp[n][m]= dp[n][m+1]+ dp[n - m][m];
}
}
return (long)N * dp[N][K];
}
}
1
2
3
4
5
6
7
8
9
10
11
defaggregate_partition_sum_bottom_up(N: int, K: int) -> int:
dp = [[0]*(N+2) for _ in range(N+1)]
for m in range(N, 0, -1):
for n in range(N+1):
if n < m:
dp[n][m] =0elif n <2*m:
dp[n][m] =1else:
dp[n][m] = dp[n][m+1] + dp[n - m][m]
return N * dp[N][K]