2016多校联合第5场部分题解 HDU5780,5781,5783,5784,5785,5787,5791,5792

5780

gcd(xa1,xb1) 这个式子可以把里面拆开成多项式,变式为 (x1)gcd(1+x+...+xa1,1+x+...+xb1)

然后由于 gcd(1+x+...+xa1,1+x+...+xb1)=(1+x+...+xgcd(a,b))

然后再带入化简求和在化简得到 gcd(xa1,xb1)=xgcd(a,b)1

对于求 gcd=d 的对数,可以考虑欧拉函数,然后对数就是 2(phi[1]+...+phi[nd])1 。预处理出来这些东西,直接枚举 d 会超时。考虑 n/d 有很多共同的变量,那么我们只需要枚举这 sqrt(n) 个不同的变量,然后x求个等比公式就可以在sqrt(n)的时间内搞定了

//
//  Created by Running Photon
//  Copyright (c) 2015 Running Photon. All rights reserved.
//
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define ALL(x) x.begin(), x.end()
#define INS(x) inserter(x, x,begin())
#define ll long long
#define CLR(x) memset(x, 0, sizeof x)
using namespace std;
const int inf = 0x3f3f3f3f;
const ll MOD = 1e9 + 7;
const int maxn = 1e6 + 50;
const int maxv = 1e3 + 10;
const double eps = 1e-9;

ll phi[maxn];
void exgcd(ll a, ll b, ll &d, ll &x, ll &y) {
    if(!b) {d = a; x = 1; y = 0;}
    else {exgcd(b, a%b, d, y, x); y -= x * (a / b);}
}
ll inv(ll a, ll n) {
    ll d, x, y;
    exgcd(a, n, d, x, y);
    return d == 1 ? (x + n) % n : -1;
}

void init() {
    phi[1] = 1;
    phi[0] = 0;
    for(int i = 2; i <= 1e6 + 5; i++) {
        if(!phi[i]) {
            for(int j = i; j <= 1e6 + 5; j += i) {
                if(!phi[j]) phi[j] = j;
                phi[j] = phi[j] / i * (i - 1);
            }
        }
    }
    for(int i = 1; i <= 1e6 + 5; i++) {
        phi[i] += phi[i-1];
    }
    for(int i = 1; i <= 1e6; i++) {
        phi[i] = (2 * phi[i] - 1 + MOD) % MOD;
    }
}
ll Pow(ll a, ll n) {
    ll ret = 1;
    while(n) {
        if(n & 1) ret = ret * a % MOD;
        a = a * a % MOD;
        n >>= 1;
    }
    return ret;
}
ll calc(ll a, ll b, ll x) {
    ll ret = 0;
    ret = (Pow(x, b+1) - Pow(x, a) + MOD) % MOD;
    ret = ret * inv(x - 1, MOD) % MOD;
    // printf("ret = %lld\n", ret);
    return ret;
}
int main() {
#ifdef LOCAL
    freopen("C:\\Users\\Administrator\\Desktop\\in.txt", "r", stdin);
    freopen("C:\\Users\\Administrator\\Desktop\\out.txt","w",stdout);
#endif
//  ios_base::sync_with_stdio(0);
    init();
    int T;
    scanf("%d", &T);
    while(T--) {
        ll x, n;
        scanf("%lld%lld", &x, &n);
        ll res = MOD - n * n % MOD;
        if(x == 1) {puts("0"); continue;}
        for(int i = 1, la = 0; i <= n; i = la + 1) {
            la = n / (n / i);
            // printf("a = %d  b = %d\n", i, la);
            res = (res + calc(i, la, x) * phi[n/i] % MOD) % MOD;
        }
        printf("%lld\n", res);
    }
    return 0;
}

HDU5781

题目原型是一个选楼层扔鸡蛋找到敲好不碎的楼层。这里取钱和扔鸡蛋一样,被警告一次等同于碎一个蛋。设 f(i,j) 表示当前可能的钱为 (0,i) 还剩 j 次被警告的机会找到准确楼层的最小期望。

