找出所有稳定的二进制数组II

2024-08-07

题目:

给你 3 个正整数 zeroonelimit 。一个 二进制数组 arr 如果满足以下条件,那么我们称它是 稳定的

  • 0 在 arr 中出现次数 恰好zero
  • 1 在 arr 中出现次数 恰好one
  • arr中每个长度超过limit的子数组都同时包含 0 和 1 。

请你返回 稳定 二进制数组的 数目。由于答案可能很大,将它对 109 + 7 取余 后返回。

示例 1:

输入:zero = 1, one = 1, limit = 2

输出:2

解释:两个稳定的二进制数组为 [1,0][0,1] ,两个数组都有一个 0 和一个 1 ,且没有子数组长度大于 2 。

示例 2:

输入:zero = 1, one = 2, limit = 1

输出:1

解释:唯一稳定的二进制数组是 [1,0,1] 。二进制数组 [1,1,0][0,1,1] 都有长度为 2 且元素全都相同的子数组,所以它们不稳定。

示例 3:

输入:zero = 3, one = 3, limit = 2

输出:14

解释:

所有稳定的二进制数组包括 [0,0,1,0,1,1][0,0,1,1,0,1][0,1,0,0,1,1][0,1,0,1,0,1][0,1,0,1,1,0][0,1,1,0,0,1][0,1,1,0,1,0][1,0,0,1,0,1][1,0,0,1,1,0][1,0,1,0,0,1][1,0,1,0,1,0][1,0,1,1,0,0][1,1,0,0,1,0][1,1,0,1,0,0]

提示:

  • 1 <= zero, one, limit <= 1000

记忆化搜索

思路:

根据稳定数组的前两个条件,可知稳定数组的长度为 zero + one。第三个条件可知,稳定数组不存在长度为 limit + 1 的全 0 或全 1 子数组。

接下来我们分解问题,包含 zero 个 0 和 one 个 1 的稳定数组,末位元素可能为 1,也可能为 0。
	如果末位元素为 1,我们需要知道有多少个包含 zero 个 0 和 one − 1 个 1 的稳定数组,再去掉“由于添加了一个 1 而使得原来的稳定数组变得不稳定”的情况。那么有哪些情况会使得原来稳定的数组变得不稳定呢?即原来的稳定数组的末尾连续 1 的个数正好为 limit 个。在这种情况下,添加一个 1 会使得原来稳定的数组变得不稳定。这种情况出现的次数,即为包含 zero 个 0 和 one − 1 − limit 个 1,且末位元素为 0 的稳定数组的个数。
	如果末位元素为 0,我们需要知道有多少个包含 zero − 1 个 0 和 one 个 1 的稳定数组,再去掉“由于添加了一个 0 而使得原来的稳定数组变得不稳定”的情况。

这样一来,我们就将问题分解为子问题了,可以用动态规划求解。用函数 dp(zero, one, lastBit),来求解包含 zero 个 0 和 one 个 1,并且末位元素为 lastBit 的稳定数组的个数,其中 lastBit 为 0 或 1。根据上面的讨论,可以得到递推公式:
	dp(zero, one, 0) = dp(zero − 1, one, 0) +dp(zero − 1, one, 1) − dp(zero − 1 − limit, one, 1)
	dp(zero, one, 1) = dp(zero, one − 1, 0) + dp(zero, one − 1, 1) − dp(zero,one − 1 − limit, 0)。

另外考虑边界情况。如果 zero 为 0,那么当 lastBit 为 1 或者 one 大于 limit 时,不存在这样的稳定数组,返回 0,否则返回 1。如果 zero 为 1,也有对应的结论。

我们用记忆化搜索的方式来计算结果,记录所有的中间状态,最终返回 dp(zero, one, 0) + dp(zero, one, 1) 取模后的结果。

代码:

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        int mod = 1e9 + 7;
        vector<vector<vector<int>>> memo(zero + 1, vector<vector<int>>(one + 1, vector<int>(2, -1)));

        function<int(int, int, int)> dp = [&](int zero, int one, int lastBit) -> int {
            if (zero == 0) {
                return (lastBit == 0 || one > limit) ? 0 : 1;
            } else if (one == 0) {
                return (lastBit == 1 || zero > limit) ? 0 : 1;
            }

            if (memo[zero][one][lastBit] == -1) {
                int res = 0;
                if (lastBit == 0) {
                    res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1)) % mod;
                    if (zero > limit) {
                        res = (res - dp(zero - limit - 1, one, 1) + mod) % mod;
                    }
                } else {
                    res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % mod;
                    if (one > limit) {
                        res = (res - dp(zero, one - limit - 1, 0) + mod) % mod;
                    }
                }
                memo[zero][one][lastBit] = res % mod;
            }
            return memo[zero][one][lastBit];
        };

        return (dp(zero, one, 0) + dp(zero, one, 1)) % mod;
    }
};