给你一个下标从 0 开始的整数数组 nums 和一个  整数 k 。

你可以对数组执行以下操作 任意次 :

  • 选择两个互不相同的下标 i 和 j ,同时 将 nums[i] 更新为 (nums[i] AND nums[j]) 且将 nums[j] 更新为 (nums[i] OR nums[j]) ,OR 表示按位  运算,AND 表示按位  运算。

你需要从最终的数组里选择 k 个元素,并计算它们的 平方 之和。

请你返回你可以得到的 最大 平方和。

由于答案可能会很大,将答案对 10^9 + 7 取余 后返回。

示例 1:

输入:nums = [2,6,5,8], k = 2
输出:261
解释:我们可以对数组执行以下操作:
- 选择 i = 0 和 j = 3 ,同时将 nums[0] 变为 (2 AND 8) = 0 且 nums[3] 变为 (2 OR 8) = 10 ,结果数组为 nums = [0,6,5,10]- 选择 i = 2 和 j = 3 ,同时将 nums[2] 变为 (5 AND 10) = 0 且 nums[3] 变为 (5 OR 10) = 15 ,结果数组为 nums = [0,6,0,15] 。
从最终数组里选择元素 156 ,平方和为 152 + 62 = 261261 是可以得到的最大结果。

示例 2:

输入:nums = [4,5,4,7], k = 3
输出:90
解释:不需要执行任何操作。
选择元素 754 ,平方和为 72 + 52 + 42 = 9090 是可以得到的最大结果。

提示:

  • 1 <= k <= nums.length <= 10^5
  • 1 <= nums[i] <= 10^9

解法 位运算+贪心+哈希表

一个直接的感受是,如果将一个数和其他更多的数按位与,则结果会越来越小;如果将一个数和其他更多的数按位或,则结果会越来越大。

现在对 n u m s [ i ] nums[i] nums[i] n u m s [ j ] nums[j] nums[j] 同时更新,对于同一个比特位,由于 AND 和 OR 不会改变都为 0 0 0 和都为 1 1 1 的情况,所以操作等价于:把一个数的 0 0 0 和另一个数的同一个比特位上的 1 1 1 交换

假设交换前两个数是 x , y x,y x,y ,且 x > y x > y x>y 。则把小的数上的 1 1 1 给大的数,假设交换后 x x x 增加了 d d d ,则 y y y 也减少了 d d d

  • 交换前: x 2 + y 2 x^2 + y^2 x2+y2
  • 交换后: ( x + d ) 2 + ( y − d ) 2 = x 2 + y 2 + 2 d ( x − y ) + 2 d 2 > x 2 + y 2 (x+d)^2 + (y - d)^2 = x^2 + y^2 + 2d(x - y) + 2d^2 > x^2 +y^2 (x+d)2+(yd)2=x2+y2+2d(xy)+2d2>x2+y2

这说明应该通过交换,让一个数越大越好。相当于。

由于可以操作任意次,那么一定可以「组装」出尽量大的数:做法如下:

  1. 对每个比特位,统计 n u m s nums nums 数组中的元素在这个比特位上有多少个 1 1 1 ,记到一个长至多为 30 30 30 c n t cnt cnt 数组中(因为 1 0 9 < 2 30 10^9 < 2^{30} 109<230
  2. 循环 k k k 次。每次循环,组装一个数(记为 x x x):遍历 c n t cnt cnt ,只要 c n t [ i ] > 0 cnt[i] > 0 cnt[i]>0 就将其减一,同时将 2 i 2^i 2i 加到 x x x 中。这样相当于把 1 1 1 尽量聚集在一个数中。
  3. x 2 x^2 x2 加到答案中。
class Solution {
public:
    int maxSum(vector<int>& nums, int k) {
        int cnt[31] {};
        for (int num : nums)
            for (int i = 0; i < 31; ++i)
                cnt[i] += ((num >> i) & 1);
        long long ans = 0;
        const int MOD = 1e9 + 7;
        while (k--) {
            int x = 0;
            for (int i = 0; i < 31; ++i) {
                if (cnt[i]) {
                    x |= 1 << i;
                    cnt[i]--;
                }
            }
            ans = (ans + (long long)x * x) % MOD;
        }
        return ans;
    }
};

复杂度分析:

  • 时间复杂度: O ( n log ⁡ U ) \mathcal{O}(n\log U) O(nlogU) ,其中 n n n nums \textit{nums} nums 的长度, U = max ⁡ ( nums ) U=\max(\textit{nums}) U=max(nums)
  • 空间复杂度: O ( log ⁡ U ) \mathcal{O}(\log U) O(logU)
10-19 15:53