递推化简+线段树区间维护,P6477 [NOI Online #2 提高组] 子序列问题

一、题目

1.1题目背景

2s 512M

1.2题目描述

给定一个长度为 n n n 的正整数序列 A 1 A_1 A1, A 2 A_2 A2, ⋯ \cdots , A n A_n An。定义一个函数 f ( l , r ) f(l,r) f(l,r) 表示:序列中下标在 [ l , r ] [l,r] [l,r] 范围内的子区间中,不同的整数个数。换句话说, f ( l , r ) f(l,r) f(l,r) 就是集合 { A l , A l + 1 , ⋯   , A r } \{A_l,A_{l+1},\cdots,A_r\} {Al,Al+1,,Ar} 的大小,这里的集合是不可重集,即集合中的元素互不相等。

现在,请你求出 ∑ l = 1 n ∑ r = l n ( f ( l , r ) ) 2 \sum_{l=1}^n\sum_{r=l}^n (f(l,r))^2 l=1nr=ln(f(l,r))2。由于答案可能很大,请输出答案对 1 0 9 + 7 10^9 +7 109+7 取模的结果。

1.3输入格式

第一行一个正整数 n n n,表示序列的长度。

第二行 n n n 个正整数,相邻两个正整数用空格隔开,表示序列 A 1 A_1 A1, A 2 A_2 A2, ⋯ \cdots , A n A_n An

1.4输出格式

仅一行一个非负整数,表示答案对 1 0 9 + 7 10^9+7 109+7 取模的结果。

1.5样例 #1

样例输入 #1

4
2 1 3 2

样例输出 #1

43

样例 #2

样例输入 #2

3
1 1 1

样例输出 #2

6

1.6提示

对于 10 % 10\% 10% 的数据,满足 1 ≤ n ≤ 10 1 \leq n \leq 10 1n10

对于 30 % 30\% 30% 的数据,满足 1 ≤ n ≤ 100 1 \leq n \leq 100 1n100

对于 50 % 50\% 50% 的数据,满足 1 ≤ n ≤ 1 0 3 1\leq n \leq 10^3 1n103

对于 70 % 70\% 70% 的数据,满足 1 ≤ n ≤ 1 0 5 1 \leq n \leq 10^5 1n105

对于 100 % 100\% 100% 的数据,满足 1 ≤ n ≤ 1 0 6 1\leq n\leq 10^6 1n106,集合中每个数的范围是 [ 1 , 1 0 9 ] [1,10^9] [1,109]

1.7原题链接

https://www.luogu.com.cn/problem/P6477


二、解题报告

1、思路分析

1e6数据量,必须想出O(N)或者O(NlogN)的解法,不然肯定过不了

我们发现a[r]对于f(l , r)的影响为:即a[r]上一次出现位置为last[a[r]],那么f(i , r)都+1,i >= last[a[r]] + 1,其它f值都不变

这样我们似乎可以得出某种递推关系,我们令g® = Σf(l , r)^2

那么g® - g(r - 1) = Σf(l , r) ^ 2 - f(l , r - 1) ^ 2,其中last[a[r]] + 1 <= l <= r

进一步化简:g(r ) - g(r - 1) = Σf(l , r) * 2 + r - last[a[r]]

这样一来,我们就把平方和转化为线性和

我们只需要用线段树存储f(l , r),其中l >= 1,即可,然后每求一次g®都对[last[a[r]] + 1 , r]区间+1

每次求g®只需要知道g(r - 1)和线段树对应区间和,这一步是O(logn),枚举右端点是O(n),整体O(nlogn)

2、复杂度

时间复杂度:O(nlogn) 空间复杂度:O(n)

3、代码详解

#include 
#include 
#include 
#include 
using namespace std;
#define lc p << 1
#define rc p << 1 | 1
#define int long long
const int N = 1e6 + 5, MOD = 1e9 + 7;
int a[N], b[N], last[N], d[N];
struct Node
{
    int l, r, s, add;
} tr[N << 2];

void build(int p, int l, int r)
{
    tr[p] = {l, r, 0, 0};
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    build(lc, l, mid), build(rc, mid + 1, r);
}

void pushup(int p)
{
    tr[p].s = tr[lc].s + tr[rc].s;
}

void pushdown(int p)
{
    if (tr[p].add)
    {
        tr[lc].s += (tr[lc].r - tr[lc].l + 1) * tr[p].add, tr[rc].s += (tr[rc].r - tr[rc].l + 1) * tr[p].add;
        tr[lc].add += tr[p].add, tr[rc].add += tr[p].add;
        tr[p].add = 0;
    }
}

void update(int p, int l, int r, int k)
{
    if (l <= tr[p].l && tr[p].r <= r)
    {
        tr[p].s += (tr[p].r - tr[p].l + 1) * k;
        tr[p].add += k;
        return;
    }
    int mid = (tr[p].l + tr[p].r) >> 1;
    pushdown(p);
    if (l <= mid)
        update(lc, l, r, k);
    if (r > mid)
        update(rc, l, r, k);
    pushup(p);
}

int query(int p, int l, int r)
{
    if (l <= tr[p].l && r >= tr[p].r)
        return tr[p].s;
    int mid = (tr[p].l + tr[p].r) >> 1, ret = 0;
    pushdown(p);
    if (l <= mid)
        ret += query(lc, l, r);
    if (r > mid)
        ret += query(rc, l, r);
    return ret;
}

int read()
{
    int s = 0, w = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9')
        w *= (ch == '-' ? -1 : 1), ch = getchar();
    while (ch >= '0' && ch <= '9')
        s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar();
    return s * w;
}

void write(int x)
{
    if (x < 0)
        putchar('-');
    if (x > 9)
        write(x / 10);
    putchar((x % 10) ^ 48);
}

signed main()
{
    // freopen("in.txt", "r", stdin);
    int n = read();
    for (int i = 1; i <= n; i++)
        a[i] = read(), b[i] = a[i];
    sort(b + 1, b + n + 1);
    int m = unique(b + 1, b + n + 1) - b - 1;
    for (int i = 1, k; i <= n; i++)
    {
        k = lower_bound(b + 1, b + m + 1, a[i]) - b;
        d[i] = last[k];
        last[k] = i;
    }
    int ans = 0;
    build(1, 1, n);
    for (int i = 1, cur = 0; i <= n; i++)
    {
        cur += i - d[i] + (query(1, d[i] + 1, i) << 1);
        cur %= MOD;
        ans = (ans + cur) % MOD;
        update(1, d[i] + 1, i, 1);
    }
    write(ans);
    return 0;
}

你可能感兴趣的:(OJ刷题解题报告,算法,c++,数据结构,线段树)