2023“钉耙编程”中国大学生算法设计超级联赛(1) A - Hide-And-Seek Game
有一棵有 n n n个节点的树,小 S S S和小 R R R在树上各有一条链。小 S S S的链的起点为 S a S_a Sa,终点为 T a T_a Ta;小 R R R的链起点为 S b S_b Sb,终点为 T b T_b Tb。
小 S S S和小 R R R在各自的链来回移动(就是从起点到终点,再回到起点,进行多次来回),一个单位时间移动一条边。求出两人最早相遇的位置(这个位置必须是一个点,而不是一条边),若不可能相遇,输出 − 1 -1 −1。
有 t t t组数据。
1 ≤ t ≤ 500 , 2 ≤ n , m ≤ 3 × 1 0 3 1\leq t\leq 500,2\leq n,m\leq 3\times 10^3 1≤t≤500,2≤n,m≤3×103。
数据保证 n n n值超过 400 400 400的数据组数不超过 20 20 20。
数据保证 m m m值超过 400 400 400的数据组数不超过 20 20 20。
因为 n n n比较小,所以我们可以枚举两条链上的每一个点。
对于同时在两条链上的点 x x x,我们可以得出从小 S S S到达 x x x的时间为 2 k 1 ⋅ d i s ( S a , T a ) + d i s ( S a , x ) 2k_1\cdot dis(S_a,T_a)+dis(S_a,x) 2k1⋅dis(Sa,Ta)+dis(Sa,x)或 2 k 2 ⋅ d i s ( S a , T a ) + d i s ( T a , x ) 2k_2\cdot dis(S_a,T_a)+dis(T_a,x) 2k2⋅dis(Sa,Ta)+dis(Ta,x),其中 k 1 , k 2 k_1,k_2 k1,k2都为非负整数。
同理,小 R R R到达 x x x的时间为 2 k 3 ⋅ d i s ( S a , T a ) + d i s ( S a , x ) 2k_3\cdot dis(S_a,T_a)+dis(S_a,x) 2k3⋅dis(Sa,Ta)+dis(Sa,x)或 2 k 4 ⋅ d i s ( S a , T a ) + d i s ( T a , x ) 2k_4\cdot dis(S_a,T_a)+dis(T_a,x) 2k4⋅dis(Sa,Ta)+dis(Ta,x),其中 k 3 , k 4 k_3,k_4 k3,k4都为非负整数。
两两联立成形为 a x + b x = c ax+bx=c ax+bx=c的二元一次方程,然后用扩展欧几里得算法求出最小非负整数解即可。
时间复杂度为 O ( n m log n ) O(nm\log n) O(nmlogn)。因为 n , m n,m n,m较大的数据组数比较少,而且跑不满,时限有 5 s 5s 5s,所以是可以过的。
#include
using namespace std;
int tq,n,m,s1,s2,t1,t2,dt,fa[5005],dep[5005],dfn[5005],low[5005];
int ans,bz,ls,lt,tot=0,d[10005],l[10005],r[10005],z[5005],w[5005][2];
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void dfs(int u,int f){
fa[u]=f;
dep[u]=dep[f]+1;
dfn[u]=++dt;
for(int i=r[u];i;i=l[i]){
if(d[i]==f) continue;
dfs(d[i],u);
}
low[u]=dt;
}
int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
while(dep[x]>dep[y]) x=fa[x];
while(x!=y){
x=fa[x];y=fa[y];
}
return x;
}
bool in(int x,int y){
return dfn[y]<=dfn[x]&&dfn[x]<=low[y];
}
void exgcd(int &x,int &y,int &d,int a,int b){
if(b==0){
x=1;y=0;d=a;
return;
}
exgcd(x,y,d,b,a%b);
int t=x;x=y;y=t-a/b*y;
}
void dd(int now,int w1,int w2){
int a=2*ls,b=-2*lt,c=w2-w1,x,y,d,vk;
c=(c%(2*lt)+2*lt)%(2*lt);
exgcd(x,y,d,a,b);
if(c%d) return;
x*=c/d;y*=c/d;
int p=a/d,q=b/d;
if(x<0){
vk=ceil((-1.0*x)/q);
x+=vk*q;y-=vk*p;
}
else if(x>=0){
vk=x/q;
x-=vk*q;y+=vk*p;
}
if(a*x+w1<ans){
ans=a*x+w1;
bz=now;
}
}
int main()
{
scanf("%d",&tq);
while(tq--){
scanf("%d%d",&n,&m);
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
dfs(1,0);
for(int o=1;o<=m;o++){
scanf("%d%d%d%d",&s1,&t1,&s2,&t2);
if(s1==s2){
printf("%d\n",s1);
continue;
}
int vs=lca(s1,t1),vt=lca(s2,t2);
if(dep[vs]>dep[vt]){
swap(s1,s2);
swap(t1,t2);
swap(vs,vt);
}
if(!in(s1,vt)&&!in(t1,vt)){
printf("-1\n");
continue;
}
ans=1e9;bz=-1;
ls=dep[s1]+dep[t1]-2*dep[vs],lt=dep[s2]+dep[t2]-2*dep[vt];
for(int p=s1;;p=fa[p]){
w[p][0]=dep[s1]-dep[p];
w[p][1]=2*ls-(dep[s1]-dep[p]);
z[p]=o;
if(p==vs) break;
}
for(int p=t1;p!=vs;p=fa[p]){
w[p][0]=ls-(dep[t1]-dep[p]);
w[p][1]=ls+(dep[t1]-dep[p]);
z[p]=o;
}
for(int p=s2;;p=fa[p]){
if(z[p]==o){
int k1=dep[s2]-dep[p],k2=2*lt-(dep[s2]-dep[p]);
dd(p,w[p][0],k1);dd(p,w[p][0],k2);
dd(p,w[p][1],k1);dd(p,w[p][1],k2);
}
if(p==vt) break;
}
for(int p=t2;p!=vt;p=fa[p]){
if(z[p]==o){
int k1=lt-(dep[t2]-dep[p]),k2=lt+(dep[t2]-dep[p]);
dd(p,w[p][0],k1);dd(p,w[p][0],k2);
dd(p,w[p][1],k1);dd(p,w[p][1],k2);
}
}
printf("%d\n",bz);
}
tot=dt=0;
for(int i=1;i<=n;i++){
r[i]=fa[i]=dep[i]=z[i]=0;
dfn[i]=low[i]=0;
}
}
return 0;
}