题目:
给你 3 个正整数 zero
,one
和 limit
。一个 二进制数组 arr
如果满足以下条件,那么我们称它是 稳定的 :
- 0 在
arr
中出现次数 恰好 为zero
。 - 1 在
arr
中出现次数 恰好 为one
。 arr
中每个长度超过limit
的子数组都同时包含 0 和 1 。
请你返回 稳定 二进制数组的 总 数目。由于答案可能很大,将它对 109 + 7
取余 后返回。
示例 1:
输入:zero = 1, one = 1, limit = 2
输出:2
解释:两个稳定的二进制数组为 [1,0]
和 [0,1]
,两个数组都有一个 0 和一个 1 ,且没有子数组长度大于 2 。
示例 2:
输入:zero = 1, one = 2, limit = 1
输出:1
解释:唯一稳定的二进制数组是 [1,0,1]
。二进制数组 [1,1,0]
和 [0,1,1]
都有长度为 2 且元素全都相同的子数组,所以它们不稳定。
示例 3:
输入:zero = 3, one = 3, limit = 2
输出:14
解释:
所有稳定的二进制数组包括 [0,0,1,0,1,1]
,[0,0,1,1,0,1]
,[0,1,0,0,1,1]
,[0,1,0,1,0,1]
,[0,1,0,1,1,0]
,[0,1,1,0,0,1]
,[0,1,1,0,1,0]
,[1,0,0,1,0,1]
,[1,0,0,1,1,0]
,[1,0,1,0,0,1]
,[1,0,1,0,1,0]
,[1,0,1,1,0,0]
,[1,1,0,0,1,0]
和 [1,1,0,1,0,0]
。
提示:
1 <= zero, one, limit <= 1000
记忆化搜索
思路:
根据稳定数组的前两个条件,可知稳定数组的长度为 zero + one。第三个条件可知,稳定数组不存在长度为 limit + 1 的全 0 或全 1 子数组。
接下来我们分解问题,包含 zero 个 0 和 one 个 1 的稳定数组,末位元素可能为 1,也可能为 0。
如果末位元素为 1,我们需要知道有多少个包含 zero 个 0 和 one − 1 个 1 的稳定数组,再去掉“由于添加了一个 1 而使得原来的稳定数组变得不稳定”的情况。那么有哪些情况会使得原来稳定的数组变得不稳定呢?即原来的稳定数组的末尾连续 1 的个数正好为 limit 个。在这种情况下,添加一个 1 会使得原来稳定的数组变得不稳定。这种情况出现的次数,即为包含 zero 个 0 和 one − 1 − limit 个 1,且末位元素为 0 的稳定数组的个数。
如果末位元素为 0,我们需要知道有多少个包含 zero − 1 个 0 和 one 个 1 的稳定数组,再去掉“由于添加了一个 0 而使得原来的稳定数组变得不稳定”的情况。
这样一来,我们就将问题分解为子问题了,可以用动态规划求解。用函数 dp(zero, one, lastBit),来求解包含 zero 个 0 和 one 个 1,并且末位元素为 lastBit 的稳定数组的个数,其中 lastBit 为 0 或 1。根据上面的讨论,可以得到递推公式:
dp(zero, one, 0) = dp(zero − 1, one, 0) +dp(zero − 1, one, 1) − dp(zero − 1 − limit, one, 1)
dp(zero, one, 1) = dp(zero, one − 1, 0) + dp(zero, one − 1, 1) − dp(zero,one − 1 − limit, 0)。
另外考虑边界情况。如果 zero 为 0,那么当 lastBit 为 1 或者 one 大于 limit 时,不存在这样的稳定数组,返回 0,否则返回 1。如果 zero 为 1,也有对应的结论。
我们用记忆化搜索的方式来计算结果,记录所有的中间状态,最终返回 dp(zero, one, 0) + dp(zero, one, 1) 取模后的结果。
代码:
class Solution {
public:
int numberOfStableArrays(int zero, int one, int limit) {
int mod = 1e9 + 7;
vector<vector<vector<int>>> memo(zero + 1, vector<vector<int>>(one + 1, vector<int>(2, -1)));
function<int(int, int, int)> dp = [&](int zero, int one, int lastBit) -> int {
if (zero == 0) {
return (lastBit == 0 || one > limit) ? 0 : 1;
} else if (one == 0) {
return (lastBit == 1 || zero > limit) ? 0 : 1;
}
if (memo[zero][one][lastBit] == -1) {
int res = 0;
if (lastBit == 0) {
res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1)) % mod;
if (zero > limit) {
res = (res - dp(zero - limit - 1, one, 1) + mod) % mod;
}
} else {
res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % mod;
if (one > limit) {
res = (res - dp(zero, one - limit - 1, 0) + mod) % mod;
}
}
memo[zero][one][lastBit] = res % mod;
}
return memo[zero][one][lastBit];
};
return (dp(zero, one, 0) + dp(zero, one, 1)) % mod;
}
};