题目描述
这是 LeetCode 上的 1711. 大餐计数 ,难度为 中等。
Tag : 「哈希表」、「位运算」
大餐 是指 恰好包含两道不同餐品 的一餐,其美味程度之和等于 2 的幂。
你可以搭配 任意 两道餐品做一顿大餐。
给你一个整数数组 deliciousness ,其中 deliciousness[i] 是第 i 道餐品的美味程度,返回你可以用数组中的餐品做出的不同 大餐 的数量。结果需要对 10^9109 + 7 取余。
注意,只要餐品下标不同,就可以认为是不同的餐品,即便它们的美味程度相同。
示例 1:
输入:deliciousness = [1,3,5,7,9] 输出:4 解释:大餐的美味程度组合为 (1,3) 、(1,7) 、(3,5) 和 (7,9) 。 它们各自的美味程度之和分别为 4 、8 、8 和 16 ,都是 2 的幂。 复制代码
示例 2:
输入:deliciousness = [1,1,1,3,3,3,7] 输出:15 解释:大餐的美味程度组合为 3 种 (1,1) ,9 种 (1,3) ,和 3 种 (1,7) 。 复制代码
提示:
- 1 <= deliciousness.length <= 10^5105
- 0 <= deliciousness[i] <= 2^{20}220
枚举前一个数(TLE)
一个朴素的想法是,从前往后遍历 deliciousnessdeliciousness 中的所有数,当遍历到下标 ii 的时候,回头检查下标小于 ii 的数是否能够与 deliciousness[i]deliciousness[i] 相加形成 22 的幂。
这样的做法是 O(n^2)O(n2) 的,防止同样的数值被重复计算,我们可以使用「哈希表」记录某个数出现了多少次,但这并不改变算法仍然是 O(n^2)O(n2) 的。
而且我们需要一个 check
方法来判断某个数是否为 22 的幂:
- 朴素的做法是对 xx 应用试除法,当然因为精度问题,我们需要使用乘法实现试除;
- 另一个比较优秀的做法是利用位运算找到符合「大于等于 xx」的最近的 22 的幂,然后判断是否与 xx 相同。
两种做法差距有多大呢?方法一的复杂度为 O(\log{n})O(logn),方法二为 O(1)O(1)。
根据数据范围 0 <= deliciousness[i] <= 2^{20}0<=deliciousness[i]<=220,方法一最多也就是执行不超过 2222 次循环。
显然,采用何种判断 22 的幂的做法不是关键,在 OJ 判定上也只是分别卡在 60/7060/70 和 62/7062/70 的 TLE 上。
但通过这样的分析,我们可以发现「枚举前一个数」的做法是与 nn 相关的,而枚举「可能出现的 22 的幂」则是有明确的范围,这引导出我们的解法二。
Java 代码:
class Solution { int mod = (int)1e9+7; public int countPairs(int[] ds) { int n = ds.length; long ans = 0; Map<Integer, Integer> map = new HashMap<>(); for (int i = 0; i < n; i++) { int x = ds[i]; for (int other : map.keySet()) { if (check(other + x)) ans += map.get(other); } map.put(x, map.getOrDefault(x, 0) + 1); } return (int)(ans % mod); } boolean check(long x) { // 方法一 // long cur = 1; // while (cur < x) { // cur = cur * 2; // } // return cur == x; // 方法二 return getVal(x) == x; } long getVal(long x) { long n = x - 1; n |= n >>> 1; n |= n >>> 2; n |= n >>> 4; n |= n >>> 8; n |= n >>> 16; return n < 0 ? 1 : n + 1; } } 复制代码
Python3 代码:
class Solution: mod = 10 ** 9 + 7 def countPairs(self, deliciousness: List[int]) -> int: n = len(deliciousness) ans = 0 hashmap = Counter() for i in range(n): x = deliciousness[i] for other in hashmap: if self.check(other+x): ans += hashmap[other] hashmap[x] += 1 return ans % self.mod def check(self, x): """ # 方法一 cur = 1 while cur < x: cur *= 2 return cur == x """ # 方法二 return self.getVal(x) == x def getVal(self, x): n = x - 1 # java中 >>>:无符号右移。无论是正数还是负数,高位通通补0。 Python不需要 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 return 1 if n < 0 else n + 1 复制代码
- 时间复杂度:O(n^2)O(n2)
- 空间复杂度:O(n)O(n)
枚举 2 的幂(容斥原理)
根据对朴素解法的分析,我们可以先使用「哈希表」对所有在 deliciousnessdeliciousness 出现过的数进行统计。
然后对于每个数 xx,检查所有可能出现的 22 的幂 ii,再从「哈希表」中反查 t = i - xt=i−x 是否存在,并实现计数。
一些细节:如果哈希表中存在 t = i - xt=i−x,并且 t = xt=x,这时候方案数应该是 (cnts[x] - 1) * cnts[x](cnts[x]−1)∗cnts[x];其余一般情况则是 cnts[t] * cnts[x]cnts[t]∗cnts[x]。
同时,这样的计数方式,我们对于二元组 (x, t)(x,t) 会分别计数两次(遍历 xx 和 遍历 tt),因此最后要利用容斥原理,对重复计数的进行减半操作。
Java 代码:
class Solution { int mod = (int)1e9+7; int max = 1 << 22; public int countPairs(int[] ds) { Map<Integer, Integer> map = new HashMap<>(); for (int d : ds) map.put(d, map.getOrDefault(d, 0) + 1); long ans = 0; for (int x : map.keySet()) { for (int i = 1; i < max; i <<= 1) { int t = i - x; if (map.containsKey(t)) { if (t == x) ans += (map.get(x) - 1) * 1L * map.get(x); else ans += map.get(x) * 1L * map.get(t); } } } ans >>= 1; return (int)(ans % mod); } } 复制代码
Python3 代码:
class Solution: mod = 10 ** 9 + 7 maximum = 1 << 22 def countPairs(self, deliciousness: List[int]) -> int: hashmap = Counter(deliciousness) ans = 0 for x in hashmap: i = 1 while i < self.maximum: t = i - x if t in hashmap: if t == x: ans += (hashmap[x] - 1) * hashmap[x] else: ans += hashmap[x] * hashmap[t] i <<= 1 ans >>= 1 return ans % self.mod 复制代码
- 时间复杂度:根据数据范围,令 CC 为 2^{21}221。复杂度为 O(n * \log{C})O(n∗logC)
- 空间复杂度:O(n)O(n)
枚举 2 的幂(边遍历边统计)
当然,我们也可以采取「一边遍历一边统计」的方式,这样取余操作就可以放在遍历逻辑中去做,也就顺便实现了不使用 longlong 来计数(以及不使用 %
实现取余)。
Java 代码:
class Solution { int mod = (int)1e9+7; int max = 1 << 22; public int countPairs(int[] ds) { Map<Integer, Integer> map = new HashMap<>(); int ans = 0; for (int x : ds) { for (int i = 1; i < max; i <<= 1) { int t = i - x; if (map.containsKey(t)) { ans += map.get(t); if (ans >= mod) ans -= mod; } } map.put(x, map.getOrDefault(x, 0) + 1); } return ans; } } 复制代码
Python3 代码:
class Solution: mod = 10 ** 9 + 7 maximum = 1 << 22 def countPairs(self, deliciousness: List[int]) -> int: hashmap = defaultdict(int) ans = 0 for x in deliciousness: i = 1 while i < self.maximum: t = i - x if t in hashmap: ans += hashmap[t] if ans >= self.mod: ans -= self.mod i <<= 1 hashmap[x] += 1 return ans 复制代码
- 时间复杂度:根据数据范围,令 CC 为 2^{21}221。复杂度为 O(n * \log{C})O(n∗logC)
- 空间复杂度:O(n)O(n)
最后
这是我们「刷穿 LeetCode」系列文章的第 No.1711
篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。
在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。
为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:github.com/SharingSour… 。
在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。