原题链接
题意:
给定一棵边权都是0或1的树,求有效点对的数量。有效点对的定义为:从u到v,经过权值为1的边后不会再经过权值为0的边。
思路:
考虑树形DP。
dp[i][j]表示以i为根节点的子树里到i的路径的状态全为j的点的个数,cnt[j]表示该当前遍历的子树里到根节点的路径的状态为j的点的个数。
j的状态无非四种,都从父节点向子节点的路径角度上考虑:路径的边权全为0,路径的边权全为1,路径的边权先1后0,路径的边权先0后1。
状态转移时,我们来枚举i的子节点k,这样就相当于是枚举0和1的分界点(当然也可能为全0或全1的情况,这种说法不准确),那么问题就转化成了两部分,一是从i到k的路径,二是从k到叶子节点的路径,根据乘法原理,两者相乘即是答案。
在乘法原理时,dp数组和cnt数组的匹配也是个关键,比如dp[i][0]表示路径上的边权全为0,那么他可以匹配的就是cnt[0],cnt[1],cnt[3];
代码:
///#pragma GCC optimize(3) ///#pragma GCC optimize("Ofast","unroll-loops","omit-frame-pointer","inline") ///#pragma GCC optimize(2) #include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<ll,ll>PLL; typedef pair<int,int>PII; typedef pair<double,double>PDD; #define I_int ll inline ll read() { ll x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-')f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=x*10+ch-'0'; ch=getchar(); } return x*f; } char F[200]; inline void out(I_int x) { if (x == 0) return (void) (putchar('0')); I_int tmp = x > 0 ? x : -x; if (x < 0) putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0) putchar(F[--cnt]); //cout<<" "; } ll ksm(ll a,ll b,ll p) { ll res=1; while(b) { if(b&1)res=res*a%p; a=a*a%p; b>>=1; } return res; } const int inf=0x3f3f3f3f,mod=998244353; const ll INF = 0x3f3f3f3f3f3f3f3f; const int maxn=200000+100,maxm=3e5+7,N=1e6+7; const double PI = atan(1.0)*4; int h[maxn],idx; struct node{ int e,ne,w; }edge[maxn*2]; int n; void add(int u,int v,int w){ edge[idx]={v,h[u],w}; h[u]=idx++; } ll dp[maxn][4],cnt[4],res; void dfs(int u,int fa){ for(int i=h[u];~i;i=edge[i].ne){ memset(cnt,0,sizeof cnt); int j=edge[i].e,w=edge[i].w; if(j==fa) continue; dfs(j,u); if(w==0){///0 cnt[0]=dp[j][0]+1;///全0 cnt[1]=0;///全1 cnt[2]=0;///先1后0 cnt[3]=dp[j][1]+dp[j][3];///先0后1 } else{///1 cnt[0]=0;///全0 cnt[1]=dp[j][1]+1;///全1 cnt[2]=dp[j][0]+dp[j][2];///先1后0 cnt[3]=0;///先0后1 } res=res+dp[u][0]*(cnt[1]+cnt[3]+cnt[0]*2);///全0的和全1、先0后1,全0 res=res+dp[u][1]*(cnt[2]+cnt[0]+cnt[1]*2);///全1的和全0,先1后0,全1 res=res+dp[u][2]*(cnt[1]);///先1后0的和全1的 res=res+dp[u][3]*(cnt[0]);///先0后1的和全0的 for(int k=0;k<4;k++) dp[u][k]+=cnt[k]; } res=res+dp[u][0]*2+dp[u][1]*2+dp[u][2]+dp[u][3]; } int main() { n=read(); memset(h,-1,sizeof h); for(int i=1;i<n;i++){ int u=read(),v=read(),w=read(); add(u,v,w);add(v,u,w); } dfs(1,-1); printf("%lld\n",res); return 0; }