对于当前状态,显然可以考虑枚举当前我们取的钱,从 1 开始( 0 无意义)到 i ,然后根据题目条件,钱数是均匀分布的,那么对于当前决策有两种可能,被警告或者不被警告。若被警告,说明钱数小于 k ,这样的概率是 ki+1 ,若不被警告,说明钱数大于等于 k ,概率为 i+1ki+1

这样 dp[i][j]=min(ki+1dp[k1][j1]i+1ki+1dp[ik][j])+1

直接计算是 n2m 但是想想因为最优策略,至少是接近2分的策略,那么钦定 m15 即可

//
//  Created by Running Photon
//  Copyright (c) 2015 Running Photon. All rights reserved.
//
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define ALL(x) x.begin(), x.end()
#define INS(x) inserter(x, x,begin())
#define ll long long
#define CLR(x) memset(x, 0, sizeof x)
using namespace std;
const double inf = 1000000000000.0;
const int MOD = 1e9 + 7;
const int maxn = 1e6 + 10;
const int maxv = 2e3 + 10;
const double eps = 1e-9;

double dp[maxv][20];
int main() {
#ifdef LOCAL
    freopen("C:\\Users\\Administrator\\Desktop\\in.txt", "r", stdin);
    freopen("C:\\Users\\Administrator\\Desktop\\out.txt","w",stdout);
#endif
//  ios_base::sync_with_stdio(0);
    int k, w;
    k = 2000; w = 15;
    w = min(w, 15);
    for(int i = 0; i <= k; i++)
    for(int j = 0; j <= w; j++)
    dp[i][j] = inf;
    CLR(dp[0]);
    for(int i = 1; i <= k; i++) {
        for(int j = 1; j <= w; j++) {
            for(int k = 1; k <= i; k++) {
                dp[i][j] = min(dp[i][j], (i-k+1.0)/(i+1.0)*dp[i-k][j] + k/(i+1.0)*dp[k-1][j-1] + 1.0);
            }
        }
    }
    while(scanf("%d%d", &k, &w) != EOF) {
        w = min(w, 15);
        printf("%.6f\n", dp[k][w]);
    }   

    return 0;
}

HDU5783

因为要求分割区间的各前缀和大于等于0,那么考虑倒着来,先遇到的0和正数显然独立成为一个区间,当遇到某些连续的负数就去要用前面的正数和0去填补了。

//
//  Created by Running Photon
//  Copyright (c) 2015 Running Photon. All rights reserved.
//
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define ALL(x) x.begin(), x.end()
#define INS(x) inserter(x, x,begin())
#define ll long long
#define CLR(x) memset(x, 0, sizeof x)
using namespace std;
const int inf = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
const int maxn = 1e6 + 10;
const int maxv = 1e3 + 10;
const double eps = 1e-9;

ll a[maxn];
int main() {
#ifdef LOCAL
    freopen("C:\\Users\\Administrator\\Desktop\\in.txt", "r", stdin);
    freopen("C:\\Users\\Administrator\\Desktop\\out.txt","w",stdout);
#endif
//  ios_base::sync_with_stdio(0);
    int n;
    while(scanf("%d", &n) != EOF) {
        for(int i = 1; i <= n; i++) {
            scanf("%lld", a + i);
        }
        int ans = 0;
        ll sum = 0;
        for(int i = n; i > 0; i--) {
            sum += a[i];
            if(sum >= 0) {
                ans++;
                sum = 0;
            }
        }
        printf("%d\n", ans);
    }
    return 0;
}

HDU5784

要求点集中锐角三角形的个数,一个锐角三角形 3 锐角,一个直角和钝角都是 2 锐角,搞出所有锐角个数,然后减去直角和钝角的两倍就是锐角三角形的锐角个数,然后除以3就是锐角三角形个数。

