收集所有金币可获得的最大积分

2025-01-23

题目:

有一棵由 n 个节点组成的无向树,以 0 为根节点,节点编号从 0n - 1 。给你一个长度为 n - 1 的二维 整数 数组 edges ,其中 edges[i] = [ai, bi] 表示在树上的节点 aibi 之间存在一条边。另给你一个下标从 0 开始、长度为 n 的数组 coins 和一个整数 k ,其中 coins[i] 表示节点 i 处的金币数量。

从根节点开始,你必须收集所有金币。要想收集节点上的金币,必须先收集该节点的祖先节点上的金币。

节点 i 上的金币可以用下述方法之一进行收集:

  • 收集所有金币,得到共计 coins[i] - k 点积分。如果 coins[i] - k 是负数,你将会失去 abs(coins[i] - k) 点积分。
  • 收集所有金币,得到共计 floor(coins[i] / 2) 点积分。如果采用这种方法,节点 i 子树中所有节点 j 的金币数 coins[j] 将会减少至 floor(coins[j] / 2)

返回收集 所有 树节点的金币之后可以获得的最大积分。

示例 1:

img

输入:edges = [[0,1],[1,2],[2,3]], coins = [10,10,3,3], k = 5
输出:11                        
解释:
使用第一种方法收集节点 0 上的所有金币。总积分 = 10 - 5 = 5 。
使用第一种方法收集节点 1 上的所有金币。总积分 = 5 + (10 - 5) = 10 。
使用第二种方法收集节点 2 上的所有金币。所以节点 3 上的金币将会变为 floor(3 / 2) = 1 ,总积分 = 10 + floor(3 / 2) = 11 。
使用第二种方法收集节点 3 上的所有金币。总积分 =  11 + floor(1 / 2) = 11.
可以证明收集所有节点上的金币能获得的最大积分是 11 。 

示例 2:

img

输入:edges = [[0,1],[0,2]], coins = [8,4,4], k = 0
输出:16
解释:
使用第一种方法收集所有节点上的金币,因此,总积分 = (8 - 0) + (4 - 0) + (4 - 0) = 16 。

提示:

  • n == coins.length
  • 2 <= n <= 105
  • 0 <= coins[i] <= 104
  • edges.length == n - 1
  • 0 <= edges[i][0], edges[i][1] < n
  • 0 <= k <= 104

思路:

floor(coins[i] / 2) 等价于 coins[i] >> 1。
右移运算是可以叠加的,即 (x >> 1) >> 1 等于 x >> 2。
我们可以在递归的过程中,额外记录从根节点递归到当前节点的过程中,一共执行了多少次右移,也就是子树中的每个节点值需要右移的次数。
故定义 dfs(i,j) 表示递归到以 i 为根的子树,在上面已经执行了 j 次右移的前提下,我们在这棵子树中最多可以得到多少积分。

用「选或不选」来思考,即是否执行右移:
	不右移:答案为 (coins[i] >> j)−k 加上 i 的每个子树 ch 的 dfs(ch,j)。
	右移:答案为 coins[i] >> (j+1) 加上 i 的每个子树 ch 的 dfs(ch,j+1)。

两种情况取最大值,得
dfs(i,j) = max 
(coins[i] >> j)−k+∑ch dfs(ch,j) 
(coins[i] >> (j+1))+∑ch dfs(ch,j+1)
递归入口:dfs(0,0)。其中 i=0 表示根节点。一开始没有执行右移,所以 j=0。

细节
一个数最多右移多少次,就变成 0 了?
设 w 是 coins[i] 的二进制长度,那么 coins[i] 右移 w 次后就是 0 了。
在本题的数据范围下,w≤14。
所以如果在递归过程中发现 j+1=14,就不执行右移,因为此时 dfs(ch,j+1) 子树中的每个节点值都要右移 14 次,算出的结果一定是 0。既然都知道递归的结果了,那就不需要递归了。

此外,为避免错把父亲当作儿子,可以额外传入 fa 表示父节点,遍历 i 的邻居时,跳过邻居节点是 fa 的情况。

代码:

class Solution {
public:
    int maximumPoints(vector<vector<int>>& edges, vector<int>& coins, int k) {
        int n = coins.size();
        vector<vector<int>> g(n);
        for (auto& e: edges) {
            int x = e[0], y = e[1];
            g[x].push_back(y);
            g[y].push_back(x);
        }

        array<int, 14> init_val;
        ranges::fill(init_val, -1); // -1 表示没有计算过
        vector memo(n, init_val);
        auto dfs = [&](this auto&& dfs, int i, int j, int fa) {
            int& res = memo[i][j]; // 注意这里是引用
            if (res != -1) { // 之前计算过
                return res;
            }
            int res1 = (coins[i] >> j) - k;
            int res2 = coins[i] >> (j + 1);
            for (int ch : g[i]) {
                if (ch == fa) continue;
                res1 += dfs(ch, j, i); // 不右移
                if (j < 13) { // j+1 >= 14 相当于 res2 += 0,无需递归
                    res2 += dfs(ch, j + 1, i); // 右移
                }
            }
            return res = max(res1, res2); // 记忆化
        };
        return dfs(0, 0, -1);
    }
};