P3707 [SDOI2017] 相关分析 Solution

Description

给定序列 x = ( x 1 , x 2 , ⋯   , x n ) , y = ( y 1 , y 2 , ⋯   , y n ) x=(x_1,x_2,\cdots,x_n),y=(y_1,y_2,\cdots,y_n) x=(x1,x2,,xn),y=(y1,y2,,yn),有 m m m 个操作,分三种:

  1. query ⁡ ( l , r ) \operatorname{query}(l,r) query(l,r):求 ∑ i = l r ( x i − x ˉ ) ( y i − y ˉ ) ∑ i = l r ( x i − x ˉ ) 2 \dfrac{\sum_{i=l}^r (x_i-\bar x)(y_i-\bar y)}{\sum_{i=l}^r (x_i-\bar x)^2} i=lr(xixˉ)2i=lr(xixˉ)(yiyˉ),其中 x ˉ \bar x xˉ x l ⋯ r x_{l\cdots r} xlr 的平均值, y ˉ \bar y yˉ y l ⋯ r y_{l\cdots r} ylr 的平均值,保证分母不为 0 0 0
  2. add ⁡ ( l , r , s , t ) \operatorname{add}(l,r,s,t) add(l,r,s,t):对所有 i ∈ [ l , r ] i \in [l,r] i[l,r] 执行 x i ← x i + s ,    y i ← y i + t x_i \leftarrow \textcolor{red}{x_i} + s, \;y_i \leftarrow \textcolor{red}{y_i} + t xixi+s,yiyi+t.
  3. modify ⁡ ( l , r , s , t ) \operatorname{modify}(l,r,s,t) modify(l,r,s,t):对所有 i ∈ [ l , r ] i \in [l,r] i[l,r] 执行 x i ← i + s ,    y i ← i + t x_i \leftarrow \textcolor{red}{i} + s, \;y_i \leftarrow \textcolor{red}{i} + t xii+s,yii+t.

Limitations

1 ≤ n , m ≤ 1 0 5 1 \le n,m \le 10^5 1n,m105
∣ s ∣ , ∣ t ∣ ≤ 1 0 9 |s|,|t| \le 10^9 s,t109
0 ≤ ∣ x i ∣ , ∣ y i ∣ ≤ 1 0 5 0 \le |x_i|,|y_i| \le 10^5 0xi,yi105
1 s , 125 MB 1\text{s},125\text{MB} 1s,125MB

Solution

发现要求的式子难以维护,考虑化简,得到 ∑ x i y i − ∑ x i ∑ y i r − l + 1 ∑ x i 2 − ( ∑ x i ) 2 r − l + 1 \dfrac{\sum x_iy_i-\frac{\sum x_i \sum y_i}{r-l+1}}{\sum x_i^2 - \frac{(\sum x_i)^2}{r-l+1}} xi2rl+1(xi)2xiyirl+1xiyi(过程不好打所以省略)。
现在只需维护 ∑ x i ,    ∑ y i ,    ∑ x i 2 ,    ∑ x i y i \sum x_i,\;\sum y_i, \; \sum x_i^2,\; \sum x_iy_i xi,yi,xi2,xiyi,可以上线段树。
考虑 pushdown 如何写,发现后两个不好搞,同样化简式子:

  1. ∑ ( x i + s ) ( y i + t ) = ∑ x i y i + s ∑ y i + t ∑ x i + s t ( r − l + 1 ) \sum (x_i+s)(y_i+t)=\sum x_iy_i+s\sum y_i+t\sum x_i+st(r-l+1) (xi+s)(yi+t)=xiyi+syi+txi+st(rl+1)
  2. ∑ ( x i + s ) 2 = ∑ x i 2 + 2 s ∑ x i + s 2 ( r − l + 1 ) \sum (x_i+s)^2=\sum x^2_i+2s\sum x_i+s^2(r-l+1) (xi+s)2=xi2+2sxi+s2(rl+1)

然后写的时候注意顺序!!
还需要一个 x i ← i , y i ← i x_i \leftarrow i,y_i \leftarrow i xii,yii 的标记,写的时候需要用公式 1 2 + 2 2 + ⋯ + n 2 = n ( n + 1 ) ( 2 n + 1 ) 6 1^2+2^2+\cdots+n^2=\dfrac{n(n+1)(2n+1)}{6} 12+22++n2=6n(n+1)(2n+1)

剩下的不必多说,和普通的是一样的。
注意全部要开 double 因为可能爆 long long

Code

4.58 KB , 1.53 s , 26.24 MB    (in   total,   C++   20   with   O2) 4.58\text{KB},1.53\text{s},26.24\text{MB} \; \texttt{(in total, C++ 20 with O2)} 4.58KB,1.53s,26.24MB(in total, C++ 20 with O2)

// Problem: P3707 [SDOI2017] 相关分析
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3707
// Memory Limit: 125 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include 
using namespace std;

using i64 = long long;
using ui64 = unsigned long long;
using i128 = __int128;
using ui128 = unsigned __int128;
using f4 = float;
using f8 = double;
using f16 = long double;

template<class T>
bool chmax(T &a, const T &b){
	if(a < b){ a = b; return true; }
	return false;
}

template<class T>
bool chmin(T &a, const T &b){
	if(a > b){ a = b; return true; }
	return false;
}

struct Node {
    int l, r;
    f8 sumX, sumY, sumXX, sumXY, tagX, tagY;
    bool cover;
};