为了减小时间开销,可以枚举一个点,然后以它为中心,其他点做极角排序,先扫出锐角的个数,在扫出小于 PI 的角个数,一减就是钝角+直角个数。因为要往后扩 PI ,可以把原数组复制一遍,注意精度, eps 至少 1012

//
//  Created by Running Photon
//  Copyright (c) 2015 Running Photon. All rights reserved.
//
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define ALL(x) x.begin(), x.end()
#define INS(x) inserter(x, x,begin())
#define ll long long
#define CLR(x) memset(x, 0, sizeof x)
using namespace std;
const int inf = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
const int maxn = 1e6 + 10;
const int maxv = 2e3 + 10;
const double eps = 1e-12;
const double PI = acos(-1);

ll x[maxv], y[maxv];

int comp(double x) {
    if(abs(x) < eps) return 0;
    return x < 0 ? -1 : 1;
}
int m;
ll calc(std::vector<double> &a, double delta) {
    ll ret = 0;
    for(int i = 0, j = 0, k = 0; i < m; i = k + 1) {
        while(k + 1 < m && comp(a[k+1] - a[i]) == 0) ++k;
        j = max(j, k);
        while(j < a.size() && comp(a[j] - a[i] - delta) < 0) ++j;
        ret += (k - i + 1) * (j - k - 1);
    }
    return ret;
}
int main() {
#ifdef LOCAL
    freopen("C:\\Users\\Administrator\\Desktop\\in.txt", "r", stdin);
    freopen("C:\\Users\\Administrator\\Desktop\\out.txt","w",stdout);
#endif
//  ios_base::sync_with_stdio(0);
    int n;
    // cout << atan2(1, 2e9) * 180 << endl;
    while(scanf("%d", &n) != EOF) {
        for(int i = 1; i <= n; i++) {
            scanf("%lld%lld", &x[i], &y[i]);
        }
        ll rui = 0, dun = 0;
        for(int T = 1; T <= n; T++) {
            std::vector<double> a;
            for(int j = 1; j <= n; j++) {
                if(T == j) continue;
                a.push_back(atan2(y[j]-y[T], x[j]-x[T]));
            }
            m = a.size();
            for(int i = 0; i < m; i++) {
                a.push_back(PI * 2 + a[i]);
            }
            sort(ALL(a));
            int tot = 0, curAcute = 0;
            curAcute = calc(a, PI / 2);
            tot = calc(a, PI);
            rui += curAcute;
            dun += tot - curAcute;
        }
        printf("%lld\n", (rui - dun * 2) / 3);
    }

    return 0;
}

HDU5785

此题介绍两种做法。

一种是题解说的,用 manacher 处理出各个最长的回文串的区间,就是下标算的蛋疼。在考虑一个区间对答案的影响的时候,设这个区间为 [l,mid,r] , mid 是区间的中心,可以发现在 mid l,r 两端的时候, l+ri 的值都是当前坐标 i 对应的回文串的端点。在 [l,mid] 就是右端点, [mid,r] 就是左端点,在处理右端点数组的时候,就在 l 位置加上 l+r mid+1 位置减去 l+r 。并且考虑当前点被区间覆盖的个数 num 也在相应的地方加1和减1,然后求个前缀和便能算出当前点右端点贡献。详细细节见代码

//
//  Created by Running Photon
//  Copyright (c) 2015 Running Photon. All rights reserved.
//
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define ALL(x) x.begin(), x.end()
#define INS(x) inserter(x, x,begin())
#define ll long long
#define CLR(x) memset(x, 0, sizeof x)
using namespace std;
const int inf = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
const int maxn = 2e6 + 10;
const int maxv = 1e3 + 10;
const double eps = 1e-9;

