题意:f(i)=i的幂次之和。 求(N+1-i)*f(i)之和。
思路:可以推论得对于一个素数p^k,其贡献是ans=(N+1)[N/(P^k)]+P^k(1+2+3...N/(P^k));
我们分两部分统计答案即可,在p<=sqrt(N)时,可以暴力(阶乘那样一直除)统计答案。 p>sqrt(N)时,我们可以利用min25的消息得到。
因为p>sqrt(N),这个时候k=1,所以贡献为(N+1)*(N/p)+p*(1+2+...N/p);我们把N/p相同的拉出来即可,而这个东西正好就是min25的基本操作。
N/p有根号级别个,我们把素数个数前缀和保存到h[]里,素数和的前缀和保存到g[]里,就不难得到区间素数个数,以及区间素数和。
wa点:写出来容易爆炸long long。
前缀和公式:x*(x+1)/2;由于x*(x+1)可能爆炸ll,所以我们得到x=x%Mod再进行计算,这个时候就不能/2了,需要*2的逆元。
ans:ans没必要每次维护到[0,Mod)这个范围内,我们最后一次性处理就好了,毕竟每次加减的都是[0,Mod)这个范围的数,最后结果一定再long long范围内。这样会
快一些。
#include#define ll long long #define rep(i,a,b) for(int i=a;i<=b;i++) using namespace std; const int maxn=1000010; const int Mod=998244353,inv2=(Mod+1)/2; int qpow(int a,int x){ int res=1; while(x){ if(x&1) res=1LL*res*a%Mod; x>>=1; a=1LL*a*a%Mod; } return res; } struct min25 //很多部分需要long,不要搞错了 { ll p[maxn],sp1[maxn],N; ll g1[maxn],h[maxn];int Sqr,ind1[maxn],ind2[maxn],num,tot; ll w[maxn]; bool vis[maxn]; int MOD(int x){ if(x>=Mod) x-=Mod;return x; } void prime() //得到素数,sp1,sp2 { rep(i,2,Sqr){ if(!vis[i]){ p[++num]=i; sp1[num]=MOD(sp1[num-1]+i); } for(int j=1;j<=num&&p[j]*i<=Sqr;j++){ vis[p[j]*i]=1; if(i%p[j]==0) break; } } } void getind() { for(ll i=1;i<=N;i++){ ll now=N/i,j=N/now,t=now%Mod; w[++tot]=now; g1[tot]=MOD(t*(t+1)/2%Mod+Mod-1); //因为我们全部都不考虑1。 h[tot]=t-1; if(g1[tot]<0) g1[tot]+=Mod; if(h[tot]<0) h[tot]+=Mod; if(now<=Sqr) ind1[now]=tot; else ind2[j]=tot; i=j; } } void getg() { rep(i,1,num){ //注意w里面的东西是递减的,所以可以滚动 for(int j=1;j<=tot&&p[i]<=w[j]/p[i];j++){ ll now=w[j]/p[i]; int k=now<=Sqr?ind1[now]:ind2[N/now]; g1[j]=MOD(g1[j]-1LL*p[i]*(g1[k]-sp1[i-1]+Mod)%Mod+Mod); h[j]=(h[j]-(h[k]-i+1+Mod)%Mod+Mod)%Mod; } } } ll get(ll x) { return x*(x+1)%Mod*inv2%Mod; }//x需要实现%Mod,不然爆ll,所以也不能用/2,而用逆元。 void solve(ll n) { N=n; Sqr=sqrt(N); prime(); //筛根号部分素数。 ll ans=0; rep(i,1,num){ for(ll e=p[i];e<=N;e*=p[i]){ ans+=(N+1)%Mod*(N/e)%Mod-e%Mod*get(N/e%Mod)%Mod; } } getind(); getg(); rep(i,1,Sqr-1){ ans+=(N+1)%Mod*i*((h[i]-h[i+1]+Mod)%Mod)%Mod; //多少个素数满足N/p=i ans-=1LL*get(i)*((g1[i]-g1[i+1]+Mod)%Mod)%Mod; //素数之和 } ans%=Mod; if(ans<0) ans+=Mod; printf("%lld\n",ans%Mod); } }T; int main() { ll N; scanf("%lld",&N); T.solve(N); return 0; }
,很多地方容易爆long long。