2023“钉耙编程”中国大学生算法设计超级联赛(1)Hide-And-Seek Game

题目大意

有一棵树,小 S \text{S} S 和小 R \text{R} R 在树上各有一条链。小 S \text{S} S 的链起点为 S a S_a Sa,终点为 T a T_a Ta;小 R \text{R} R 的链起点为 S b S_b Sb,终点为 T b T_b Tb

S \text{S} S 和小 R \text{R} R 在各自的链来回移动,一个单位时间移动一条边。求出两人最早相遇的位置,若不可能相遇,输出 − 1 -1 1

t ( 1 ≤ t ≤ 500 ) t(1\le t\le500) t(1t500) 组数据,树的结点个数 n n n 小于 3000 3000 3000

题解

先判断两条链是否有交点,没有就输出 − 1 -1 1

x 1 = dis ⁡ ( S a , k ) x_1=\operatorname{dis}(S_a,k) x1=dis(Sa,k) y 1 = dis ⁡ ( k , T a ) y_1=\operatorname{dis}(k,T_a) y1=dis(k,Ta) x 2 = dis ⁡ ( S b , k ) x_2=\operatorname{dis}(S_b,k) x2=dis(Sb,k) y 2 = dis ⁡ ( k , T b ) y_2=\operatorname{dis}(k,T_b) y2=dis(k,Tb)

由于 n n n 比较小。枚举两条链的交点 k k k,小 S S S 走到 k k k 的时间为 x 1 + 2 n ( x 1 + y 1 ) x_1+2n(x_1+y_1) x1+2n(x1+y1) x 1 + 2 n ( x 1 + y 1 ) + 2 y 1 x_1+2n(x_1+y_1)+2y_1 x1+2n(x1+y1)+2y1,小 R R R 走到 k k k 的时间为 x 2 + 2 m ( x 2 + y 2 ) x_2+2m(x_2+y_2) x2+2m(x2+y2) x 2 + 2 m ( x 2 + y 2 ) + 2 y 2 x_2+2m(x_2+y_2)+2y_2 x2+2m(x2+y2)+2y2,其中 n , m n,m n,m 是自然数。这样有 4 4 4 种可能,做 4 4 4 exgcd \text{exgcd} exgcd,各求出最小的非负 n , m n,m n,m,更新答案。

赛时用了 1h \text{1h} 1h 多打出来。代码长度 6k \text{6k} 6k

