BZOJ
这道题的官方正解好像是bitset?
据说数据很水的说,随机数据要么不连通要么就只有一条最短路
dp一下即可,记录每个点被几条最短路覆盖。
#include
#include
#include
using namespace std;
typedef long long ll;
const int maxn=50010;
const ll INF=0x3f3f3f3f3f3f3f3f;
struct data{int v,nxt,w;};
int n,m,s,t,pos,cnt,inq[maxn],path[maxn],cov[maxn];
ll ans,dis[2][maxn],way[2][maxn];
queue<int> q;
struct hashtable{
static const ll ha=999917,maxe=50010;
ll E,lnk[ha],son[maxe+5],nxt[maxe+5],w[maxe+5];
ll top,stk[maxe+5];
void clear(){E=0;while(top) lnk[stk[top--]]=0;}
void add(ll x,ll y){son[++E]=y;nxt[E]=lnk[x];w[E]=0;lnk[x]=E;}
bool count(ll y)
{
ll x=y%ha;
for(int j=lnk[x];j;j=nxt[j]) if(y==son[j]) return true;
return false;
}
ll& operator [] (ll y)
{
ll x=y%ha;
for(int j=lnk[x];j;j=nxt[j]) if(y==son[j]) return w[j];
add(x,y);stk[++top]=x;return w[E];
}
}m1,m2,m3;
struct graph{
int p,head[maxn],pre[maxn];
data edge[maxn<<1];
inline void insert(int u,int v,int w)
{
edge[++p]=(data){v,head[u],w};head[u]=p;
edge[++p]=(data){u,head[v],w};head[v]=p;
}
void spfa(ll *dis,ll *cov,int s)
{
while(!q.empty()) q.pop();
memset(inq,0,sizeof(inq));
q.push(s);dis[s]=0;cov[s]=1;inq[s]=1;
while(!q.empty())
{
int x=q.front();q.pop();
inq[x]=0;
for(int i=head[x];i;i=edge[i].nxt)
{
if(dis[edge[i].v]>dis[x]+edge[i].w)
{
dis[edge[i].v]=dis[x]+edge[i].w;
cov[edge[i].v]=cov[x];pre[edge[i].v]=x;
if(!inq[edge[i].v]) inq[edge[i].v]=1,q.push(edge[i].v);
}
else if(dis[edge[i].v]==dis[x]+edge[i].w)
cov[edge[i].v]+=cov[x];
}
}
}
}g,h;
void input()
{
scanf("%d%d%d%d",&n,&m,&s,&t);
for(int i=1,u,v,w;i<=m;i++)
{
scanf("%d%d%d",&u,&v,&w);
g.insert(u,v,w);h.insert(u,v,w);
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt","r",stdin);
#endif
input();
memset(dis,0x3f,sizeof(dis));
g.spfa(dis[0],way[0],s);h.spfa(dis[1],way[1],t);
if(dis[0][t]==INF){printf("%lld\n",(ll)n*(n-1)/2);return 0;}
for(int i=1;i<=n;i++)
if(dis[0][i]+dis[1][i]!=dis[0][t])
way[0][i]=way[1][i]=0;
for(int i=1;i<=n;i++) m1[way[0][i]*way[1][i]]++;
pos=s;
while(pos)
{
path[++cnt]=pos;cov[cnt]=way[0][pos]*way[1][pos];
pos=h.pre[pos];
}
for(int i=1;i<=cnt;i++) ans+=m1[way[0][t]-cov[i]];
for(int i=1;i<=cnt;i++)
{m2[way[0][path[i]]*way[1][path[i]]]++;ans-=m2[way[0][t]-cov[i]];}
for(int i=1;i<=cnt;i++)
{ans-=m3[way[0][t]-cov[i]];m3[way[0][path[i]]*way[1][path[i]]]++;}
printf("%lld\n",ans);
return 0;
}