拓扑序计数+树形DP
题目
链接:https://ac.nowcoder.com/acm/contest/38630/F
思路
每个公司是一棵树,有n家公司,可以将这n家公司连到一个虚拟的根上。总共的排队方案就等于这个棵的排队方案树。为了满足排队是顺序的,所以我们要求的就是这棵树的拓扑序个数。用树形DP来求解。
f[u]
: 以u
为根的子树的拓扑序数
sz[u]
: 以u
为根的子树的大小(节点的数量)
如何计算一个棵的拓扑序数?
我们先来看只有两个子树的情况:
如上图所示,1号节点一定是第1为,那么剩下还有3位置,有三种可能:
- 1 2 3 4
- 1 3 2 4
- 1 2 4 3
这就用到了概率论的知识,有三个盒子,从中要选2个盒子给2、4
放置,并且是按照顺序的2 4
放置,因为是拓扑序,所以每个子树的相对顺序是一定的,所以最后计算的结果就是:
$$ f[u]=f[v1]⋅f[v2]⋅C(sz[v1]+sz[v2],sz[v1]) $$
当树为二叉树时,将两个子树v1,v2
进行合并:即先把各子树的方案数乘起来算出总方案,然后考虑各子树元素的相对排列顺序,即在总的节点个数中选sz[v1]
排sz[v1]
个位置,剩下的顺序就固定了,保证每颗子树的相对拓扑序不变。
如何计算组合数
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
typedef priority_queue<int, vector<int>, less<int>> Q;
#define x first
#define y second
#define endl '\n'
#define ppb pop_back
#define pb push_back
#define pf push_front
#define YES cout << "YES" << endl
#define Yes cout << "Yes" << endl
#define yes cout << "yes" << endl
#define NO cout << "NO" << endl
#define No cout << "No" << endl
#define no cout << "no" << endl
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define mset(x, a) memset(x, a, sizeof(x))
#define rep(i, l, r) for (LL i = l; i <= (r); ++i)
#define per(i, r, l) for (LL i = r; i >= (l); --i)
const int N = 1e5 + 10, inf = 0x3f3f3f3f, mod = 1e9 + 7;
vector<int> v[N];
int n;
int sz[N];
ll fac[N];
ll inv[N];
ll f[N];
ll qsm(ll a, ll b)
{
ll ans = 1;
while (b)
{
if (b & 1)
ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}
void dfs(int u)
{
f[u] = 1;
sz[u] = 1;
for (int i = 0; i < v[u].size(); i++)
{
int j = v[u][i];
dfs(j);
sz[u] += sz[j];
f[u] = (f[u] * f[j]) % mod * inv[sz[j]] % mod;
}
f[u] = f[u] * fac[sz[u] - 1] % mod;
}
void solve()
{
fac[0] = 1;
inv[0] = 1;
for (int i = 1; i < N; i++)
{
fac[i] = fac[i - 1] * i % mod;
inv[i] = inv[i - 1] * qsm(i, mod - 2) % mod;
}
cin >> n;
ll ans = 1;
int cnt = 0;
for (int i = 1; i <= n; i++)
{
int c;
cin >> c;
v[0].pb(cnt + 1);
for (int j = 2; j <= c; j++)
{
int u;
cin >> u;
v[cnt + u].pb(cnt + j);
}
cnt += c;
}
dfs(0);
cout << f[0];
}
signed main()
{
#ifdef Xin
freopen("in.in", "r", stdin);
freopen("out.out", "w", stdout);
#endif
int T = 1;
while (T--)
solve();
return 0;
}
感谢大佬的文章: