Category: Data Structure, Persistent segment Tree
Solution:
See official Editorial here for understanding the basics.
Implementation:
#include<bits/stdc++.h> using namespace std; /*------- Constants---- */ #define Long long long #define Ulong unsigned long long #define forn(i,n) for( int i=0 ; i < n ; i++ ) #define mp(i,j) make_pair(i,j) #define pb(a) push_back((a)) #define SZ(a) (int) a.size() #define all(x) (x).begin(),(x).end() #define gc getchar_unlocked #define PI acos(-1.0) #define EPS 1e-9 #define xx first #define yy second #define lc ((n)<<1) #define rc ((n)<<1|1) #define db(x) cout << #x << " -> " << x << endl; #define Di(x) int x;scanf("%d",&x) #define min(a,b) ((a)>(b) ? (b) : (a) ) #define max(a,b) ((a)>(b) ? (a):(b)) #define ms(ara_name,value) memset(ara_name,value,sizeof(ara_name)) /*************************** END OF TEMPLATE ****************************/ const int N = 2e5+500; int n,m,sz; vector<pair<int,int> > G[N]; int root[N]; long long B = 37; long long MOD = 1e9 + 7; long long POW[N], POWER[N]; int zero; struct Node { int ls,rs; int cnt; long long hash; } node[N*40]; void build(int &n,int b,int e) { n = ++sz; if(b==e) { node[n].hash = node[n].cnt = 0; return; } int mid = (b+e)/2; build(node[n].ls, b, mid); build(node[n].rs, mid+1,e); node[n].cnt = node[node[n].ls].cnt + node[node[n].rs].cnt; node[n].hash = node[node[n].ls].hash + POW[mid - b + 1] * node[node[n].rs].hash; } void update(int &n,int pre,int b,int e,int pos) { n = ++sz; node[n] = node[pre]; if(b==e && b == pos) { node[n].hash = 1; node[n].cnt = 1; return; } int mid = (b+e)/2; if(pos <= mid) update(node[n].ls, node[pre].ls , b, mid, pos); else update(node[n].rs, node[pre].rs, mid+1,e,pos); node[n].cnt = node[node[n].ls].cnt + node[node[n].rs].cnt; node[n].hash = node[node[n].ls].hash + POW[mid - b + 1] * node[node[n].rs].hash; } void updateRange(int &n,int pre,int id2 , int b,int e,int i, int j) { n = ++sz; node[n] = node[pre]; if(b>j || e < i) return; if(b >= i && e <= j ) { node[n] = node[id2]; return; } int mid = (b+e)/2; updateRange(node[n].ls, node[pre].ls, node[id2].ls, b, mid, i, j); updateRange(node[n].rs, node[pre].rs, node[id2].rs, mid+1,e,i, j); node[n].cnt = node[node[n].ls].cnt + node[node[n].rs].cnt; node[n].hash = node[node[n].ls].hash + POW[mid - b + 1] * node[node[n].rs].hash; } int query(int n,int b,int e,int i,int j) { if(b>j||e<i) return 0; if(b>=i&&e<=j) return node[n].cnt; int mid = (b+e)/2; return query(node[n].ls, b, mid, i, j) + query(node[n].rs, mid+1,e,i,j); } int find(int root,int w) { int low = w, high = N , mid, ans = w - 1; while(low <= high) { mid = (low + high ) /2; int oo = query(root,0,N,w,mid); if( oo == mid - w+ 1) { ans = mid ; low = mid+1; } else high = mid-1; } return ans + 1; } bool queryIn(int r1 ,int r2, int b,int e) { if(b==e) { return node[r1].hash > node[r2].hash; } int mid = (b+e)/2; if( node[node[r1].rs]. hash == node[node[r2].rs]. hash ) return queryIn(node[r1].ls, node[r2].ls, b , mid ) ; else return queryIn(node[r1].rs, node[r2].rs, mid+1,e); } long long tot = 0; void Trv(int n,int b,int e) { if(b==e) { if(node[n].hash) { tot += POWER[b]; tot %= MOD; } return; } int mid = (b+e)/2; Trv(node[n].ls, b, mid); Trv(node[n].rs, mid+1,e); } void show(int r,int b , int e) { if(b==e) {printf("%lld",node[r].hash); return; } int mid = (b+e)/2; show(node[r].ls, b , mid); show(node[r].rs , mid+1,e); } struct Vertex { int id; int r; bool operator<(const Vertex &p) const { return queryIn(r,p.r,0,N); } } vlist[N]; bool vis[N]; int F[N]; bool see[N]; void print(Vertex x) { printf("Value of %d : ", x.id); show(x.r,0,N); cout << endl; } void dij(int s,int t) { Vertex S = {s,zero}; //print(S); priority_queue<Vertex> pq; pq.push(S); while(!pq.empty()){ Vertex u = pq.top(); pq.pop(); if(vis[u.id]) continue; vis[u.id] = 1; if(u.id == t) break; for(auto a : G[u.id]) { int v = a.first, w = a.second; if(vis[v]) continue; int ho = find(u.r,w); int tmp, tmp2; update(tmp, u.r, 0, N , ho ); updateRange(tmp2, tmp, zero, 0, N , w, ho - 1); Vertex T = {v, tmp2}; //print(T); if(see[v] == 0 || vlist[v] < T ) { vlist[v] = T; see[v] = 1; F[v] = u.id; pq.push(vlist[v]); } } } if(vis[t]) { Trv(vlist[t].r, 0, N); cout << tot << endl; vector<int> ans; while(t!= s) { ans.pb(t); t = F[t]; } ans.pb(s); reverse(all(ans)); printf("%ld\n",ans.size()); for(auto a: ans ) printf("%d ",a); printf("\n"); } else printf("-1\n"); } int main() { //freopen("in.txt","r",stdin); scanf("%d %d",&n,&m); POW[0] = POWER[0] = 1; for(int i = 1; i < N; i ++ ) POW[i] = (POW[i-1] * B) % MOD, POWER[i] = (POWER[i-1] * 2) % MOD; for(int i = 0; i < m ; i ++ ) { int a,b,x; scanf("%d %d %d",&a,&b,&x); G[a].pb(mp(b,x)); G[b].pb(mp(a,x)); } build(zero,0,N); int s,t; scanf("%d %d",&s,&t); dij(s,t); return 0; }