思路:
首先从起点跟终点分别跑一遍最短路,假设d i s 1 [ i ] 表示从1出发到i的最短路,d i s 2 [ i ]表示从n出发到i的最短路。
假设中间的两个特殊点为a , b,在a , b连边后新的最短路为d i s 1 [ a ] + d i s 2 [ b ] + 1。
如果暴力枚举特殊点的话,时间复杂度就会变成O ( k 2 ),显然不可取。
比赛时就卡这了,一直没想到优化的方法,赛后看了巨巨的代码才懂。
正常应该有两种类型的最短路:
1 − > a − > b − > n
1 − > b − > a − > n
假设最短路为第一种情况,那么会有:
d i s 1 [ a ] + d i s 2 [ b ] + 1 < d i s 1 [ b ] + 1 + d i s 2 [ a ]
移项与消掉重复项后得:
d i s 1 [ a ] − d i s 2 [ a ] < d i s 1 [ b ] − d i s 2 [ b ]
所以可以按这个排序后,枚举b并且维护最大的dis1[a],取最大值就好了,时间复杂度降为O ( k l o g k ).
要注意的是,向图中加边不会让最短路增加,所以说如果最大值大于原先的最短路,答案还是原先的最短路。
代码:
const int maxn=2e5+10,N=maxn*2,inf=0x3f3f3f3f; int n,m,k,a[maxn],dis1[maxn],dis2[maxn],st[maxn]; int h[N], w[N], e[N], ne[N], idx; int dist[N]; struct node{ int id,ds,dt; }b[maxn]; bool cmp(node a,node b){ return a.ds-a.dt<b.ds-b.dt; } void add(int a, int b, int c) { e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ; } void spfa(int s,int dist[]) { memset(st,0,sizeof st); dist[s] = 0; queue<int> q; q.push(s); st[s] = true; while (q.size()) { int t = q.front(); q.pop(); st[t] = false; for (int i = h[t]; i != -1; i = ne[i]) { int j = e[i]; if (dist[j] > dist[t] + w[i]) { dist[j] = dist[t] + w[i]; if (!st[j]) { q.push(j); st[j] = true; } } } } } int main(){ memset(h,-1,sizeof h); memset(dis1,0x3f,sizeof dis1); memset(dis2,0x3f,sizeof dis2); n=read,m=read,k=read; rep(i,1,k) a[i]=read; rep(i,1,m){ int x=read,y=read; add(x,y,1);add(y,x,1); } spfa(1,dis1);spfa(n,dis2); rep(i,1,k){ b[i]={a[i],dis1[a[i]],dis2[a[i]]}; } sort(b+1,b+1+k,cmp); int res=0; int las=b[1].ds; rep(i,2,k){ res=max(res,las+1+b[i].dt); las=max(las,b[i].ds); } if(res>dis1[n]) res=dis1[n]; write(res); return 0; }