此题用倍增在线查询更好,我只是练练Tarjan算法的手,倍增时间复杂度为O ( n l o n g n ) , 查 询 O ( l o g n ) , t a r j a n 离 线 l c a 算 法 O ( n + q ) O(nlong_n),查询O(log_n),tarjan离线lca算法O(n+q)O(nlongn),查询O(logn),tarjan离线lca算法O(n+q)
Tarjan算法原理:求出每两个点的LCA之后,树上两点的最距离d i s [ a ] [ b ] = d e p [ a ] + d e p [ b ] − 2 ∗ d e p [ L C A ( a , b ) ] dis[a][b]=dep[a]+dep[b]-2*dep[LCA(a,b)]dis[a][b]=dep[a]+dep[b]−2∗dep[LCA(a,b)],O ( 1 ) O(1)O(1)求出
而且有vis数组 fa要不要无所谓
#include<bits/stdc++.h> using namespace std; const int N=4e4+10; typedef pair<int,int> PII; vector<int>g[N]; struct node { int x; int y; int z; }; vector<node>to[N]; int vis[N]; int f[N]; int ans[N]; int find(int x) { if(f[x]==x) return x; return f[x]=find(f[x]); } void Tarjan(int u,int fa) { vis[u]=1; for(auto x:g[u]) { if(x==fa) continue; if(!vis[x]) { Tarjan(x,u); f[x]=u; } } for(auto [x,y,z]:to[u]) { if(vis[x]==2) { int t=find(x); if(t==u) // u是x的lca { if(z==1) { ans[y]=1; }else ans[y]=2; }else ans[y]=0; } } vis[u]=2;//不划分成三个状态也行 } int main() { int n,m; cin>>n; for(int i=1;i<N;i++) f[i]=i; int root=0; for(int i=1;i<=n;i++) { int a,b; cin>>a>>b; if(b==-1) { root=a; } else { g[a].push_back(b); g[b].push_back(a); } } cin>>m; for(int i=1;i<=m;i++) { int a,b; cin>>a>>b; to[a].push_back({b,i,1}); to[b].push_back({a,i,2}); } Tarjan(root,-1); for(int i=1;i<=m;i++) { cout<<ans[i]<<endl; } }
O(n)求距离
#include<bits/stdc++.h> using namespace std; const int N=1e4+10,M=2e4+19; typedef pair<int,int> PII; int d[N];//i点离根节点的距离 int n,m; vector<PII>g[N]; vector<PII>res[M]; int f[N]; int ans[M]; int st[N]; int find(int x) { if(f[x]==x) return x; return f[x]=find(f[x]); } void dfs(int u,int fa) { for(auto [x,y]:g[u]) { if(x==fa) continue; d[x]=d[u]+y; dfs(x,u); } } void Tarjan(int u) { st[u]=1; for(auto [x,y]:g[u]) { if(!st[x]) { Tarjan(x); f[x]=u; } } for(auto [x,y]:res[u]) { if(st[x]==2) { int anc=find(x); ans[y]=d[x]+d[u]-2*d[anc]; } } st[u]=2; } int main() { int n; cin>>n>>m; for(int i=1;i<=n;i++) f[i]=i; for(int i=1;i<=n-1;i++) { int x,y,k; cin>>x>>y>>k; g[x].push_back({y,k}); g[y].push_back({x,k}); } for(int i=1;i<=m;i++) { int x,y; cin>>x>>y; res[x].push_back({y,i}); res[y].push_back({x,i}); } d[1]=0; dfs(1,-1); Tarjan(1); for(int i=1;i<=m;i++) { cout<<ans[i]<<endl; } return 0; }