using Tree = vector<Node>;
int ls(int u) { return 2 * u + 1; }
int rs(int u) { return 2 * u + 2; }

void merge(Node& res, const Node& le, const Node& ri) {
    res.sumX = le.sumX + ri.sumX;
    res.sumY = le.sumY + ri.sumY;
    res.sumXX = le.sumXX + ri.sumXX;
    res.sumXY = le.sumXY + ri.sumXY;
}

void pushup(Tree& tr, int u) {
    merge(tr[u], tr[ls(u)], tr[rs(u)]);
}

void build(Tree& tr, int u, int l, int r, vector<f8>& X, vector<f8>& Y) {
    tr[u].l = l;
    tr[u].r = r;
    if (l == r) {
        tr[u].sumX = X[l];
        tr[u].sumY = Y[l];
        tr[u].sumXX = X[l] * X[l];
        tr[u].sumXY = X[l] * Y[l];
        return;
    }
    
    int mid = (l + r) >> 1;
    build(tr, ls(u), l, mid, X, Y);
    build(tr, rs(u), mid + 1, r, X, Y);
    pushup(tr, u);
}

f8 sqsum(f8 a) {
    return a * (a + 1) * (2 * a + 1) / 6;
}

void fix(Tree& tr, int u) {
    f8 lef = tr[u].l, rig = tr[u].r;
    tr[u].tagX = tr[u].tagY = 0;
    tr[u].cover = true;
    tr[u].sumX = tr[u].sumY = (lef + rig + 2) * (rig - lef + 1) / 2;
    tr[u].sumXX = tr[u].sumXY = sqsum(rig + 1) - sqsum(lef);
}

void apply(Tree& tr, int u, f8 tagX, f8 tagY) {
    int len = tr[u].r - tr[u].l + 1;
    tr[u].tagX += tagX;
    tr[u].tagY += tagY;
    
    tr[u].sumXY += tagY * tr[u].sumX + tagX * tr[u].sumY + tagX * tagY * len;
    tr[u].sumXX += 2 * tagX * tr[u].sumX + tagX * tagX * len;
    tr[u].sumX += tagX * len;
    tr[u].sumY += tagY * len;
}

void pushdown(Tree& tr, int u) {
    if (tr[u].cover) {
        fix(tr, ls(u));
        fix(tr, rs(u));
        tr[u].cover = false;
    }
    
    apply(tr, ls(u), tr[u].tagX, tr[u].tagY);
    apply(tr, rs(u), tr[u].tagX, tr[u].tagY);
    tr[u].tagX = tr[u].tagY = 0;
}

void add(Tree& tr, int u, int l, int r, f8 X, f8 Y) {
    if (l <= tr[u].l && tr[u].r <= r) {
        apply(tr, u, X, Y);
        return;
    }
    int mid = (tr[u].l + tr[u].r) >> 1;
    pushdown(tr, u);
    if (l <= mid) {
        add(tr, ls(u), l, r, X, Y);
    }
    if (r > mid) {
        add(tr, rs(u), l, r, X, Y);
    }
    pushup(tr, u);
}

void update(Tree& tr, int u, int l, int r, f8 X, f8 Y) {
    if (l <= tr[u].l && tr[u].r <= r) {
        fix(tr, u);
        apply(tr, u, X, Y);
        return;
    }
    int mid = (tr[u].l + tr[u].r) >> 1;
    pushdown(tr, u);
    if (l <= mid) {
        update(tr, ls(u), l, r, X, Y);
    }
    if (r > mid) {
        update(tr, rs(u), l, r, X, Y);
    }
    pushup(tr, u);
}

Node query(Tree& tr, int u, int l, int r) {
    if (l <= tr[u].l && tr[u].r <= r) {
        return tr[u];
    }
    int mid = (tr[u].l + tr[u].r) >> 1;
    pushdown(tr, u);

    if (r <= mid) {
        return query(tr, ls(u), l, r);
    }
    if (l > mid) {
        return query(tr, rs(u), l, r);
    }
    
    Node res;
    Node le = query(tr, ls(u), l, r);
    Node ri = query(tr, rs(u), l, r);
    merge(res, le, ri);
    return res;
}

signed main() {
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	
	int n, m;
	cin >> n >> m;
	
	vector<f8> X(n), Y(n);
	for (int i = 0; i < n; i++) {
	    cin >> X[i];
	}
	for (int i = 0; i < n; i++) {
	    cin >> Y[i];
	}
	
	Tree seg(n << 2);
	build(seg, 0, 0, n - 1, X, Y);
	
	auto get = [&](int l, int r) {
	    Node res = query(seg, 0, l, r);
	    int len = r - l + 1;
	    f8 num = res.sumXY - res.sumX * res.sumY / len;
	    f8 den = res.sumXX - res.sumX * res.sumX / len;
	    return num / den;
	};
	
	for (int i = 0; i < m; i++) {
	    int op, l, r;
	    cin >> op >> l >> r;
	    l--, r--;
	    if (op == 1) {
	        printf("%.10lf\n", get(l, r));
	    }
	    else if (op == 2) {
	        f8 s, t;
	        cin >> s >> t;
	        add(seg, 0, l, r, s, t);
	    }
	    else {
	        f8 s, t;
	        cin >> s >> t;
	        update(seg, 0, l, r, s, t);
	    }
	};
	
	return 0;
}

你可能感兴趣的:(洛谷题解,算法,c++,数据结构)