As some readers of this blog might’ve guessed by now, we like applying FFT in unexpected places when solving classic LeetCode challenges, see [1], [2].
Today we’re going to see how the Partition Equal Subset Sum problem, typically solved using dynamic programming, can actually be approached using polynomial multiplication and FFT-style tricks. The problem goes like this:
Given a list of positive integers nums
, can we split it into two subsets with equal sum?
In other words: is there a subset of nums
whose sum is exactly half of the total?
First standard DP approach
Usually, for such kind of problems one might try to implement a bitmasking approach to track which subset sums are possible. But in this particular challenge the constraints make that infeasible: the number of elements n
can be up to 200, and each element is in the range [1, 100]
. That’s way too large for bitmasking over all combinations — it would blow up in both time and memory. A classical trick here is to use DP over all possible sums bounded by the target number, with intuition similar to what one does in the knapsack problem.
We start by tracking all possible subset sums using the first k
numbers. Then, for each new number, we update the set of reachable sums by adding it to all the sums we’ve already seen. This leads to the following code, as always in this series we are going to utilise Magma as it provides fast polynomial multiplication out of the box:
function CanPartitionDP(nums)
total := &+nums;
if total mod 2 ne 0 then
return false;
end if;
target := total div 2;
reachable := [false : i in [0..target]];
reachable[1] := true;
for n in nums do
for j in [target..n by -1] do
if reachable[j - n + 1] then
reachable[j + 1] := true;
end if;
end for;
end for;
return reachable[target + 1];
end function;
A rough estimate of the running time for this DP approach is n * target
, where target
can go up to n * max(nums) / 2
. If we denote this maximum value as m
, the complexity ends up being around O(n² * m). That’s not bad, but it raises a natural question: can we do it differently?
Fast polynomial multiplication trick: the algorithm
Let’s assign to each number aᵢ
in nums
a simple polynomial:
fᵢ(x) = 1 + xai.
Now, consider the product of all these polynomials. The coefficient bₖ
in front of x
k
in the final product tells us how many ways we can form the sum k
using elements from nums
.
So, to solve our problem, we just need to check whether the coefficient of xtarget
is non-zero. That means there’s some subset of numbers that adds up to exactly target
. The main task now is to efficiently multiply all those polynomials together.
And here comes the classic trick again: we group and multiply the polynomials in stages. First, we multiply n/2
pairs of polynomials, then n/4
pairs of the resulting products, and so on. In total, we perform about log(n)
rounds. In the i
-th round, we do n/2i
polynomial multiplications, each involving polynomials of degree roughly m * 2i-1
.
Using FFT, multiplying two polynomials of degree k
takes around O(k log k)
operations. So if we add it all up across the rounds, the total time complexity is roughly:
(n/2) * m * log(m) + (n/4) * (2m * log(2m)) + (n/8) * (4m * log(4m)) + ...
We can factor this into:
(nm / 2) * Σ log(2i * m)
= (nm / 2) * [log(n) * log(m) + 1 + 2 + ... + log(n)]
So overall, we get an approximate time complexity of:
(nm / 2) * [log(n) * log(m) + (log(n) * (log(n) + 1)) / 2]
Which appears to be asymptotically way faster than the original dp approach.
Fast polynomial multiplication trick: an implementation
Here is a short implementation of the algorithms from the above section:
function CanPartition(nums)
total := &+nums;
if total mod 2 ne 0 then
return false;
end if;
target := total div 2;
R<x> := PolynomialRing(Integers());
polys := [1 + x^n : n in nums];
count := 0;
while #polys gt 1 do
new_polys := [];
for i in [1..#polys by 2] do
if i + 1 le #polys then
prod := polys[i] * polys[i+1];
else
prod := polys[i];
end if;
Append(~new_polys, prod);
end for;
polys := new_polys;
end while;
final := polys[1];
return Coefficient(final, target) gt 0;
end function;
Comparing the results
Finally let’s do some runtime comparison. In order to do this we generate a random set of nums with total sum divisible by 2 and execute both functions and compare results and the performance.
repeat
nums := [Random(100, 500) : i in [1..30]];
until &+nums mod 2 eq 0;
t0 := Realtime();
print CanPartition(nums);
t1 := Realtime();
print "FFT-like Time:", t1 - t0;
t2 := Realtime();
print CanPartitionDP(nums);
t3 := Realtime();
print "DP Time:", t3 - t2;
And the results on this small input are:
true
FFT-like Time: 0.010
true
DP Time: 0.030
We already see that dp is about 3 times slower and the results are even more visible when we start increasing n and m. Say for n = 500 and m = 500 we get:
true
FFT-like Time: 1.050
true
DP Time: 10.920
I’m not sure why, but I always feel a sense of satisfaction when a bit of slightly advanced mathematical thinking can make a real impact.