找出所有稳定的二进制数组I

2024-08-06

题目:

给你 3 个正整数 zeroonelimit 。一个 二进制数组 arr 如果满足以下条件,那么我们称它是 稳定的

  • 0 在 arr 中出现次数 恰好zero
  • 1 在 arr 中出现次数 恰好one
  • arr中每个长度超过limit的子数组都同时包含 0 和 1 。

请你返回 稳定 二进制数组的 数目。由于答案可能很大,将它对 109 + 7 取余 后返回。

示例 1:

输入:zero = 1, one = 1, limit = 2

输出:2

解释:两个稳定的二进制数组为 [1,0][0,1] ,两个数组都有一个 0 和一个 1 ,且没有子数组长度大于 2 。

示例 2:

输入:zero = 1, one = 2, limit = 1

输出:1

解释:唯一稳定的二进制数组是 [1,0,1] 。二进制数组 [1,1,0][0,1,1] 都有长度为 2 且元素全都相同的子数组,所以它们不稳定。

示例 3:

输入:zero = 3, one = 3, limit = 2

输出:14

解释:

所有稳定的二进制数组包括 [0,0,1,0,1,1][0,0,1,1,0,1][0,1,0,0,1,1][0,1,0,1,0,1][0,1,0,1,1,0][0,1,1,0,0,1][0,1,1,0,1,0][1,0,0,1,0,1][1,0,0,1,1,0][1,0,1,0,0,1][1,0,1,0,1,0][1,0,1,1,0,0][1,1,0,0,1,0][1,1,0,1,0,0]

提示:

  • 1 <= zero, one, limit <= 200

暴力回溯所有可能arr

思路:

通过回溯算法,求解所有满足zero 和 one 的数组 arr
然后利用前缀和判断其中的每个数组是否满足 limit 条件
只要长度为 limit + 1 的子数组(利用滑动窗口)和在 [1, limit] 中就是满足条件的

代码:

class Solution {
public:
    vector<int> path;
    int ans = 0;
    int limit;
    int mod = 1e9 + 7;
    void backTrack(int zero, int one) {
        if (zero == 0 && one == 0) {
            int n = path.size();
            vector<int> preSum(n, 0);
            int cur = 0;
            for (int i = 0; i < n; i++) {
                cur += path[i];
                preSum[i] = cur;
                if (i >= limit) {
                    int judge = preSum[i] - preSum[i - limit] + path[i - limit];
                    // 出现全 0 或者 全 1 的情况,直接返回,这里就是剪枝优化
                    if (judge == 0 || judge == limit + 1) {
                        return;
                    }
                }
            }
            ans = (ans + 1) % mod;
            return;
        }
        // 下一位填 0
        if (zero) {
            zero--;
            path.push_back(0);
            backTrack(zero, one);
            path.pop_back();
            zero++;
        }
        // 下一位填 1
        if (one) {
            one--;
            path.push_back(1);
            backTrack(zero, one);
            path.pop_back();
            one++;
        }
    }
    int numberOfStableArrays(int zero, int one, int limit) {
        // 解法1:通过回溯算法,求解所有满足zero 和 one 的数组 arr
        //       然后利用前缀和判断其中的每个数组是否满足 limit 条件
        //       只要长度为 limit + 1 的子数组(利用滑动窗口)和在 [1, limit] 中就是满足条件的
        this->limit = limit;
        backTrack(zero, one);
        return ans;
    }
};

三维动态规划

思路:

这个转移方程过于难想了,感兴趣自己看吧,真想不到
    
题目要求二进制数组 arr 中每个长度超过 limit 的子数组同时包含 0 和 1,这个条件等价于二进制数组 arr 中每个长度等于 limit+1 的子数组都同时包含 0 和 1(读者可以思考一下证明过程)。

按照题目要求,我们需要将 zero 个 0 和 one 个 1 依次填入二进制数组 arr,使用 dp0[i][j] 表示已经填入 i 个 0 和 j 个 1,并且最后填入的数字为 0 的可行方案数目,dp1[i][j] 表示已经填入 i 个 0 和 j 个 1,并且最后填入的数字为 1 的可行方案数目。对于 dp0[i][j],我们分析一下它的转换方程:

	当 j = 0 且 i ∈ [0, min(zero,limit)] 时:我们可以不断地填入 0,所以 dp0[i][j]=1。
	当 i = 0,或者 j = 0 且 i ∈/ [0, min(zero, limit)] 时:我们没法构造可行的方案,所以 dp0[i][j] = 0。
 	当 i > 0 且 j > 0 时:dp0[i][j] 可以分别由 dp0[i−1][j] 和 dp1[i−1][j] 转移而来,分别考虑两种情况:
		对于 dp1[i−1][j]:显然可以通过在 dp1[i−1][j] 对应的所有填入方案后再填入一个 0 得到对应的可行方案。
		对于 dp0[i−1][j]:当 i ≤ limit 时,显然可以通过在 dp1[i−1][j] 对应的所有填入方案后再填入一个 0 得到对应的可行方案;当 i > limit 时,我们需要去除一些不可行的方案数。因为 dp0[i−1][j] 对应的所有填入方案都是可行的,而只有一种情况会在额外填入一个 0 时,变成不可行,即先前已经连续填入 limit 个 0,对应的方案数为 dp1[i−limit−1][j]。

根据以上分析,我们有 dp0[i][j] 的转移方程:
	dp0[i][j] = 1, i ∈ [0, min(zero, limit)], j = 0
			 = dp1[i - 1][j] + dp0[i - 1][j] - dp1[i - limit - 1][j], i > limit, j > 0
              = dp1[i - 1][j] + dp0[i - 1][j], i ∈ [0, limit], j > 0
              = 0, otherwise
              
同理,我们也可以获得 dp1[i][j] 的转移方程: 
	dp1[i][j] = 1, j ∈ [0, min(one, limit)], i = 0
			 = dp0[i][j - 1] + dp1[i][j - 1] - dp0[i][j - limit - 1], j > limit, i > 0
              = dp0[i][j - 1] + dp1[i][j - 1], j ∈ [0, limit], i > 0
              = 0, otherwise
              
最后,稳定二进制数组的数目等于 dp0[zero][one]+dp1[zero][one]。

代码:

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        vector<vector<vector<long long>>> dp(zero + 1, vector<vector<long long>>(one + 1, vector<long long>(2)));
        long long mod = 1e9 + 7;
        for (int i = 0; i <= min(zero, limit); i++) {
            dp[i][0][0] = 1;
        }
        for (int j = 0; j <= min(one, limit); j++) {
            dp[0][j][1] = 1;
        }
        for (int i = 1; i <= zero; i++) {
            for (int j = 1; j <= one; j++) {
                if (i > limit) {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
                } else {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
                }
                dp[i][j][0] = (dp[i][j][0] % mod + mod) % mod;
                if (j > limit) {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
                } else {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
                }
                dp[i][j][1] = (dp[i][j][1] % mod + mod) % mod;
            }
        }
        return (dp[zero][one][0] + dp[zero][one][1]) % mod;
    }
};