作者推荐
本博文涉及知识点
深度优先搜索 树 图论 分类讨论
LeetCode2973. 树中每个节点放置的金币数目
给你一棵 n 个节点的 无向 树,节点编号为 0 到 n - 1 ,树的根节点在节点 0 处。同时给你一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] = [ai, bi] 表示树中节点 ai 和 bi 之间有一条边。
给你一个长度为 n 下标从 0 开始的整数数组 cost ,其中 cost[i] 是第 i 个节点的 开销 。
你需要在树中每个节点都放置金币,在节点 i 处的金币数目计算方法如下:
如果节点 i 对应的子树中的节点数目小于 3 ,那么放 1 个金币。
否则,计算节点 i 对应的子树内 3 个不同节点的开销乘积的 最大值 ,并在节点 i 处放置对应数目的金币。如果最大乘积是 负数 ,那么放置 0 个金币。
请你返回一个长度为 n 的数组 coin ,coin[i]是节点 i 处的金币数目。
示例 1:
输入:edges = [[0,1],[0,2],[0,3],[0,4],[0,5]], cost = [1,2,3,4,5,6]
输出:[120,1,1,1,1,1]
解释:在节点 0 处放置 6 * 5 * 4 = 120 个金币。所有其他节点都是叶子节点,子树中只有 1 个节点,所以其他每个节点都放 1 个金币。
示例 2:
输入:edges = [[0,1],[0,2],[1,3],[1,4],[1,5],[2,6],[2,7],[2,8]], cost = [1,4,2,3,5,7,8,-4,2]
输出:[280,140,32,1,1,1,1,1,1]
解释:每个节点放置的金币数分别为:
- 节点 0 处放置 8 * 7 * 5 = 280 个金币。
- 节点 1 处放置 7 * 5 * 4 = 140 个金币。
- 节点 2 处放置 8 * 2 * 2 = 32 个金币。
- 其他节点都是叶子节点,子树内节点数目为 1 ,所以其他每个节点都放 1 个金币。
示例 3:
输入:edges = [[0,1],[0,2]], cost = [1,2,-2]
输出:[0,1,1]
解释:节点 1 和 2 都是叶子节点,子树内节点数目为 1 ,各放置 1 个金币。节点 0 处唯一的开销乘积是 2 * 1 * -2 = -4 。所以在节点 0 处放置 0 个金币。
提示:
2 <= n <= 2 * 104
edges.length == n - 1
edges[i].length == 2
0 <= ai, bi < n
cost.length == n
1 <= |cost[i]| <= 104
edges 一定是一棵合法的树。
分类讨论
情况表面上很多,时间上只有4情况:
{ 1 不足 3 个节点 最大的三个正数的乘积 至少 3 个正数节点 1 个正数和 2 个负数的乘积 至少一个正数节点, 2 个负数节点 0 o t h e r \begin{cases} 1 &不足3个节点 \\ 最大的三个正数的乘积 & 至少3个正数节点\\ 1个正数和2个负数的乘积 & 至少一个正数节点,2个负数节点\\ 0 & other \\ \end{cases}⎩⎨⎧1最大的三个正数的乘积1个正数和2个负数的乘积0不足3个节点至少3个正数节点至少一个正数节点,2个负数节点other
正数节点都只需要记录3个节点,2个不够。
3个负数节点,0个正数节点。值是0。
2个负数节点,0个正数节点。值是1。
注意:cost[i]不会为0。
代码
核心代码
class CNeiBo2 { public: CNeiBo2(int n, bool bDirect, int iBase = 0) :m_iN(n), m_bDirect(bDirect), m_iBase(iBase) { m_vNeiB.resize(n); } CNeiBo2(int n, vector<vector<int>>& edges, bool bDirect, int iBase = 0) :m_iN(n), m_bDirect(bDirect), m_iBase(iBase) { m_vNeiB.resize(n); for (const auto& v : edges) { m_vNeiB[v[0] - iBase].emplace_back(v[1] - iBase); if (!bDirect) { m_vNeiB[v[1] - iBase].emplace_back(v[0] - iBase); } } } inline void Add(int iNode1, int iNode2) { iNode1 -= m_iBase; iNode2 -= m_iBase; m_vNeiB[iNode1].emplace_back(iNode2); if (!m_bDirect) { m_vNeiB[iNode2].emplace_back(iNode1); } } const int m_iN; const bool m_bDirect; const int m_iBase; vector<vector<int>> m_vNeiB; }; class Solution { public: vector<long long> placedCoins(vector<vector<int>>& edges, vector<int>& cost) { m_vAns.resize(cost.size()); m_cost = cost; CNeiBo2 neiBo(cost.size(), edges, false); std::priority_queue<int> maxHeap; std::priority_queue<int, vector<int>, greater<int> > minHeap; DFS(maxHeap, minHeap, neiBo.m_vNeiB, 0, -1); return m_vAns; } void DFS(std::priority_queue<int>& maxHeap, std::priority_queue<int, vector<int>, greater<int> >& minHeap,const vector<vector<int>>& neiBo, int cur, int par) { if (m_cost[cur] >= 0) { minHeap.emplace(m_cost[cur]); } else { maxHeap.emplace(m_cost[cur]); } for (const auto& next : neiBo[cur]) { if (next == par) { continue; } std::priority_queue<int> maxHeap1; std::priority_queue<int, vector<int>, greater<int> > minHeap1; DFS(maxHeap1,minHeap1,neiBo, next, cur); Union(maxHeap, maxHeap1); Union(minHeap, minHeap1); } auto Cal = [&]() { if (maxHeap.size() + minHeap.size() <3 ) { return 1LL; } long long llRet = 0; auto v1 = ToVector(minHeap); auto v2 = ToVector(maxHeap); if (3 == minHeap.size()) { llRet =max(llRet, (long long)v1[0] * v1[1] * v1[2]); } if (minHeap.size()&& (maxHeap.size() >= 2)) { if (v2.size() > 2) { v2.erase(v2.begin()); } llRet = max(llRet, (long long)v1.back() * v2[0] * v2[1]); } return llRet; }; m_vAns[cur] = Cal(); } protected: template<class T> vector<int> ToVector(T heap) { vector<int> v; while (heap.size()) { v.emplace_back(heap.top()); heap.pop(); } T heap2(v.begin(), v.end()); heap2.swap(heap); return v; } template<class T> void Union(T& heap1, T& heap2) { while (heap2.size()) { heap1.emplace(heap2.top()); heap2.pop(); } while (heap1.size() > 3) { heap1.pop(); } } vector<long long> m_vAns; vector<int> m_cost; };
测试用例
template<class T,class T2> void Assert(const T& t1, const T2& t2) { assert(t1 == t2); } template<class T> void Assert(const vector<T>& v1, const vector<T>& v2) { if (v1.size() != v2.size()) { assert(false); return; } for (int i = 0; i < v1.size(); i++) { Assert(v1[i], v2[i]); } } int main() { vector<vector<int>> edges; vector<int> cost; { Solution sln; edges = { {0,1},{0,2},{2,3} }, cost = { 10000, -10000, 10000, -10000 }; auto res = sln.placedCoins(edges, cost); Assert({ 1000000000000,1,1,1 }, res); } { Solution sln; edges = { {0,1},{0,2},{0,3},{0,4},{0,5} }, cost = { 1,2,3,4,5,6 }; auto res = sln.placedCoins(edges, cost); Assert({ 120,1,1,1,1,1 }, res); } { Solution sln; edges = { {0,1},{0,2},{1,3},{1,4},{1,5},{2,6},{2,7},{2,8} }, cost = { 1,4,2,3,5,7,8,-4,2 }; auto res = sln.placedCoins(edges, cost); Assert({ 280,140,32,1,1,1,1,1,1 }, res); } { Solution sln; edges = { {0,1},{0,2} }, cost = { 1,2,-2 }; auto res = sln.placedCoins(edges, cost); Assert({ 0,1,1 }, res); } { Solution sln; edges = { {0,1},{0,2},{0,3},{0,4},{0,5},{0,6},{0,7},{0,8},{0,9},{0,10},{0,11},{0,12},{0,13},{0,14},{0,15},{0,16},{0,17},{0,18},{0,19},{0,20},{0,21},{0,22},{0,23},{0,24},{0,25},{0,26},{0,27},{0,28},{0,29},{0,30},{0,31},{0,32},{0,33},{0,34},{0,35},{0,36},{0,37},{0,38},{0,39},{0,40},{0,41},{0,42},{0,43},{0,44},{0,45},{0,46},{0,47},{0,48},{0,49},{0,50},{0,51},{0,52},{0,53},{0,54},{0,55},{0,56},{0,57},{0,58},{0,59},{0,60},{0,61},{0,62},{0,63},{0,64},{0,65},{0,66},{0,67},{0,68},{0,69},{0,70},{0,71},{0,72},{0,73},{0,74},{0,75},{0,76},{0,77},{0,78},{0,79},{0,80},{0,81},{0,82},{0,83},{0,84},{0,85},{0,86},{0,87},{0,88},{0,89},{0,90},{0,91},{0,92},{0,93},{0,94},{0,95},{0,96},{0,97},{0,98},{0,99} }; cost={-5959, 602, -6457, 7055, -1462, 6347, 7226, -8422, -6088, 2997, -7909, 6433, 5217, 3294, -3792, 7463, 8538, -3811, 5009, 151, 5659, 4458, -1702, -1877, 2799, 9861, -9668, -1765, 2181, -8128, 7046, 9529, 6202, -8026, 6464, 1345, 121, 1922, 7274, -1227, -9914, 3025, 1046, -9368, -7368, 6205, -6342, 8091, -6732, -7620, 3276, 5136, 6871, 4823, -1885, -4005, -3974, -2725, -3845, -8508, 7201, -9566, -7236, -3386, 4021, 6793, -8759, 5066, 5879, -5171, 1011, 1242, 8536, -8405, -9646, -214, 2251, -9934, -8820, 6206, 1006, 1318, -9712, 7230, 5608, -4601, 9185, 346, 3056, 8913, -2454, -3445, -4295, 4802, -8852, -6121, -4538, -5580, -9246, -6462}; auto res = sln.placedCoins(edges, cost); sort(cost.begin(), cost.end()); Assert({ 971167251036, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, res); } }
第二版
class CNeiBo2
{
public:
CNeiBo2(int n, bool bDirect, int iBase = 0) :m_iN(n), m_bDirect(bDirect), m_iBase(iBase)
{
m_vNeiB.resize(n);
}
CNeiBo2(int n, vector<vector>& edges, bool bDirect, int iBase = 0) :m_iN(n), m_bDirect(bDirect), m_iBase(iBase)
{
m_vNeiB.resize(n);
for (const auto& v : edges)
{
m_vNeiB[v[0] - iBase].emplace_back(v[1] - iBase);
if (!bDirect)
{
m_vNeiB[v[1] - iBase].emplace_back(v[0] - iBase);
}
}
}
inline void Add(int iNode1, int iNode2)
{
iNode1 -= m_iBase;
iNode2 -= m_iBase;
m_vNeiB[iNode1].emplace_back(iNode2);
if (!m_bDirect)
{
m_vNeiB[iNode2].emplace_back(iNode1);
}
}
const int m_iN;
const bool m_bDirect;
const int m_iBase;
vector<vector> m_vNeiB;
};
class Solution {
public:
vector placedCoins(vector<vector>& edges, vector& cost) {
m_cost = cost;
m_vAns.resize(cost.size());
CNeiBo2 neiBo(cost.size(), edges, false);
multiset<int, greater> more0;
multiset less0;
DFS(more0, less0, neiBo.m_vNeiB, 0, -1);
return m_vAns;
}
void DFS(multiset<int, greater>& more0, multiset& less0, vector<vector>& neiBo, int cur, int par)
{
if (m_cost[cur] > 0)
{
more0.emplace(m_cost[cur]);
}
else
{
less0.emplace(m_cost[cur]);
}
for (const auto& next : neiBo[cur])
{
if (next == par)
{
continue;
}
multiset<int, greater> more01;
multiset less01;
DFS(more01, less01, neiBo, next, cur);
Union(more0, more01);
Union(less0, less01);
}
long long& llRet = m_vAns[cur];
if (more0.size() + less0.size() < 3)
{
llRet = 1;
return;
}
if (more0.size() >= 3)
{
auto it = more0.begin();
llRet = max(llRet, (long long)*(it++) * *(it++) * (it++));
}
if (more0.size() && (less0.size() >= 2))
{
llRet = max(llRet, (long long)(more0.begin()) * *(less0.begin()) * *(std::next(less0.begin())));
}
};
template
void Union(T& set1, const T& set2)
{
for (const auto& n : set2)
{
set1.emplace(n);
}
while (set1.size() > 3)
{
set1.erase(prev(set1.end()));
}
}
vector m_cost;
vector m_vAns;
};