#include
using namespace std;
const int INF=1e9;
const int N=5e3+1;
int n,m,sa,ta,sb,tb,ans,pos;
int head[N],nxt[N<<1],to[N<<1],cnt,fa[N],top[N],dep[N],son[N],sz[N],p[N],cnt1,bj[N],vis[N];
int tr[N<<2],la[N<<2];
void exgcd(int a,int b,int &d,int &x,int &y)
{
    if(!b)
        d=a,x=1,y=0;
    else{
        exgcd(b,a%b,d,y,x);
        y-=x*(a/b);
    }
}
void add(int u,int v)
{
    to[++cnt]=v;
    nxt[cnt]=head[u];
    head[u]=cnt;
}
void dfs1(int k,int f)
{
    fa[k]=f;
    dep[k]=dep[f]+1;
    sz[k]=1;
    for(int i=head[k];i;i=nxt[i]){
        if(to[i]!=f){
            dfs1(to[i],k);
            sz[k]+=sz[to[i]];
            if(sz[son[k]]<sz[to[i]]){
                son[k]=to[i];
            }
        }
    }
}
void dfs2(int k,int t)
{
    top[k]=t;
    p[k]=++cnt1;
    if(son[k]) dfs2(son[k],t);
    for(int i=head[k];i;i=nxt[i])
        if(to[i]!=son[k]&&to[i]!=fa[k])
            dfs2(to[i],to[i]);
}
int lca(int u,int v)
{
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        u=fa[top[u]];
    }
    return dep[u]>dep[v]?v:u;
}
void dfs(int u,int x)
{
    while(u!=x){
        bj[u]++;
        u=fa[u];
    }
}
int getdistance(int x,int y)
{
    int Lca=lca(x,y);
    return dep[x]+dep[y]-2*dep[Lca];
}
void getans(int k)
{
    //1
    int X1=getdistance(sa,k),Y1=getdistance(ta,k),X2=getdistance(sb,k),Y2=getdistance(tb,k);
    int a=2*(X1+Y1),b=2*(X2+Y2),d,x,y,c=X2-X1;
    exgcd(a,b,d,x,y);
    if(c%d==0){
        int aa=b/d;
        x*=c/d;
        x=(x%aa+aa)%aa;
        if(ans>X1+x*a){
            ans=X1+x*a;
            pos=k;
        }
        int bb=a/d;
        y*=c/d;
        y=(y%bb+bb)%bb;
        y-=bb;
        if(ans>X2-y*b){
            ans=X2-y*b;
            pos=k;
        }
    }
    //2
    c=X2-X1-2*Y1;
    exgcd(a,b,d,x,y);
    if(c%d==0){
        int aa=b/d;
        x*=c/d;
        x=(x%aa+aa)%aa;
        if(ans>X1+x*a+2*Y1){
            ans=X1+x*a+2*Y1;
            pos=k;
        }
        int bb=a/d;
        y*=c/d;
        y=(y%bb+bb)%bb;
        y-=bb;
        if(ans>X2-y*b){
            ans=X2-y*b;
            pos=k;
        }
    }
    //3
    c=X2-X1+2*Y2;
    exgcd(a,b,d,x,y);
    if(c%d==0){
        int aa=(b)/d;
        x*=c/d;
        x=(x%aa+aa)%aa;
        if(ans>X1+x*(a)){
            ans=X1+x*a;
            pos=k;
        }
        int bb=a/d;
        y*=c/d;
        y=(y%bb+bb)%bb;
        y-=bb;
        if(ans>X2-y*b+2*Y2){
            ans=X2-y*b+2*Y2;
            pos=k;
        }
    }
    //4
    c=X2-X1+2*Y2-2*Y1;
    exgcd(a,b,d,x,y);
    if(c%d==0){
        int aa=(b)/d;
        x*=c/d;
        x=(x%aa+aa)%aa;
        if(ans>X1+x*a+2*Y1){
            ans=X1+x*a+2*Y1;
            pos=k;
        }
        int bb=(a)/d;
        y*=c/d;
        y=(y%bb+bb)%bb;
        y-=bb;
        if(ans>X2-y*b+2*Y2){
            ans=X2-y*b+2*Y2;
            pos=k;
        }
    }
}
void bfs()
{
    queue<int> q;
    q.push(1);
    vis[1]=1;
    while(q.size()){
        int k=q.front();
        q.pop();
        if(bj[k]==2){
            getans(k);
        }
        for(int i=head[k];i;i=nxt[i]){
            if(!vis[to[i]]){
                vis[to[i]]=1;
                q.push(to[i]);
            }
        }
    }
}
int main()
{
    int t;
    cin>>t;
    while(t--){
        cnt=cnt1=0;
        memset(head,0,sizeof(head));
        memset(nxt,0,sizeof(nxt));
        memset(vis,0,sizeof(vis));
        memset(fa,0,sizeof(fa));
        memset(dep,0,sizeof(dep));
        memset(son,0,sizeof(son));
        memset(sz,0,sizeof(sz));
        scanf("%d%d",&n,&m);
        for(int i=1,u,v;i<n;i++){
            scanf("%d%d",&u,&v);
            add(u,v),add(v,u);
        }
        dfs1(1,0);
        dfs2(1,1);
        while(m--){
            ans=1e9;
            memset(bj,0,sizeof(bj));
            memset(vis,0,sizeof(vis));
            scanf("%d%d%d%d",&sa,&ta,&sb,&tb);
            int lcaa=lca(sa,ta),lcab=lca(sb,tb);
            dfs(sa,lcaa),dfs(ta,lcaa);bj[lcaa]++;
            dfs(sb,lcab),dfs(tb,lcab);bj[lcab]++;
            bfs();
            if(ans==1e9) puts("-1");
            else printf("%d\n",pos);
        }
    }
}

你可能感兴趣的:(算法)