char a[maxn], s[maxn*2];
int n, p[maxn*2];
void init() {
    s[0] = '@'; s[1] = '#', n = 2;
    for(int i = 0; a[i]; i++) {
        s[n++] = a[i];
        s[n++] = '#';
    }
    s[n] = 0;
}
void kp() {
    init();
    int mx = 0, id;
    for(int i = 1; i < n; i++) {
        if(mx > i) p[i] = min(mx - i, p[id * 2 - i]);
        else p[i] = 1;
        while(s[i - p[i]] == s[i + p[i]]) p[i]++;
        if(i + p[i] > mx) {
            mx = i + p[i];
            id = i;
        }
    }
}
ll L[maxn], R[maxn];
ll lsum[maxn], rsum[maxn];
ll lnum[maxn], rnum[maxn];
int main() {
#ifdef LOCAL
    freopen("C:\\Users\\Administrator\\Desktop\\in.txt", "r", stdin);
    freopen("C:\\Users\\Administrator\\Desktop\\out.txt","w",stdout);
#endif
//  ios_base::sync_with_stdio(0);
    while(scanf("%s", a) != EOF) {
        kp(); n--;
        CLR(lnum); CLR(rnum);
        CLR(L); CLR(R); CLR(lsum); CLR(rsum);
        // for(int i = 1; i <= n; i++) {
        //  printf("%c", s[i]);
        // }
        // puts("");
        for(int i = 2; i < n; i++) {
            if(i & 1) {
                int cur = i / 2;
                int l = cur + 1 - (p[i] - 1) / 2;
                int r = cur + (p[i] - 1) / 2;
                if(r <= l) continue;
                int back = cur + 1;
                // printf("%d %d %d %d\n", l, cur, back, r);
                lsum[back] += l + r, lnum[back]++; 
                lsum[r+1] -= l + r, lnum[r+1]--;
                rsum[l] += l + r, rnum[l]++;
                rsum[cur+1] -= l + r, rnum[cur+1]--;
            }
            else {
                int cur = i / 2;
                int l = cur - (p[i] - 2) / 2;
                int r = cur + (p[i] - 2) / 2;
                // printf("%d %d %d\n", l, cur, r);
                lsum[cur] += l + r, lnum[cur]++;
                lsum[r+1] -= l + r, lnum[r+1]--;
                rsum[l] += l + r, rnum[l]++;
                rsum[cur+1] -= l + r, rnum[cur+1]--;
            }
        }
        ll res = 0;
        n /= 2;
        for(int i = 1; i <= n; i++) {
            lsum[i] += lsum[i-1];
            rsum[i] += rsum[i-1];
            lnum[i] += lnum[i-1];
            rnum[i] += rnum[i-1];
        }
        // for(int i = 1; i <= n; i++) {
        //  printf("lsum[%d] = %lld\n", i, lsum[i] - i);
        //  printf("rsum[%d] = %lld\n", i, rsum[i] - i);
        // }
        for(int i = 1; i <= n; i++) {
            L[i] = (lsum[i] - i * lnum[i] % MOD + MOD) % MOD;
            R[i] = (rsum[i] - i * rnum[i] % MOD + MOD) % MOD;
        }
        for(int i = 1; i < n; i++) {
            res = (res + L[i] * R[i+1] % MOD) % MOD;
        }
        printf("%lld\n", res);
    }

    return 0;
}

另一种是用回文树处理,方便了很多,但是注意空间限制。我们设当前处理的点是 i1 i ,那么对答案的贡献就是 ab(ia)(i+b1) 其中 a,b 分别是以 i1 结尾的各个回文串长度,以 i 开头的各个回文串长度。

cnta,suma,cntb,sumb 分别是以i结尾的回文串个数,长度和,以i开头的回文串个数,长度和。

拆开后为

cntacntbi2+i(sumbcntb)cntaicntbsumasumasumb

回文树可以统计出以i下标结尾的回文串个数,回文串长度和,顺着倒着分别搞一次, cnta,cntb,suma,sumb 都出来了带入计算即可。

