给你一个下标从 0 开始的整数数组 nums
和一个 正 整数 k
。
你可以对数组执行以下操作 任意次 :
- 选择两个互不相同的下标
i
和j
,同时 将nums[i]
更新为(nums[i] AND nums[j])
且将nums[j]
更新为(nums[i] OR nums[j])
,OR
表示按位 或 运算,AND
表示按位 与 运算。
你需要从最终的数组里选择 k
个元素,并计算它们的 平方 之和。
请你返回你可以得到的 最大 平方和。
由于答案可能会很大,将答案对 10^9 + 7
取余 后返回。
示例 1:
输入:nums = [2,6,5,8], k = 2
输出:261
解释:我们可以对数组执行以下操作:
- 选择 i = 0 和 j = 3 ,同时将 nums[0] 变为 (2 AND 8) = 0 且 nums[3] 变为 (2 OR 8) = 10 ,结果数组为 nums = [0,6,5,10] 。
- 选择 i = 2 和 j = 3 ,同时将 nums[2] 变为 (5 AND 10) = 0 且 nums[3] 变为 (5 OR 10) = 15 ,结果数组为 nums = [0,6,0,15] 。
从最终数组里选择元素 15 和 6 ,平方和为 152 + 62 = 261 。
261 是可以得到的最大结果。
示例 2:
输入:nums = [4,5,4,7], k = 3
输出:90
解释:不需要执行任何操作。
选择元素 7 ,5 和 4 ,平方和为 72 + 52 + 42 = 90 。
90 是可以得到的最大结果。
提示:
1 <= k <= nums.length <= 10^5
1 <= nums[i] <= 10^9
解法 位运算+贪心+哈希表
一个直接的感受是,如果将一个数和其他更多的数按位与,则结果会越来越小;如果将一个数和其他更多的数按位或,则结果会越来越大。
现在对 n u m s [ i ] nums[i] nums[i] 与 n u m s [ j ] nums[j] nums[j] 同时更新,对于同一个比特位,由于 AND 和 OR 不会改变都为 0 0 0 和都为 1 1 1 的情况,所以操作等价于:把一个数的 0 0 0 和另一个数的同一个比特位上的 1 1 1 交换。
假设交换前两个数是 x , y x,y x,y ,且 x > y x > y x>y 。则把小的数上的 1 1 1 给大的数,假设交换后 x x x 增加了 d d d ,则 y y y 也减少了 d d d 。
- 交换前: x 2 + y 2 x^2 + y^2 x2+y2
- 交换后: ( x + d ) 2 + ( y − d ) 2 = x 2 + y 2 + 2 d ( x − y ) + 2 d 2 > x 2 + y 2 (x+d)^2 + (y - d)^2 = x^2 + y^2 + 2d(x - y) + 2d^2 > x^2 +y^2 (x+d)2+(y−d)2=x2+y2+2d(x−y)+2d2>x2+y2
这说明应该通过交换,让一个数越大越好。相当于。
由于可以操作任意次,那么一定可以「组装」出尽量大的数:做法如下:
- 对每个比特位,统计 n u m s nums nums 数组中的元素在这个比特位上有多少个 1 1 1 ,记到一个长至多为 30 30 30 的 c n t cnt cnt 数组中(因为 1 0 9 < 2 30 10^9 < 2^{30} 109<230)
- 循环 k k k 次。每次循环,组装一个数(记为 x x x):遍历 c n t cnt cnt ,只要 c n t [ i ] > 0 cnt[i] > 0 cnt[i]>0 就将其减一,同时将 2 i 2^i 2i 加到 x x x 中。这样相当于把 1 1 1 尽量聚集在一个数中。
- 把 x 2 x^2 x2 加到答案中。
class Solution {
public:
int maxSum(vector<int>& nums, int k) {
int cnt[31] {};
for (int num : nums)
for (int i = 0; i < 31; ++i)
cnt[i] += ((num >> i) & 1);
long long ans = 0;
const int MOD = 1e9 + 7;
while (k--) {
int x = 0;
for (int i = 0; i < 31; ++i) {
if (cnt[i]) {
x |= 1 << i;
cnt[i]--;
}
}
ans = (ans + (long long)x * x) % MOD;
}
return ans;
}
};
复杂度分析:
- 时间复杂度: O ( n log U ) \mathcal{O}(n\log U) O(nlogU) ,其中 n n n 为 nums \textit{nums} nums 的长度, U = max ( nums ) U=\max(\textit{nums}) U=max(nums) 。
- 空间复杂度: O ( log U ) \mathcal{O}(\log U) O(logU) 。