PE 439 Sum of sum of divisors | 51nod 1220 约数之和

题目:https://projecteuler.net/problem=439以及http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1220

题意:
d(x) 表示 x 的所有因子之和,求 Ni=1Nj=1d(ij)mod109+7
N109

题解:
[x] 为布尔表达式 x 的值,即若 x 为真则 [x]=1 ,否则 [x]=0
那么有 d(x)=xi=1[i|x]i
于是问题可以写成

i=1Nj=1Nd(ij)=i=1Nj=1Nd=1ij[d|ij]d

考虑如何简化掉 j ,我们可以计算每个 d 对应多少个 j 。因为 d|ij 等价于 dgcd(i,d)|j ,所以上式可以化为
=i=1Nd=1iNNgcd(i,d)dd

再考虑如何将 gcd(i,d) 化为互质的情况,令 d=gcd(i,d) ,则有
=d=1Ni=1Nd=1iN[gcd(i,d)=d]Ngcd(i,d)dd

由于 id dd 是整数,我们将上式的 i d idd idd 替换
=d=1Nid=1Nddd=1idN[gcd(id,dd)=1]Nddddd

再进行一个下标变换,将 id dd i d 替换
=d=1Ni=1Ndd=1iN[gcd(i,d)=1]Nddd=d=1Ndi=1Ndd=1N[gcd(i,d)=1]Ndd

可以看到 d d 之间没有约束,而且 i d 的关系比 d 的关系简单,不妨将 d i 交换求和顺序,考虑每个 i 对应哪些 d ,尝试将 d 约去
=i=1Nd=1Nidd=1N[gcd(i,d)=1]Ndd=i=1NNi(Ni+1)2d=1N[gcd(i,d)=1]Ndd

而对于处理 [gcd(i,d)=1] ,想必读到这里的你已经想出了处理它的方法,利用莫比乌斯函数的经典定理 Ni=1[i|N]μ(i)=[N=1] 可以得到
=i=1NNi(Ni+1)2d=1NNddd=1N[d|gcd(i,d)]μ(d)=i=1NNi(Ni+1)2d=1NNddd=1N[d|i][d|d]μ(d)

细心的你一定发现了 i d 互相无关,它们分别只和 d 有关,于是我们交换求和顺序,算每个 d 对应的 i d ,则有
=i=1NNi(Ni+1)2d=1NNddd=1N[d|i][d|d]μ(d)=d=1Nμ(d)id=1NdNidd(Nidd+1)2dd=1NdNdddddd=d=1Ndμ(d)i=1NdNid(Nid+1)2d=1NdNddd

到这里,我们已经得到了一个不错的式子,可以直接进行计算了,无需继续化简。(再化简下去 LATEX 都要炸啦~)
f(x)=xμ(x) g(x)=xi=1xi(xi+1)2 h(x)=xi=1xii ,则所求为 Ni=1f(i)g(Ni)h(Ni)
这个式子里用到的 g(x) h(x) O(N) 个,每个的计算复杂度为 O(x) ,所以把所有所需的 g h 算出来的复杂度为 O(Ni=1Ni)=O(N34) ,但是达不到这个上界。(实际上很容易改成 O(N23) 的,并且 g(x)=h(x)
如果可以快速计算出一段 f(x) 的值,就可以结合上面的做法解决这个问题,注意到 f(x) 是两个积性函数的乘积,所以 f(x) 也是积性函数,我们可以得到类似梅滕斯函数(莫比乌斯函数的前缀和函数)的计算方法,令 S(x)=xi=1f(i) ,则有
S(x)=1i=2xd=1i1[d|i]iμ(d)=1i=2xd=1xiidμ(d)=1i=2xiS(xi)

于是我们可以利用线性筛预处理出前 T 个值,那么计算复杂度变为 O(NTi=1Ni)=O(NT) ,当 T N23 的时候可以做到 O(N23)
于是总体复杂度为 O(N34)

代码:

#include 
#include 
#include 
using namespace std;
typedef long long LL;
const int maxn = 1000010, mod = 1000000007, inv2 = 500000004;
mapint> Hash;
int n, sqn, tot, prime[maxn], f[maxn], ans;
bool vis[maxn];
inline void inc(int &x, int y)
{
    x += y;
    if(x >= mod)
        x -= mod;
}
inline void dec(int &x, int y)
{
    x -= y;
    if(x < 0)
        x += mod;
}
inline int num1(int x)
{
    return x * (x + 1LL) % mod * inv2 % mod;
}
inline int num1(int L, int R)
{
    int ret = num1(R);
    dec(ret, num1(L - 1));
    return ret;
}
int calc_imu(int x)
{
    if(x <= sqn)
        return f[x];
    if(Hash.count(x))
        return Hash[x];
    int ret = 1;
    for(int i = 2, j; i <= x; i = j + 1)
    {
        j = x / (x / i);
        dec(ret, (LL)num1(i, j) * calc_imu(x / i) % mod);
    }
    return Hash[x] = ret;
}
inline int calc_imu(int L, int R)
{
    int ret = calc_imu(R);
    dec(ret, calc_imu(L - 1));
    return ret;
}
int calc_g(int n)
{
    int ret = 0;
    for(int i = 1, j; i <= n; i = j + 1)
    {
        j = n / (n / i);
        inc(ret, (j - i + 1LL) * num1(n / i) % mod);
    }
    return ret;
}
int calc_h(int n)
{
    int ret = 0;
    for(int i = 1, j; i <= n; i = j + 1)
    {
        j = n / (n / i);
        inc(ret, (LL)(n / i) * num1(i, j) % mod);
    }
    return ret;
}
int main()
{
    scanf("%d", &n);
    sqn = (int)ceil(pow(n, 2.0 / 3));
    f[1] = 1;
    for(int i = 2; i <= sqn; ++i)
    {
        if(!vis[i])
        {
            prime[tot++] = i;
            dec(f[i], 1);
        }
        for(int j = 0, k = sqn / i, o; j < tot && prime[j] <= k; ++j)
        {
            vis[o = i * prime[j]] = 1;
            if(i % prime[j] == 0)
            {
                f[o] = 0;
                break;
            }
            else
                dec(f[o], f[i]);
        }
        f[i] = (f[i - 1] + (LL)i * f[i]) % mod;
    }
    for(int i = 1, j; i <= n; i = j + 1)
    {
        j = n / (n / i);
        inc(ans, (LL)calc_imu(i, j) * calc_g(n / i) % mod * calc_g(n / i) % mod);
    }
    printf("%d\n", ans);
    return 0;
}

你可能感兴趣的:(51nod)