作者推荐
本文涉及的基础知识点
数据结构 双堆
LeetCode3013. 将数组分成最小总代价的子数组 II
给你一个下标从 0 开始长度为 n 的整数数组 nums 和两个 正 整数 k 和 dist 。
一个数组的 代价 是数组中的 第一个 元素。比方说,[1,2,3] 的代价为 1 ,[3,4,1] 的代价为 3 。
你需要将 nums 分割成 k 个 连续且互不相交 的子数组,满足 第二 个子数组与第 k 个子数组中第一个元素的下标距离 不超过 dist 。换句话说,如果你将 nums 分割成子数组 nums[0…(i1 - 1)], nums[i1…(i2 - 1)], …, nums[ik-1…(n - 1)] ,那么它需要满足 ik-1 - i1 <= dist 。
请你返回这些子数组的 最小 总代价。
示例 1:
输入:nums = [1,3,2,6,4,2], k = 3, dist = 3
输出:5
解释:将数组分割成 3 个子数组的最优方案是:[1,3] ,[2,6,4] 和 [2] 。这是一个合法分割,因为 ik-1 - i1 等于 5 - 2 = 3 ,等于 dist 。总代价为 nums[0] + nums[2] + nums[5] ,也就是 1 + 2 + 2 = 5 。
5 是分割成 3 个子数组的最小总代价。
示例 2:
输入:nums = [10,1,2,2,2,1], k = 4, dist = 3
输出:15
解释:将数组分割成 4 个子数组的最优方案是:[10] ,[1] ,[2] 和 [2,2,1] 。这是一个合法分割,因为 ik-1 - i1 等于 3 - 1 = 2 ,小于 dist 。总代价为 nums[0] + nums[1] + nums[2] + nums[3] ,也就是 10 + 1 + 2 + 2 = 15 。
分割 [10] ,[1] ,[2,2,2] 和 [1] 不是一个合法分割,因为 ik-1 和 i1 的差为 5 - 1 = 4 ,大于 dist 。
15 是分割成 4 个子数组的最小总代价。
示例 3:
输入:nums = [10,8,18,9], k = 3, dist = 1
输出:36
解释:将数组分割成 4 个子数组的最优方案是:[10] ,[8] 和 [18,9] 。这是一个合法分割,因为 ik-1 - i1 等于 2 - 1 = 1 ,等于 dist 。总代价为 nums[0] + nums[1] + nums[2] ,也就是 10 + 8 + 18 = 36 。
分割 [10] ,[8,18] 和 [9] 不是一个合法分割,因为 ik-1 和 i1 的差为 3 - 1 = 2 ,大于 dist 。
36 是分割成 3 个子数组的最小总代价。
提示:
3 <= n <= 105
1 <= nums[i] <= 109
3 <= k <= n
k - 2 <= dist <= n - 2
分析
本题等效于:nums[0]必选, 从nums[left,left+dist]中选择k-1个数,使得和最小。
设计容器:存放dist+1个数,方便读取k-1个最小数的和。读、写的时间复杂度都是:O(logn)。
标准做法是双堆(优先队列),用双mulset好理解。
代码
核心代码
class CTopK { public: CTopK(int k):m_iK(k) { } void Add(int num) { m_setK1.emplace(num); OnAdd(num); Do(); } void Erase(int num) { auto it1 = m_setOther.find(num); if (m_setOther.end() != it1 ) { m_setOther.erase(it1); } else { auto it2 = m_setK1.find(num); if (m_setK1.end() != it2) { OnErase(num); m_setK1.erase(it2); } } Do(); while ((m_setK1.size() < m_iK) && m_setOther.size()) { m_setK1.emplace(*m_setOther.begin()); OnAdd(*m_setOther.begin()); m_setOther.erase(m_setOther.begin()); } while (m_setK1.size() && m_setOther.size() && (*m_setK1.rbegin() > *m_setOther.begin())) { int tmp = *m_setK1.rbegin(); OnErase(tmp); m_setK1.erase(std::prev(m_setK1.end())); m_setK1.emplace(*m_setOther.begin()); OnAdd(*m_setOther.begin()); m_setOther.erase(m_setOther.begin()); m_setOther.emplace(tmp); } } protected: virtual void OnAdd(int num) = 0; virtual void OnErase(int num) = 0; void Do() { while (m_setK1.size() > m_iK) { m_setOther.emplace(*m_setK1.rbegin()); OnErase(*m_setK1.rbegin()); m_setK1.erase(std::prev(m_setK1.end())); } } const int m_iK; std::multiset<int> m_setK1, m_setOther; }; class CMyTop : public CTopK { public: using CTopK::CTopK; // 通过 CTopK 继承 virtual void OnAdd(int num) override { m_llSum += num; } virtual void OnErase(int num) override { m_llSum -= num; } long long m_llSum; }; class Solution { public: long long minimumCost(vector<int>& nums, int k, int dist) { CMyTop top(k - 1); for (int i = 1; i <= 1+dist; i++) { top.Add(nums[i]); } long long llRet = top.m_llSum; for (int i = 2; i + k - 1 <= nums.size(); i++) { if (i + dist < nums.size()) { top.Add(nums[i + dist]); } top.Erase(nums[i - 1]); llRet = min(llRet, top.m_llSum); } return llRet + nums.front(); } };
测试用例
template<class T,class T2> void Assert(const T& t1, const T2& t2) { assert(t1 == t2); } template<class T> void Assert(const vector<T>& v1, const vector<T>& v2) { if (v1.size() != v2.size()) { assert(false); return; } for (int i = 0; i < v1.size(); i++) { Assert(v1[i], v2[i]); } } int main() { vector<int> nums; int k, dist; { Solution sln; nums = { 1,3,2,6,4,2 }, k = 3, dist = 3; auto res = sln.minimumCost(nums, k, dist); Assert(5, res); } { Solution sln; nums = { 10,1,2,2,2,1 }, k = 4, dist = 3; auto res = sln.minimumCost(nums, k, dist); Assert(15, res); } { Solution sln; nums = { 10,8,18,9 }, k = 3, dist = 1; auto res = sln.minimumCost(nums, k, dist); Assert(36, res); } }
优化
class CTop2 { public: CTop2(int k) :m_iK(k) { } void Add(int num) { if (m_top.empty() || (num <= *m_top.rbegin())) { m_top.emplace(num); m_llSum += num; } else { m_other.emplace(num); } Adust(); } void Sub(int num) { auto it1 = m_top.find(num); if (m_top.end() != it1) { m_top.erase(it1); m_llSum -= num; Adust(); return; } auto it2 = m_other.find(num); if (m_other.end() != it2) { m_other.erase(it2); } Adust(); } void Adust() { while ((m_top.size() < m_iK)&& m_other.size()) { m_top.emplace(*m_other.begin()); m_llSum += *m_other.begin(); m_other.erase(m_other.begin()); } while (m_top.size() > m_iK) { m_other.emplace(*m_top.rbegin()); m_llSum -= *m_top.rbegin(); m_top.erase(prev(m_top.end())); } } std::multiset<int> m_top, m_other; long long m_llSum = 0; const int m_iK; }; class Solution { public: long long minimumCost(vector<int>& nums, int k, int dist) { CTop2 top(k - 1); int i = 1; for (; i <= 1+dist; i++) { top.Add(nums[i]); } long long iRet = top.m_llSum; for (; i < nums.size(); i++) { top.Add(nums[i]); if (i - dist - 1 > 0) { top.Sub(nums[i - dist - 1]); } iRet = min(iRet, top.m_llSum); } return iRet + nums[0]; } };