//
//  Created by Running Photon
//  Copyright (c) 2015 Running Photon. All rights reserved.
//
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define ALL(x) x.begin(), x.end()
#define INS(x) inserter(x, x,begin())
#define ll long long
#define CLR(x) memset(x, 0, sizeof x)
using namespace std;
const int inf = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
const int maxn = 1e6 + 10;
const int maxv = 18 + 10;
const double eps = 1e-9;

const int SIZE = 26;
struct PalindromicTree {
    //nxt指针,指向的串为当前串加上同一个字符构成
    int nxt[maxn][SIZE];
    //fail指针,失配后跳转到fail指向的节点,指向的节点为当前节点的最长公共回文后缀
    int fail[maxn];
    int cnt[maxn];
    int num[maxn];
    int len[maxn];
    char S[maxn];
    int last;
    int n;
    int p;
    int newNode(int l) {
        for(int i = 0; i < SIZE; i++) nxt[p][i] = 0;
        cnt[p] = num[p] = 0;
        len[p] = l;
        return p++;
    }
    void init() {
        p = 0;
        newNode(0);
        newNode(-1);
        last = 0;
        n = 0;
        S[n] = -1;
        fail[0] = 1;
    }
    int getFail(int x) {
        while(S[n - len[x] - 1] != S[n]) x = fail[x];
        return x;
    }
    pair  add(int c) {
        c -= 'a';
        S[++n] = c;
        int cur = getFail(last);
        if(!nxt[cur][c]) {
            int now = newNode(len[cur] + 2);
            fail[now] = nxt[getFail(fail[cur])][c];
            nxt[cur][c] = now;
            num[now] = num[fail[now]] + len[now];
            if(num[now] >= MOD) num[now] -= MOD;
            cnt[now] = cnt[fail[now]] + 1;
        }
        last = nxt[cur][c];
        return {cnt[last], num[last]};
    }
}pt;
char s[maxn];
int Lcnt[maxn], Lnum[maxn];
ll mul(ll a, ll b) {
    return a * b % MOD;
}
int main() {
#ifdef LOCAL
    freopen("C:\\Users\\Administrator\\Desktop\\in.txt", "r", stdin);
    // freopen("C:\\Users\\Administrator\\Desktop\\out.txt","w",stdout);
#endif
//  ios_base::sync_with_stdio(0);
    while(scanf("%s", s + 1) != EOF) {
        int n = strlen(s+1);
        pt.init();
        for(int i = 1; i <= n; i++) {
            auto ret = pt.add(s[i]);
            Lcnt[i] = ret.first;
            Lnum[i] = ret.second;
        }
        pt.init();
        ll ans = 0;
        for(int i = n; i > 1; i--) {
            auto ret = pt.add(s[i]);
            ll sufcnt = ret.first;
            ll sufnum = ret.second;
            ll L = mul(mul(i, i), mul(sufcnt, Lcnt[i-1]));
            ll M = mul(Lcnt[i-1], sufnum - sufcnt) - mul(sufcnt, Lnum[i-1]) + MOD;
            M %= MOD;
            M = mul(M, i);
            ll R = mul(sufnum - sufcnt, Lnum[i-1]);
            ans += L + M - R;
            ans = (ans + MOD) % MOD;
        }
        printf("%lld\n", ans);
    }

    return 0;
}

HDU5787

明显的数位 dp 。题解很机智,直接 map vector 方便了很多。因为 K 最大才是5,显然保存前 4 位,往下走的时候要看当前位能不能填数字 i 。这样边界问题比较蛋疼,前导零处理起来比较麻烦。于是我又开了一维表示当前需要的与之前比较的位数。最后为了只 memset 一次又开了一维表示当前处理的 K 是多少。。具体细节见代码

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
#define ls id<<1,l,mid
#define rs id<<1|1,mid+1,r
#define OFF(x) memset(x,-1,sizeof x)
#define CLR(x) memset(x,0,sizeof x)
#define MEM(x) memset(x,0x3f,sizeof x)
typedef long long ll;
typedef pair<int,int> pii;
const int maxn = 1e4 + 50;
const int maxm = 1e6 + 50;
const double eps = 1e-10;
const int max_index = 62;
const int inf = 0x3f3f3f3f ;
int MOD = 1e5;


