[SDOI2018] bzoj 5332 & luogu 4619 旧试题 - 数论

出题人“不优秀的三元环枚举也可以通过”
然而之前自己写了一发,我不计算答案只枚举三元环就跑了半分钟……

答案不会爆longlong,中间不用取模。
先统计自环的情况会很方便后面讨论。
然后就是各种地方都要卡常。
一个结论是无向图给边定向为从度数小的点指向度数大的点,每个点的出度是根号边数级别的。
判断一条边能不能连可以先枚举gcd,然后再搞,可以发现这样复杂度是 O(nlg2n) O ( n l g 2 n ) 的。
代码(在bz上因为没有c++11会CE,开map会T,所以bz的标题只是骗访问量的XD):

#include
#include
#include
#include
#include
#include
#include
#define clr(a,n) memset(a,0,sizeof(int)*((n)+1))
#define hv(u,v) ((u)*(n+1ll)+(v))
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define calc(x,y,z) (1ll*F[A/(x)]*F[B/(y)]*F[C/(z)])
#define fir first
#define sec second
#define mp make_pair
#define mod 1000000007
#define lint long long
#define gc getchar()
#define N 100010
#define debug(x) cerr<<#x<<"="<
#define sp <<" "
#define ln <
using namespace std;
typedef pair<int,int> pii;
inline int inn()
{
    int x,ch;while((ch=gc)<'0'||ch>'9');
    x=ch^'0';while((ch=gc)>='0'&&ch<='9')
        x=(x<<1)+(x<<3)+(ch^'0');return x;
}
inline int gcd(int a,int b) { return a?gcd(b%a,a):b; }
struct node{
    int u,v,w;
    node(int _u=0,int _v=0,int _w=0) { u=_u,v=_v,w=_w; }
};
unordered_mapint> e;
vector es;vector g[N];
int deg[N],F[N],mu[N],p[N],notp[N];

inline int prelude(int n)
{
    notp[1]=1,mu[1]=1;
    for(int i=2,c=0;i<=n;i++)
    {
        if(!notp[i]) p[++c]=i,mu[i]=-1;
        for(int j=1,x;j<=c&&1ll*i*p[j]<=n;j++)
        {
            notp[x=i*p[j]]=1;
            if(i%p[j]) mu[x]=-mu[i];
            else { mu[x]=0;break; }
        }
    }
    rep(i,1,n) rep(j,1,n/i) F[i*j]++;
    rep(i,1,n) F[i]+=F[i-1];
    return 0;
}
int main()
{
    prelude(N-1);
    for(int Tc=inn();Tc;Tc--)
    {
        int A=inn(),B=inn(),C=inn(),n=max(max(A,B),C);
        lint ans=0;

        es.clear(),clr(deg,n),e.clear();
        rep(i,1,n) g[i].clear();

        rep(i,1,n) if(mu[i]) ans+=mu[i]*calc(i,i,i);

        rep(g,1,n) rep(i,1,n/g) if(mu[i*g])
            rep(j,i+1,n/g/i) if(mu[j*g]&&gcd(i,j)==1)
            {
                int u=i*g,v=j*g,w=i*j*g;
                ans+=mu[v]*(calc(u,w,w)+calc(w,u,w)+calc(w,w,u)),
                ans+=mu[u]*(calc(v,w,w)+calc(w,v,w)+calc(w,w,v)),
                es.push_back(node(u,v,w)),deg[u]++,deg[v]++;
            }

        for(int i=0;i<(int)es.size();i++)
        {
            int &u=es[i].u,&v=es[i].v,w=es[i].w;
            if(deg[u]>deg[v]) swap(u,v);
            if(deg[u]==deg[v]&&u>v) swap(u,v);
            g[u].push_back(mp(v,w)),e.insert(mp(hv(u,v),w));
        }

        for(int i=0;i<(int)es.size();i++)
        {
            int u=es[i].u,w=es[i].v,uw=es[i].w;lint val;
            for(int j=0,v;j<(int)g[u].size();j++)
            {
                if(!e.count(val=hv(v=g[u][j].fir,w))) continue;
                int uv=g[u][j].sec,vw=e[val],t=mu[u]*mu[v]*mu[w];
                ans+=t*(calc(uv,vw,uw)+calc(uw,vw,uv)),
                ans+=t*(calc(uv,uw,vw)+calc(vw,uw,uv)),
                ans+=t*(calc(uw,uv,vw)+calc(vw,uv,uw));
            }
        }

        printf("%lld\n",ans%mod);
    }
    return 0;
}

这是以前写的:

#include
#include
#include
#include
#include
#include
#define mod 1000000007
#define gc getchar()
#define N 600010
#define M 12000010
#define hv(x,y) ((lint)(x)*N+(y))
#define lint long long
#define pb push_back
#define debug(x) cerr<<#x<<"="<
#define sp <<" "
#define ln <
using namespace std;
inline int inn()
{
    int x,ch;while((ch=gc)<'0'||ch>'9');
    x=ch^'0';while((ch=gc)>='0'&&ch<='9')
        x=(x<<1)+(x<<3)+(ch^'0');return x;
}
int u[M],v[M],ec,pri[N],mu[N],f[N],P[N],Q[N],R[N],ans[5];
bool notp[N],vis[N];vector<int> in[N],d[N];int val[N],A[5];
unordered_mapint> m;int cnt,du[N],lis[N],b[N],bel[N];
inline int inv(int x,int k=mod-2,int ans=1)
{   for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans;   }
inline int add_edge(int x,int y,int z)
{   return ec++,u[ec]=x,v[ec]=y,x[du]++,y[du]++,in[y].pb(x),m.insert(make_pair(hv(x,y),z)),0;   }
inline int sol(int x,int s) { return (!s)?0:((s>0)?x:(mod-x)%mod); }
inline int F(int x,int y)
{
    int bx=bel[x],by=bel[y];if(bx>by) swap(bx,by);
    return f[A[(!bx)?(by==1?0:2):1]/m[hv(x,y)]];
}

