题目传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=5109
题目分析:过了挺久终于把这个坑填了。一开始以为是一道很难的题,后来发现也不难想。
由于懒得打题解,直接引用出题人的题解好了(主要是来贴代码)QAQ:
虽然题目中给定的是无向图,但是实际上我们可以先从 S 出发求一遍最短路,然后问题变成了:“在有向无环图上,求有多少个满足条件的点对 A,B ,满足从 S 到 T 的所有路径一定经过 A,B 其中一点,并且不存在路径同时经过 A,B ”。
求解这到题目的一个关键点在于: 满足条件的点对 A,B 具有特点:从 S 到 A 的方案数 × 从 A 到 T 的方案数 + 从 S 到 B 的方案数 × 从 B 到 T 的方案数 = 从 S 到 T 的方案数。
所以在有向无环图上用动态规划求解路径条数,再去掉 A 可以到达 B 或 B 可以到达 A 的情况即可求解这到题目。
PS:方案数可能会爆掉怎么办?可以对方案数求余一个大整数,如果觉得不够的话可以求余两个大整数。
定义 F(X)= 从 S 到 X 的方案数 × 从 X 到 T 的方案数 = 从 S 经过 X 到达 T 的方案数,所以满足条件的点对 A,B 为:
- F(A)+F(B)=F(T)
- A 和 B 不能相互到达
对于条件 1 ,我们可以使用数据结构进行优化(使用
std::map
即可),而对于条件 2 ,我们可以使用bitset
位压 32 或者 64 位进行加速,使得最终时间和空间都能够承受。时间复杂度: O(nlogn+nmw) ,其中 w 是位压的字长。
我写代码的时候模了3个大质数,并且手写了个Hash表处理条件1,结果代码比标程长到不知哪里去了……
CODE:
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
const int maxn=50010;
const int maxm=800;
const unsigned long long Max=1ULL<<22;
const long long M[3]={998244353,1000000007,1998585857};
const long long Mod=55837;
const long long M1=1333331;
const long long M2=23252729;
const long long M3=19260817;
typedef long long LL;
typedef unsigned long long ULL;
struct edge
{
int obj,len;
edge *Next;
} e[maxn<<2];
edge *head[maxn];
edge *nhead[maxn];
int cur=-1;
int Heap[maxn];
int id[maxn];
LL dis[maxn];
int tail;
int pin[maxn];
int que[maxn];
int he,ta;
struct data
{
LL cnt1[3],cnt2[3];
int num;
} a[maxn];
ULL get[maxn][maxm];
struct Hash_data
{
LL val[3];
ULL Node[maxm];
int Num;
} Hash[Mod];
int Cnt[Max];
int n,m,s,t,sn;
LL ans=0;
void Add(edge **Head,int x,int y,int z)
{
cur++;
e[cur].obj=y;
e[cur].len=z;
e[cur].Next=Head[x];
Head[x]=e+cur;
}
int Delete()
{
int temp=Heap[1];
Heap[1]=Heap[tail];
tail--;
id[ Heap[1] ]=1;
int x=1;
while (1)
{
int y=x,Left=x<<1,Right=Left|1;
if ( Left<=tail && dis[ Heap[Left] ]if ( Right<=tail && dis[ Heap[Right] ]if (y==x) break;
swap(Heap[x],Heap[y]);
swap(id[ Heap[x] ],id[ Heap[y] ]);
x=y;
}
return temp;
}
void Update(int x)
{
while (x>1)
{
int y=x>>1;
if (dis[ Heap[y] ]<=dis[ Heap[x] ]) break;
swap(Heap[x],Heap[y]);
swap(id[ Heap[x] ],id[ Heap[y] ]);
x=y;
}
}
void Release(int x,int y,LL v)
{
if (dis[y]<=dis[x]+v) return;
dis[y]=dis[x]+v;
Update(id[y]);
}
void Dijkstra()
{
for (int i=1; i<=n; i++) dis[i]=1e15;
dis[s]=0;
tail=1;
Heap[1]=s;
id[s]=1;
for (int i=1; i<=n; i++) if (i!=s) Heap[++tail]=i,id[i]=tail;
for (int i=1; iint node=Delete();
for (edge *p=head[node]; p; p=p->Next)
Release(node,p->obj,p->len);
}
}
void Work(int x,int y)
{
y--;
int u=y/64,v=y%64;
ULL temp=1;
temp<<=v;
get[x][u]|=temp;
}
void Calc()
{
for (int i=1; i<=n; i++) a[i].num=i,Work(i,i);
sn=(n-1)/64;
he=0,ta=1;
que[1]=s;
for (int k=0; k<3; k++) a[s].cnt1[k]=1;
while (heint node=que[++he];
for (edge *p=nhead[node]; p; p=p->Next)
{
int to=p->obj;
for (int k=0; k<3; k++) a[to].cnt1[k]=(a[to].cnt1[k]+a[node].cnt1[k])%M[k];
pin[to]--;
if (!pin[to]) que[++ta]=to;
}
}
for (int k=0; k<3; k++) a[t].cnt2[k]=1;
for (int i=ta; i>=1; i--)
{
int node=que[i];
for (edge *p=nhead[node]; p; p=p->Next)
{
int to=p->obj;
for (int k=0; k<3; k++) a[node].cnt2[k]=(a[node].cnt2[k]+a[to].cnt2[k])%M[k];
}
}
for (int i=1; i<=n; i++)
for (int k=0; k<3; k++) a[i].cnt1[k]=(a[i].cnt1[k]*a[i].cnt2[k])%M[k];
for (int i=1; i<=ta; i++)
{
int node=que[i];
if ( a[node].cnt1[0] && a[node].cnt1[1] && a[node].cnt1[2] )
for (edge *p=nhead[node]; p; p=p->Next)
{
int to=p->obj;
if ( a[to].cnt1[0] && a[to].cnt1[1] && a[to].cnt1[2] )
for (int k=0; k<=sn; k++) get[to][k]|=get[node][k];
}
}
}
void Push(int x,LL v1,LL v2,LL v3)
{
LL y=(v1*M1+v2*M2+v3*M3)%Mod;
while ( Hash[y].val[0]!=-1 &&
( Hash[y].val[0]!=v1 || Hash[y].val[1]!=v2 || Hash[y].val[2]!=v3 ) )
y=(y+1)%Mod;
Hash[y].val[0]=v1;
Hash[y].val[1]=v2;
Hash[y].val[2]=v3;
Hash[y].Node[(x-1)/64]|=( 1ULL<<((x-1)%64) );
Hash[y].Num++;
}
int Get(ULL v)
{
int temp=0;
for (int i=0; i<3; i++) temp+=Cnt[v&(Max-1ULL)],v>>=22;
return temp;
}
void Check(int x,LL v1,LL v2,LL v3)
{
LL y=(v1*M1+v2*M2+v3*M3)%Mod;
while ( Hash[y].val[0]!=-1 &&
( Hash[y].val[0]!=v1 || Hash[y].val[1]!=v2 || Hash[y].val[2]!=v3 ) )
y=(y+1)%Mod;
if ( Hash[y].val[0]==-1 ) return;
ans+=Hash[y].Num;
for (int k=0; k<=sn; k++) ans-=( Get(get[x][k]&Hash[y].Node[k])<<1 );
}
void Solve()
{
for (int i=0; i0]=-1;
for (int i=1; i1;
for (int i=1; i<=n; i++) Push(i,a[i].cnt1[0],a[i].cnt1[1],a[i].cnt1[2]);
for (int i=1; i<=n; i++)
{
LL v1=(a[t].cnt1[0]-a[i].cnt1[0]+M[0])%M[0];
LL v2=(a[t].cnt1[1]-a[i].cnt1[1]+M[1])%M[1];
LL v3=(a[t].cnt1[2]-a[i].cnt1[2]+M[2])%M[2];
Check(i,v1,v2,v3);
if ( v1==a[i].cnt1[0] && v2==a[i].cnt1[1] && v3==a[i].cnt1[2] ) ans++;
}
ans>>=1;
}
int main()
{
freopen("chicken.in","r",stdin);
freopen("chicken.out","w",stdout);
scanf("%d%d%d%d",&n,&m,&s,&t);
for (int i=1; i<=n; i++) head[i]=nhead[i]=NULL;
for (int i=1; i<=m; i++)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
Add(head,u,v,w);
Add(head,v,u,w);
}
Dijkstra();
for (int i=1; i<=n; i++)
for (edge *p=head[i]; p; p=p->Next)
{
int to=p->obj;
if ( dis[i]+(long long)p->len==dis[to] )
Add(nhead,i,to,0),pin[to]++;
}
Calc();
Solve();
printf("%lld\n",ans);
return 0;
}