ll dp[19][maxn][6][4];
int digit[20];
ll L, R, K;
int all;
ll dfs(int pos, int num, int bit, int limit) {
    if(pos == -1) return 1;
    if(!limit && dp[pos][num][bit][K-2] != -1) return dp[pos][num][bit][K-2];
    int up = limit ? digit[pos] : 9;
    ll ret = 0;
    int tmp = num;
    int viss[10];
    CLR(viss);
    for(int i = 0; i < bit; i++) {
        viss[tmp % 10]++;
        tmp /= 10;
    }
    for(int i = 0; i <= up; i++) {
        if(viss[i]) {
            continue;
        }
        int nbit = bit;
        if(bit == 0 && i == 0) nbit = 0;
        else nbit++;
        if(nbit >= K) nbit--;
        ret += dfs(pos-1, (num*10+i)%MOD, nbit, i == up && limit);
    }
    if(!limit) dp[pos][num][bit][K-2] = ret;
    return ret;
}
ll calc(ll x) {
    int pos = 0;
    while(x) {
        digit[pos++] = x % 10;
        x /= 10;
    }
    all = pos - 1;
    return dfs(pos-1, 0, 0, 1);
}
int main () {
#ifdef LOCAL
    freopen("C:\\Users\\Administrator\\Desktop\\in.txt", "r", stdin);
    freopen("C:\\Users\\Administrator\\Desktop\\out.txt","w",stdout);
#endif
    // cout << sizeof(dp) / 1024 << endl;
    memset(dp, -1, sizeof dp);
    while(scanf("%lld%lld%lld", &L, &R, &K) != EOF) {
        MOD = 1;
        for(int i = 1; i < K; i++) MOD *= 10;
        printf("%lld\n", calc(R) - calc(L - 1));
    }

    return 0;
}

HDU5791

dp 水题, dp[i][j] 表示 s1 串前 i 个字符, s2 串前 j 个字符匹配的序列的总个数。

显然有 dp[i][j]=dp[i1][j]+dp[i][j1]dp[i1][j1] 因为此处 dp[i1][j1] dp[i1][j],dp[i][j1] 包含,多算了一次所以需要减去。然后若 s1[i]==s2[j] 说明可以以 i,j 为底构建新的序列,这时候 dp[i][j]+=dp[i1][j1]+1

//
//  Created by Running Photon
//  Copyright (c) 2015 Running Photon. All rights reserved.
//
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define ALL(x) x.begin(), x.end()
#define INS(x) inserter(x, x,begin())
#define ll long long
#define CLR(x) memset(x, 0, sizeof x)
using namespace std;
const int inf = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
const int maxn = 1e6 + 10;
const int maxv = 1e3 + 10;
const double eps = 1e-9;

ll dp[maxv][maxv];
int a[maxv], b[maxv];
int main() {
#ifdef LOCAL
    freopen("C:\\Users\\Administrator\\Desktop\\in.txt", "r", stdin);
    freopen("C:\\Users\\Administrator\\Desktop\\out.txt","w",stdout);
#endif
//  ios_base::sync_with_stdio(0);
    int n, m;
    while(scanf("%d%d", &n, &m) != EOF) {
        for(int i = 1; i <= n; i++) {
            scanf("%d", a + i);
        }
        for(int i = 1; i <= m; i++) {
            scanf("%d", b + i);
        }
        CLR(dp);
        for(int i = 1; i <= n; i++) {
            for(int j = 1; j <= m; j++) {
                dp[i][j] = dp[i-1][j] + dp[i][j-1] - dp[i-1][j-1];
                dp[i][j] = (dp[i][j] % MOD + MOD) % MOD;
                // printf("dp[%d][%d] = %lld\n", i, j, dp[i][j]);
                if(a[i] == b[j]) {
                    dp[i][j] = (dp[i][j] + dp[i-1][j-1] + 1) % MOD;
                }
                // printf("dp[%d][%d] = %lld\n", i, j, dp[i][j]);
            }
        }
        printf("%lld\n", dp[n][m]);
    }

    return 0;
}