inline int gcd(int x,int y) { return x?gcd(y%x,x):y; }
inline lint lcm(int x,int y) { return (lint)x*y/gcd(x,y); }
inline int mol(lint x) { return x%=mod,(x<0)?x+mod:x; }
lint p[5];

int main()
{
    mu[1]=1;int n=100000;
    for(int i=2;i<=n;i++)
    {
        if(!notp[i]) pri[++cnt]=i,mu[i]=-1;
        for(int j=1;j<=cnt;j++)
        {
            if((lint)pri[j]*i>n) break;
            int x=pri[j]*i;notp[x]=true;
            if(i%pri[j]) mu[x]=-mu[i];
            else { mu[x]=0;break; }
        }
    }
    for(int i=1;i<=n;i++)
        for(int j=i;j<=n;j+=i) d[j].push_back(i);
    for(int i=1;i<=n;i++)
        for(int s=1,t;s<=i;s=t+1)
            t=i/(i/s),(f[i]+=(t-s+1ll)*(i/s)%mod)%=mod;
//  for(int i=1;i<=30;i++) debug(i)sp,debug(mu[i])sp,debug(f[i])ln;
    for(int T=inn();T;T--)
    {
        int nc=0,t=0;ans[0]=ans[1]=ans[2]=ans[3]=0;
        A[0]=inn(),A[1]=inn(),A[2]=inn(),sort(A,A+3),swap(A[0],A[2]);

/*      int Ans=0;
        for(int i=1;i<=A[0];i++)
            for(int j=1;j<=A[0];j++)
                for(int k=1;k<=A[0];k++)
                    (Ans+=sol((lint)f[A[0]/lcm(i,j)]*f[A[1]/lcm(i,k)]%mod*f[A[2]/lcm(j,k)]%mod,mu[i]*mu[j]*mu[k]))%=mod;
        debug(Ans)ln;*/

        for(int i=1;i<=A[0];i++) if(mu[i]) val[P[i]=++nc]=i,bel[nc]=0;
        for(int i=1;i<=A[0];i++) if(mu[i]) val[Q[i]=++nc]=i,bel[nc]=1;
        for(int i=1;i<=A[0];i++) if(mu[i]) val[R[i]=++nc]=i,bel[nc]=2;
        memset(du,0,sizeof(int)*(nc+1)),m.clear(),ec=0;
        for(int i=1;i<=nc;i++) in[i].clear();
        for(int i=1;i<=nc/3;i++)
        {
            int x=val[i];t=0;
            for(int j=(int)d[x].size()-1;j>=0;j--)
                for(int k=d[x][j];k<=(lint)A[0]*d[x][j]/x;k+=d[x][j])
                    if(mu[k]&&!vis[k])
                    {
                        int w=(int)((lint)k*x/d[x][j]);vis[lis[++t]=k]=true;
                        add_edge(P[x],Q[k],w),add_edge(Q[x],R[k],w),add_edge(R[x],P[k],w);
                    }
            for(int j=1;j<=t;j++) vis[lis[j]]=false;
        }
        int s=max(1,(int)sqrt(ec/3+0.5)),bc=0;
        for(int i=1;i<=nc;i++) if(i[du]>s) b[++bc]=i;
        for(int i=1;i<=bc;i++) if(bel[b[i]]==0)
            for(int j=1;j<=bc;j++) if(bel[b[j]]==1)
                if(m.count(hv(b[i],b[j])))
                {
                    lint p3B=0ll;int x=b[i][val],y=b[j][val];
                    for(int k=1;k<=bc;k++)
                        if(m.count(hv(b[j],b[k]))&&m.count(hv(b[k],b[i]))) if(bel[b[k]]==2)
                            p3B+=sol((lint)F(b[j],b[k])*F(b[k],b[i])%mod,mu[x]*mu[y]*mu[b[k][val]]);
                    (ans[0]+=p3B%mod*F(b[i],b[j])%mod)%=mod;
                }
        for(int i=1;i<=ec;i++) if(u[i][du]<=s)
        {
            p[1]=p[2]=p[3]=0ll;
            for(int j=0,x=u[i],y=v[i],z;j<(int)in[x].size();j++)
            {
                if(!m.count(hv(y,z=in[x][j]))) continue;
                int sgn=mu[x[val]]*mu[y[val]]*mu[z[val]];
                int ds=(x[du]<=s)+(y[du]<=s)+(z[du]<=s);
                p[ds]=mol(p[ds]+sgn*(lint)F(y,z)*F(z,x));
            }
            int w=F(u[i],v[i]);
            ans[1]=mol(ans[1]+p[1]%mod*w);
            ans[2]=mol(ans[2]+p[2]%mod*w);
            ans[3]=mol(ans[3]+p[3]%mod*w);
        }
        printf("%lld\n",((lint)ans[3]*inv(3)+(lint)ans[2]*inv(2)+ans[1]+ans[0])%mod);
    }
    return 0;
}

你可能感兴趣的:(SDOI,数论,BZOJ)