Problem

Given a binary string s, you can split s into 3 non-empty strings s1, s2, and s3 where s1 + s2 + s3 = s.

Return the number of ways s can be split such that the number of ones is the same in s1, s2, and s3. Since the answer may be too large, return it modulo 10^9 + 7.

Examples

Example 1

1
2
3
4
5
6
7
Input: s = "10101"
Output: 4
Explanation: There are four ways to split s in 3 parts where each part contain the same number of letters '1'.
"1|010|1"
"1|01|01"
"10|10|1"
"10|1|01"

Example 2

1
2
Input: s = "1001"
Output: 0

Example 3

1
2
3
4
5
6
Input: s = "0000"
Output: 3
Explanation: There are three ways to split s in 3 parts.
"0|0|00"
"0|00|0"
"00|0|0"

Constraints

  • 3 <= s.length <= 10^5
  • s[i] is either '0' or '1'.

Solution

Method 1 – Prefix Counting and Combinatorics

Intuition

If the total number of ‘1’s in the string is not divisible by 3, it’s impossible to split as required. If there are no ‘1’s, any two split points will work. Otherwise, we count the number of ways to choose the split points so that each part contains exactly one-third of the total ‘1’s.

Approach

  1. Count the total number of ‘1’s in the string.
  2. If total is not divisible by 3, return 0.
  3. If total is 0, the answer is the number of ways to choose 2 split points among (n-1) possible places: C(n-1, 2).
  4. Otherwise:
    • Find the number of zeros after the first group of ones (first split) and after the second group (second split).
    • The answer is the product of (number of zeros after first split + 1) and (number of zeros after second split + 1).

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Solution {
public:
    int numWays(string s) {
        int n = s.size(), mod = 1e9 + 7;
        int ones = 0;
        for (char c : s) if (c == '1') ones++;
        if (ones == 0) return (long long)(n - 1) * (n - 2) / 2 % mod;
        if (ones % 3 != 0) return 0;
        int k = ones / 3, cnt1 = 0, cnt2 = 0, t = 0;
        for (char c : s) {
            if (c == '1') t++;
            if (t == k) cnt1++;
            else if (t == 2 * k) cnt2++;
        }
        return (long long)cnt1 * cnt2 % mod;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func numWays(s string) int {
    n, mod := len(s), int(1e9+7)
    ones := 0
    for _, c := range s {
        if c == '1' {
            ones++
        }
    }
    if ones == 0 {
        return ((n-1)*(n-2)/2)%mod
    }
    if ones%3 != 0 {
        return 0
    }
    k := ones / 3
    cnt1, cnt2, t := 0, 0, 0
    for _, c := range s {
        if c == '1' {
            t++
        }
        if t == k {
            cnt1++
        } else if t == 2*k {
            cnt2++
        }
    }
    return cnt1 * cnt2 % mod
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution {
    public int numWays(String s) {
        int n = s.length(), mod = 1_000_000_007;
        int ones = 0;
        for (char c : s.toCharArray()) if (c == '1') ones++;
        if (ones == 0) return (int)((long)(n-1)*(n-2)/2 % mod);
        if (ones % 3 != 0) return 0;
        int k = ones / 3, cnt1 = 0, cnt2 = 0, t = 0;
        for (char c : s.toCharArray()) {
            if (c == '1') t++;
            if (t == k) cnt1++;
            else if (t == 2*k) cnt2++;
        }
        return (int)((long)cnt1 * cnt2 % mod);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Solution {
    fun numWays(s: String): Int {
        val n = s.length
        val mod = 1_000_000_007
        val ones = s.count { it == '1' }
        if (ones == 0) return ((n-1).toLong()*(n-2)/2 % mod).toInt()
        if (ones % 3 != 0) return 0
        val k = ones / 3
        var cnt1 = 0; var cnt2 = 0; var t = 0
        for (c in s) {
            if (c == '1') t++
            if (t == k) cnt1++
            else if (t == 2*k) cnt2++
        }
        return ((cnt1.toLong() * cnt2) % mod).toInt()
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class Solution:
    def numWays(self, s: str) -> int:
        n, mod = len(s), 10**9+7
        ones = s.count('1')
        if ones == 0:
            return (n-1)*(n-2)//2 % mod
        if ones % 3 != 0:
            return 0
        k = ones // 3
        cnt1 = cnt2 = t = 0
        for c in s:
            if c == '1':
                t += 1
            if t == k:
                cnt1 += 1
            elif t == 2*k:
                cnt2 += 1
        return cnt1 * cnt2 % mod
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
impl Solution {
    pub fn num_ways(s: String) -> i32 {
        let n = s.len();
        let modv = 1_000_000_007;
        let ones = s.chars().filter(|&c| c == '1').count();
        if ones == 0 {
            return ((n-1)*(n-2)/2 % modv) as i32;
        }
        if ones % 3 != 0 {
            return 0;
        }
        let k = ones / 3;
        let (mut cnt1, mut cnt2, mut t) = (0, 0, 0);
        for c in s.chars() {
            if c == '1' { t += 1; }
            if t == k { cnt1 += 1; }
            else if t == 2*k { cnt2 += 1; }
        }
        ((cnt1 as i64 * cnt2 as i64) % modv) as i32
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution {
    numWays(s: string): number {
        const n = s.length, mod = 1_000_000_007
        const ones = [...s].filter(c => c === '1').length
        if (ones === 0) return ((n-1)*(n-2)/2) % mod
        if (ones % 3 !== 0) return 0
        const k = ones / 3
        let cnt1 = 0, cnt2 = 0, t = 0
        for (const c of s) {
            if (c === '1') t++
            if (t === k) cnt1++
            else if (t === 2*k) cnt2++
        }
        return cnt1 * cnt2 % mod
    }
}

Complexity

  • ⏰ Time complexity: O(n), where n is the length of s. We scan the string a constant number of times.
  • 🧺 Space complexity: O(1), only counters are used.