HDU5792

要找出满足要求的四元组,可以先处理出所有顺序对个数,逆序对个数,然后去重。设 Aa,Ab,Ac,Ad 其中 a<b,c<d,Aa<Ab,Ac<Ad 重复就是当 Aa==Ac or Aa==Ad or Ab==Ac or Ab==Ad 四种情况。统计可以用 bit

//
//  Created by Running Photon
//  Copyright (c) 2015 Running Photon. All rights reserved.
//
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define ALL(x) x.begin(), x.end()
#define INS(x) inserter(x, x,begin())
#define ll long long
#define CLR(x) memset(x, 0, sizeof x)
using namespace std;
const int inf = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
const int maxn = 1e5 + 10;
const int maxv = 1e3 + 10;
const double eps = 1e-9;


int lowbit(int x) {return x & (-x);}
int MAXN;
void updata(ll *a, int id, int val) {
    while(id <= MAXN) {
        a[id] += val;
        id += lowbit(id);
    }
}
ll query(ll *a, int id) {
    ll ret = 0;
    while(id) {
        ret += a[id];
        id -= lowbit(id);
    }
    return ret;
}
ll A[maxn];
ll dp[maxn];
ll dpv[maxn];
ll lbig[maxn], rbig[maxn], lsmall[maxn], rsmall[maxn];
int main() {
#ifdef LOCAL
    freopen("C:\\Users\\Administrator\\Desktop\\in.txt", "r", stdin);
    freopen("C:\\Users\\Administrator\\Desktop\\out.txt","w",stdout);
#endif
//  ios_base::sync_with_stdio(0);
    int n;
    while(scanf("%d", &n) != EOF) {
        std::vector<int> xs;
        for(int i = 1; i <= n; i++) {
            scanf("%d", A + i);
            xs.push_back(A[i]);
        }
        sort(ALL(xs));
        xs.resize(unique(ALL(xs)) - xs.begin());
        MAXN = xs.size() + 5;
        for(int i = 0; i <= MAXN; i++) {
            dp[i] = dpv[i] = 0;
        }
        int SIZE = xs.size() + 2;
        ll small = 0;
        ll big = 0;
        CLR(lbig); CLR(rbig); CLR(lsmall); CLR(rsmall);
        for(int i = n; i > 0; i--) {
            int id = lower_bound(ALL(xs), A[i]) - xs.begin() + 1;
            ll L = query(dp, id - 1);
            ll R = query(dpv, SIZE - id - 1);
            updata(dp, id, 1);
            updata(dpv, SIZE - id, 1);
            rbig[i] = R;
            rsmall[i] = L;
            small += L;
            big += R;
        }
        // printf("small = %lld big = %lld\n", small, big);
        for(int i = 0; i <= MAXN; i++) {
            dp[i] = dpv[i] = 0;
        }
        for(int i = 1; i <= n; i++) {
            int id = lower_bound(ALL(xs), A[i]) - xs.begin() + 1;
            ll L = query(dp, id - 1);
            ll R = query(dpv, SIZE - id - 1);
            updata(dp, id, 1);
            updata(dpv, SIZE - id, 1);
            lbig[i] = R;
            lsmall[i] = L;
        }
        ll ans = small * big;
        for(int i = 1; i <= n; i++) {
            ans = ans-rbig[i]*rsmall[i]-lbig[i]*lsmall[i]-lsmall[i]*rsmall[i]-lbig[i]*rbig[i];
        }
        printf("%lld\n", ans);
    }

    return 0;
}

你可能感兴趣的:(杂题)