https://www.luogu.com.cn/problem/P2834
求
∑ i = 1 n ∑ j = 1 m ( n m o d i ) ( m m o d j ) [ i ! = j ] \sum_{i=1}^n \sum_{j=1}^m (n\ mod\ i)(m\ mod\ j)[i!=j] i=1∑nj=1∑m(n mod i)(m mod j)[i!=j]
数据范围 n , m ≤ 1 0 9 n,m\leq 10^9 n,m≤109
强令 n ≤ m n\leq m n≤m
∑ i = 1 n ∑ j = 1 m ( n m o d i ) ( m m o d j ) − ∑ i = 1 n ( n m o d i ) ( m m o d i ) \sum_{i=1}^n \sum_{j=1}^m (n\ mod\ i)(m\ mod\ j)-\sum_{i=1}^n(n\ mod\ i)(m\ mod \ i) i=1∑nj=1∑m(n mod i)(m mod j)−i=1∑n(n mod i)(m mod i)
转换一下 m o d mod mod
∑ i = 1 n ∑ j = 1 m ( n − i ⌊ n i ⌋ ) ( m − i ⌊ m j ⌋ ) − ∑ i = 1 n ( n − i ⌊ n i ⌋ ) ( m − i ⌊ m i ⌋ ) \sum_{i=1}^n \sum_{j=1}^m (n-i\lfloor \frac n i\rfloor)(m-i\lfloor\frac m j\rfloor)-\sum_{i=1}^n(n-i\lfloor \frac n i\rfloor)(m-i\lfloor\frac m i\rfloor) i=1∑nj=1∑m(n−i⌊in⌋)(m−i⌊jm⌋)−i=1∑n(n−i⌊in⌋)(m−i⌊im⌋)
注意第一项中 i , j i,j i,j是不相关的,可以分开计算,再相乘
∑ i = 1 n ( n − i ⌊ n i ⌋ ) × ∑ i = 1 m ( m − i ⌊ m i ⌋ ) − ∑ i = 1 n ( n m − m i ⌊ n i ⌋ − n i ⌊ m i ⌋ + i 2 ⌊ m i ⌋ ⌊ n i ⌋ ) \sum_{i=1}^n (n-i\lfloor \frac n i\rfloor) \times \sum_{i=1}^m(m-i\lfloor\frac m i\rfloor)-\sum_{i=1}^n (nm-mi\lfloor \frac n i\rfloor-ni\lfloor \frac m i\rfloor+i^2\lfloor \frac m i\rfloor\lfloor \frac n i\rfloor) i=1∑n(n−i⌊in⌋)×i=1∑m(m−i⌊im⌋)−i=1∑n(nm−mi⌊in⌋−ni⌊im⌋+i2⌊im⌋⌊in⌋)
然后有 ∑ i = 1 n i 2 = n ( n + 1 ) ( 2 n + 1 ) 6 \sum_{i=1}^n i^2=\frac {n(n+1)(2n+1)}{6} ∑i=1ni2=6n(n+1)(2n+1)
乘除分块就做完了,复杂度 O ( m ) O(\sqrt m) O(m)
#include
#include
#include
#define LL long long
#define mod 1000000007
using namespace std;LL n,m,inv2,inv6,S1,S2,Ans;
inline LL ksm(LL x,LL y)
{
LL res=1;
for(;y;y>>=1,(x*=x)%=mod) if(y&1) (res*=x)%=mod;
return res;
}
inline LL read()
{
char c;LL d=1,f=0;
while(c=getchar(),!isdigit(c)) if(c=='-') d=-1;f=(f<<3)+(f<<1)+c-48;
while(c=getchar(),isdigit(c)) f=(f<<3)+(f<<1)+c-48;
return d*f;
}
inline LL sum1(LL x){
return x*(x+1)%mod*inv2%mod;}
inline LL sum2(LL x){
return x*(x+1)%mod*(2*x+1)%mod*inv6%mod;}
signed main()
{
n=read();m=read();
if(n>m) n^=m,m=n^m,n^=m;
inv2=ksm(2,mod-2);inv6=ksm(6,mod-2);
S1=n*n%mod;
for(register int l=1,r;l<=n;l=r+1)
{
r=n/(n/l);
S1-=(sum1(r)-sum1(l-1)+mod)%mod*(n/l)%mod;
S1=(S1+mod)%mod;
}
S2=m*m%mod;S2=(S2+mod)%mod;
for(register int l=1,r;l<=m;l=r+1)
{
r=m/(m/l);
S2-=(sum1(r)-sum1(l-1)+mod)%mod*(m/l)%mod;
S2=(S2+mod)%mod;
}
Ans=S1*S2%mod;
Ans-=n*n%mod*m%mod;Ans=(Ans+mod)%mod;
for(register int l=1,r;l<=n;l=r+1)
{
r=min(n/(n/l),m/(m/l));
Ans+=(sum1(r)-sum1(l-1)+mod)%mod*(m*(n/l)%mod+n*(m/l)%mod)%mod;Ans=(Ans+mod)%mod;
Ans-=(sum2(r)-sum2(l-1)+mod)%mod*(m/l)%mod*(n/l)%mod;Ans=(Ans+mod)%mod;
}
printf("%lld",Ans);
}