存板子

主席树

#include
using namespace std;
#define SIZE 200010

struct SegmentTree {
    int lc, rc;
    int sum;
    #define lc(x) tree[x].lc
    #define rc(x) tree[x].rc
    #define sum(x) tree[x].sum 
} tree[SIZE << 5];

int n, m, p, q, k, x, z, cnt, ans;
int root[SIZE], a[SIZE], b[SIZE];

void ls() {
    sort(b + 1, b + n + 1);
    m = unique(b + 1, b + n + 1) - b - 1;
}

int build(int l, int r) {
    int p = ++cnt;
    if (l == r) { sum(p) = 0; return p; }
    int mid = (l + r) / 2;
    lc(p) = build(l, mid);
    rc(p) = build(mid + 1, r);
    sum(p) = sum(lc(p)) + sum(rc(p));
    return p;
}

int insert(int now, int l, int r, int x) {
    int p = ++cnt;
    tree[p] = tree[now];
    if (l == r) { sum(p)++; return p; }
    int mid = (l + r) / 2;
    if (x <= mid) lc(p) = insert(lc(now), l, mid, x);
    if (mid < x) rc(p) = insert(rc(now), mid + 1, r, x);
    sum(p) = sum(lc(p)) + sum(rc(p));
    return p;
}

void Modify(int p, int q, int l, int r, int k) {
    int mid = (l + r) / 2, cn = (sum(lc(q)) - sum(lc(p)));
    if (l == r) { ans = l; return; }
    if (k <= cn) Modify(lc(p), lc(q), l, mid, k);
    if (cn < k) Modify(rc(p), rc(q), mid + 1, r, k - cn);
}

int main() {
    scanf("%d%d", &n, &z);
    for (register int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        b[i] = a[i];
    }
    ls();
    root[0] = build(1, m);
    for (register int i = 1; i <= n; i++) {
        x = lower_bound(b + 1, b + m + 1, a[i]) - b;
        root[i] = insert(root[i - 1], 1, m, x);
    }
    for (register int i = 1; i <= z; i++) {
        scanf("%d%d%d", &p, &q, &k);
        Modify(root[p - 1], root[q], 1, m, k);
        printf("%d\n", b[ans]);
    }
    return 0;
}

高斯消元

#include
using namespace std;
#define eps 1e-7
#define SIZE 200
double a[SIZE][SIZE];
double ans[SIZE];
int n;
int main() {
    scanf("%d", &n);
    for (register int i = 1; i <= n; i++) {
        for (register int j = 1; j <= n + 1; j++) {
            scanf("%lf", &a[i][j]);
        }
    }
    for (register int i = 1; i <= n; i++) {
        int r = i;
        for (register int j = i + 1; j <= n; j++)
            if (fabs(a[r][i]) < fabs(a[j][i]))
                r = j;
        if (fabs(a[r][i]) < eps) {
            printf("No Solution");
            return 0;
        }
        if (i != r) swap(a[i], a[r]);
        double div = a[i][i];
        for (register int j = i; j <= n + 1; j++)
            a[i][j] /= div;
        for (register int j = i + 1; j <= n; j++) {
            div = a[j][i];
            for (register int k = i; k <= n + 1; k++)
                a[j][k] -= a[i][k] *div;
        }
    }
    ans[n] = a[n][n + 1];
    for (register int i = n - 1; i >= 1; i--) {
        ans[i] = a[i][n + 1];
        for (register int j = i + 1; j <= n; j++)
            ans[i] -= a[i][j] * ans[j];
    }
    for (register int i = 1; i <= n; i++) {
        printf("%.2lf\n", ans[i]);
    }
    return 0;
}

线性基

#include
using namespace std;
#define SIZE 60
long long n, ans;
long long a[SIZE], p[SIZE];
void get(long long x) {
    for (register int i = 62; i >= 0; i--) {
        if (!(x >> (long long)i))
            continue;
        if (!p[i]) {
            p[i] = x;
            break;
        }
        x ^= p[i];
    }
}

int main() {
    scanf("%lld", &n);
    for (register int i = 1; i <= n; i++) {
        scanf("%lld", &a[i]);
        get(a[i]);
    }
    for (register int i = 62; i >= 0; i--) {
        if ((ans ^ p[i]) > ans)
            ans ^= p[i];
    }
    printf("%lld", ans);
    return 0;
}

康托展开

#include
using namespace std;
#define ll long long
#define MOD 998244353
#define N 2000000
ll n,ans,t[N],a[N];

int lowbit(int x) {
    return x&(-x);
}

void add(ll x)
{
    for (; x <= n; x += lowbit(x)) 
        t[x]++;
}

ll query(ll x)
{
    ll ans=0;
    for (; x; x -= lowbit(x))
        ans+=t[x];
    return ans;
}

int main()
{
    scanf("%lld", &n);
    for (register int i = 1; i <= n; i++)
        scanf("%lld", &a[i]);
    add(a[n]);
    int x = 1;
    for (register int i = n - 1; i >= 1; i--)
    {
        ans = (ans + x * query(a[i])) % MOD;
        x = x * (n - i + 1) % MOD;
        add(a[i]);
    }
    printf("%lld\n", (ans + 1) % MOD);
    return 0;
}

splay

#include
using namespace std;
#define INF 0x7fffffff
#define SIZE 100010
struct Splay {
    int lc, rc;
    int fa;
    int val, tag;
    int cnt, size;
    #define lc(x) tree[x].lc
    #define rc(x) tree[x].rc
    #define fa(x) tree[x].fa
    #define val(x) tree[x].val
    #define tag(x) tree[x].tag
    #define cnt(x) tree[x].cnt
    #define size(x) tree[x].size
} tree[SIZE];
int n, m, opt, x, y, tot, root, cn, len, ans;
bool flag;
int q[SIZE];
int New(int val) {
    val(++tot) = val;
    cnt(tot) = size(tot) = 1;
    return tot;
}

void update(int x) {
    size(x) = cnt(x) + size(lc(x)) + size(rc(x));
}

void build() {
    New(-INF), New(INF);
    root = 1, rc(1) = 2, fa(2) = 1;
    update(root);
}
int whichson(int x) {
    if ((lc(fa(x))) == x) return 0;
    else return 1;
}
void lrotate(int x) {
    int p = fa(x), q = lc(x);
    if (whichson(p) == 0) lc(fa(p)) = x; else rc(fa(p)) = x;
    fa(x) = fa(p), fa(p) = x, lc(x) = p, fa(q) = p, rc(p) = q;
    update(p), update(x);
}
void rrotate(int x) {
    int p = fa(x), q = rc(x);
    if (whichson(p) == 0) lc(fa(p)) = x; else rc(fa(p)) = x;
    fa(x) = fa(p), fa(p) = x, rc(x) = p, fa(q) = p, lc(p) = q;
    update(p), update(x);
}
void rotate(int x) {
    if (whichson(x) == 0) rrotate(x);
    else lrotate(x);
}
void pushdown(int x) {
    if (tag(x)) {
        tag(x) = 0;
        int ls = lc(x), rs = rc(x);
        if (ls) tag(ls) ^= 1;
        if (rs) tag(rs) ^= 1;
        swap(lc(x), rc(x));
    }
}
void splay(int x, int y) {
    int len = 0;
    for (int i = x; i; i = fa(i)) q[++len] = i;
    for (int i = len; i >= 1; i--) pushdown(q[i]);
    while (fa(x) != y) {
        int p = fa(x), q = fa(p);
        if (q != y) {
            if (whichson(x) == whichson(p)) rotate(p); 
            else rotate(x); 
        }
        rotate(x);
    }
}
void insert(int fa, int &x, int val) {
    if (x == 0) {
        x = New(val);
        fa(x) = fa;
        cn = x;
        return;
    }
    if (val == val(x)) {
        ++cnt(x), update(x);
        cn = x;
        return;
    }
    if (val < val(x))
        insert(x, lc(x), val);
    else
        insert(x, rc(x), val);
    update(x);
}
void find(int x, int k) {
    if (tag(x)) pushdown(x);
    if (size(lc(x)) + len >= k) find(lc(x), k);
    else len += size(lc(x));
    ++len;
    if (len == k) {
        ans = x;
        flag = true;
    }
    if (flag) return;
    find(rc(x), k);
}
void reverse(int l, int r) {
    flag = false;
    len = 0, find(root, l - 1 + 1);
    int x = ans; splay(x, 0), root = x;
    flag = false;
    len = 0, find(root, r + 1 + 1); 
    int y = ans; splay(y, root);
    tag(lc(y)) ^= 1;
}
void print(int x) {
    if (tag(x)) pushdown(x);
    if (lc(x)) print(lc(x));
    if (val(x) != -INF && val(x) != INF) printf("%d ", val(x));
    if (rc(x)) print(rc(x));
}
int main() {
    scanf("%d%d", &n, &m);
    build();
    for (register int i = 1; i <= n; i++) {
        insert(0, root, i); splay(cn, 0), root = cn;
    }
    for (register int i = 1; i <= m; i++) {
        scanf("%d%d", &x, &y);
        reverse(x, y);
    }
    print(root);
    return 0;
}

fhq treap(值)

// luogu-judger-enable-o2
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define MAXN 1000000
int n, opt, a, cnt, root, x, y, z;
struct node {
    int l, r, val, key, size;
} t[MAXN];
inline int read() {
    int s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
    return s * w;
}
inline int New(int val) {
    t[++cnt].val = val, t[cnt].key = rand() * rand(), t[cnt].size = 1;
    return cnt;
}
inline void update(int now) {
    t[now].size = t[t[now].l].size + t[t[now].r].size + 1;
}
inline void Split(int now, int w, int &u, int &v) {
    if (!now) u = v = 0;
    else {
        if (t[now].val <= w) u = now, Split(t[now].r, w, t[u].r, v);
        else v = now, Split(t[now].l, w, u, t[v].l);
        update(now);
    }
}
inline int Merge(int u, int v) {
    if (!u || !v) return u + v;
    if (t[u].key < t[v].key) {
        t[u].r = Merge(t[u].r, v);
        update(u);
        return u;
    }
    else {
        t[v].l = Merge(u, t[v].l);
        update(v);
        return v;
    }
}
inline void Insert(int val) {
    int x, y;
    Split(root, val, x, y);
    root = Merge(Merge(x, New(val)), y);
}

inline int Kth(int now, int sum) {
    while (1) {
        if (sum <= t[t[now].l].size) now = t[now].l;
        else if (sum == t[t[now].l].size + 1) return now;
        else sum -= t[t[now].l].size + 1 , now = t[now].r;
    }
}
int main() {
    srand(time(0));
    n = read();
    while (n--) {
        opt = read(), a = read();
        switch (opt) {
            case 1 : {
                Insert(a);
                break;
            }
            case 2 : {
                Split(root, a, x, z);
                Split(x, a - 1, x, y);
                y = Merge(t[y].l, t[y].r);
                root = Merge(Merge(x, y), z);
                break;
            }
            case 3 : {
                Split(root, a - 1, x, y);
                printf("%d\n", t[x].size + 1);
                root = Merge(x, y);
                break;
            }
            case 4 : {
                printf("%d\n", t[Kth(root, a)].val);
                break;
            }
            case 5 : {
                Split(root, a - 1, x, y);
                printf("%d\n", t[Kth(x, t[x].size)].val);
                root = Merge(x, y);
                break;
            }
            case 6 : {
                Split(root, a, x, y);
                printf("%d\n", t[Kth(y, 1)].val);
                root = Merge(x, y);
                break;
            }
        }
    }
    return 0;
}

fhq treap(区间)

#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define MAXN 200010
int n, m, root, cnt;
struct node {
    int l, r, val, key, size, tag;
} t[MAXN];
inline int read() {
    int s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
    return s * w;
}
inline int New(int val) {
    t[++cnt].val = val, t[cnt].key = rand() * rand(), t[cnt].size = 1;
    return cnt;
}
inline void update(int now) {
    t[now].size = t[t[now].l].size + t[t[now].r].size + 1;
}
inline void pushdown(int now) {
    if (t[now].tag) {
        swap(t[now].l, t[now].r);
        t[t[now].l].tag ^= 1, t[t[now].r].tag ^= 1;
        t[now].tag = 0;
    }
}
void Split(int now, int w, int &u, int &v) {
    if (!now) u = v = 0;
    else {
        pushdown(now);
        if (t[t[now].l].size < w)
            u = now, Split(t[now].r, w - t[t[now].l].size - 1, t[now].r, v);
        else
            v = now, Split(t[now].l, w, u, t[now].l);
        update(now);
    }
}
int Merge(int u, int v) {
    if (!u || !v) return u + v;
    if (t[u].key < t[v].key) {
        pushdown(u);
        t[u].r = Merge(t[u].r, v);
        update(u);
        return u;
    }
    else {
        pushdown(v);
        t[v].l = Merge(u, t[v].l);
        update(v);
        return v;
    }
}
void write(int now) {
    if (!now) return;
    pushdown(now);
    write(t[now].l);
    printf("%d ", t[now].val);
    write(t[now].r);
}
int main() {
    srand(time(0));
    n = read(), m = read();
    for (register int i = 1; i <= n; i++)
        root = Merge(root, New(i));
    while (m--) {
        int l = read(), r = read(), x, y, z;
        Split(root, r, x, y);
        Split(x, l - 1, x, z);
        t[z].tag ^= 1;
        root = Merge(Merge(x, z), y);
    }
    write(root);
    return 0;
}

树链剖分

#include
using namespace std;
#define SIZE 400010
int n, m, r, p, cnt, u, v, x, y, z, opt;
struct NODE {
    int next;
    int ver;
} edge[SIZE];
struct SegmentTree {
    int l, r;
    int sum, tag;
    #define l(x) tree[x].l
    #define r(x) tree[x].r
    #define sum(x) tree[x].sum
    #define tag(x) tree[x].tag
} tree[SIZE];
int val[SIZE], head[SIZE], size[SIZE], flag[SIZE], fa[SIZE], son[SIZE], deep[SIZE], top[SIZE], dfs[SIZE], rank[SIZE];

void add(int u, int v) {
    edge[++cnt].next = head[u];
    edge[cnt].ver = v;
    head[u] = cnt;
}

void dfs1(int p) {
    size[p] = 1;
    for (register int i = head[p]; i; i = edge[i].next) {
        int q = edge[i].ver;
        if (!flag[q]) flag[q] = 1;
        else continue;
        fa[q] = p;
        deep[q] = deep[p] + 1;
        dfs1(q);
        size[p] += size[q];
        if (size[q] > size[son[p]])
            son[p] = q;
    }
}

void dfs2(int p, int t) {
    top[p] = t;
    dfs[p] = ++cnt;
    rank[cnt] = p;
    if (!son[p]) return ;
    if (!flag[son[p]]) flag[son[p]] = 1, dfs2(son[p], t);
    for (register int i = head[p]; i; i = edge[i].next) {
        int q = edge[i].ver;
        if (!flag[q]) flag[q] = 1;
        else continue;
        if (q != son[p])
            dfs2(q, q);
    }
}

void build(int o, int l, int r) {
    int mid = (l + r) / 2;
    if (l == r) {
        l(o) = l;
        r(o) = r;
        sum(o) = val[rank[l]];
        return ;
        }
    build(o * 2, l, mid);
    build(o * 2 + 1, mid + 1, r);
    l(o) = l;
    r(o) = r;
    sum(o) = sum(o * 2) + sum(o * 2 + 1);
}

void down(int o) {
    int ll = o * 2, rr = o * 2 + 1;
    sum(ll) = (sum(ll) + tag(o) * ((r(ll) - l(ll) + 1) % p)) % p % p;
    sum(rr) = (sum(rr) + tag(o) * ((r(rr) - l(rr) + 1) % p)) % p % p;
    tag(ll) = (tag(ll) + tag(o)) % p;
    tag(rr) = (tag(rr) + tag(o)) % p;
    tag(o) = 0;
}

void change(int o, int x, int y, int z) {
    if (x <= l(o) && r(o) <= y) { sum(o) = (sum(o) + z * ((r(o) - l(o) + 1)) % p % p) % p; tag(o) = (tag(o) + z) % p; return ; }
    if (r(o) < x && y < l(o)) return ;
    if (tag(o)) down(o);
    int mid = (l(o) + r(o)) / 2;
    if (x <= mid) change(o * 2, x, y, z);
    if (mid < y) change(o * 2 + 1, x, y, z);
    sum(o) = (sum(o * 2) + sum(o * 2 + 1)) % p;
}

int query(int o, int x, int y) {
    int sum = 0;
    if (x <= l(o) && r(o) <= y) { sum = (sum + sum(o)) % p; return sum; }
    if (r(o) < x && y < l(o)) return 0;
    if (tag(o)) down(o);
    int mid = (l(o) + r(o)) / 2;
    if (x <= mid) sum = (sum + query(o * 2, x, y)) % p;
    if (mid < y) sum = (sum + query(o * 2 + 1, x, y)) % p;
    return sum;
}

void query_change(int x, int y, int k) {
    int tx = top[x], ty = top[y], dx = deep[tx], dy = deep[ty];
    while (tx != ty) {
        if (dx > dy) {
            change(1, dfs[top[x]], dfs[x], k);
            x = fa[top[x]];
        }
        else {
            change(1, dfs[top[y]], dfs[y], k);
            y = fa[top[y]];
        }
        tx = top[x], ty = top[y], dx = deep[tx], dy = deep[ty];
    }
    if (deep[x] < deep[y]) change(1, dfs[x], dfs[y], k);
    else change(1, dfs[y], dfs[x], k);
}

int query_sum(int x, int y) {
    int sum = 0;
    int tx = top[x], ty = top[y], dx = deep[tx], dy = deep[ty];
    while (tx != ty) {
        if (dx > dy) {
            sum = (sum + query(1, dfs[top[x]], dfs[x])) % p;
            x = fa[top[x]];
        }
        else {
            sum = (sum + query(1, dfs[top[y]], dfs[y])) % p;
            y = fa[top[y]];
        }
        tx = top[x], ty = top[y], dx = deep[tx], dy = deep[ty];
    }
    if (deep[x] < deep[y]) sum = (sum + query(1, dfs[x], dfs[y])) % p;
    else sum = (sum + query(1, dfs[y], dfs[x])) % p;
    return sum;
}

int main() {
    scanf("%d%d%d%d", &n, &m, &r, &p);
    for (register int i = 1; i <= n; i++) {
            scanf("%d", &val[i]);
        }
    cnt = 0;
    for (register int i = 1; i <= n - 1; i++) {
        scanf("%d%d", &u, &v);
        add(u, v);
        add(v, u);
    }
    deep[r] = 1;
    memset(flag, 0, sizeof(flag));
    flag[r] = 1;
    dfs1(r);
    cnt = 0;
    memset(flag, 0, sizeof(flag));
    flag[r] = 1;
    dfs2(r, r);
    build(1, 1, n);
    for (register int i = 1; i <= m; i++){
        scanf("%d", &opt);
        switch (opt){
            case 1 :{
                scanf("%d%d%d", &x, &y, &z);
                query_change(x, y, z);
                break;
            }
            case 2 :{
                scanf("%d%d", &x, &y);
                printf("%d\n", query_sum(x, y));
                break;
            }
            case 3 :{
                scanf("%d%d", &x, &z);
                y = dfs[x] + size[x] - 1;
                x = dfs[x];
                change(1, x, y, z);
                break;
            }
            case 4 :{
                scanf("%d", &x);
                y = dfs[x] + size[x] - 1;
                x = dfs[x];
                printf("%d\n",query(1, x, y));
                break;
            }
        }
    }
    return 0;
}

拉格朗日插值

// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include
using namespace std;
#define ll long long
#define SIZE 2010
#define mod 998244353
ll n, k, x[SIZE], y[SIZE], ans, s1, s2;
ll powmod(ll a, ll x) {
    ll ret = 1, nwv = a;
    while (x) {
        if (x & 1) ret = ret * nwv % mod;
        nwv = nwv * nwv % mod;
        x /= 2;
    }
    return ret;
}
ll inv(ll x) {
    return powmod(x, mod - 2);
}
int main() {
    scanf("%lld%lld", &n, &k);
    for (register int i = 1; i <= n; i++)
        scanf("%lld%lld", &x[i], &y[i]);
    for (register int i = 1; i <= n; i++) {
        s1 = y[i] % mod; s2 = 1;
        for (register int j = 1; j <= n; j++)
            if (i != j) s1 = s1 * (k - x[j]) % mod, s2 =s2 * (x[i] - x[j]) % mod;
        ans += s1 * inv(s2) % mod; ans = (ans + mod) % mod;
    }
    printf("%lld\n", ans);
    return 0;
}

矩阵求逆

#include
#include
#include
#include
#include
using namespace std;
typedef long long LL;
const int N=405;
const LL mod=1e9+7;
int n,m;
LL f[N][N<<1];
LL r,ret;
LL ksm(LL u,LL v){
    ret=1;
    while(v){
        if(v&1)ret=ret*u%mod;
        u=u*u%mod;v>>=1;
    }
    return ret;
}
int main(){
    scanf("%d",&n);m=n*2;
    for(int i=1;i<=n;++i){
        for(int j=1;j<=n;j++)scanf("%lld",&f[i][j]);
        f[i][n+i]=1;  
    }
    for(int i=1;i<=n;++i){ 
        for(int j=i;j<=n;j++)
        if(f[j][i]){
            for(int k=1;k<=m;k++)
            swap(f[i][k],f[j][k]);
            break;
        }
        if(!f[i][i]){puts("No Solution");return 0;}
        r=ksm(f[i][i],mod-2); 
        for(int j=i;j<=m;++j)  
        f[i][j]=f[i][j]*r%mod;
        for(int j=1;j<=n;++j) 
        if(j!=i){
            r=f[j][i];
            for(int k=i;k<=m;++k)
            f[j][k]=(f[j][k]-r*f[i][k]%mod+mod)%mod;
        }
    }
    for(int i=1;i<=n;++i,puts(""))
    for(int j=n+1;j<=m;++j)printf("%lld ",f[i][j]);
}

FFT

#include
#include
#include
#include
#include
#include
using namespace std;
#define cp complex
#define ll long long
#define PI acos(-1.0)
#define MAXN 4000010

cp a[MAXN], b[MAXN], c[MAXN];
int n, m, lim;

inline ll read() {
    ll s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
    return s * w;
}

cp omega(int n, int k) {
    return cp{cos(2 * PI * k / n), sin(2 * PI * k / n)};
}
void fft(cp *a, int n, bool inv) {
    if (n == 1) return;
    static cp buf[MAXN];
    int m = n / 2;
    for (register int i = 0; i < m; i++) {
        buf[i] = a[2 * i];
        buf[i + m] = a[2 * i + 1];
    }
    for (register int i = 0; i < n; i++)
        a[i] = buf[i];
    fft(a, m, inv);
    fft(a + m, m, inv);
    for (register int i = 0; i < m; i++) {
        cp x = omega(n, i);
        if (inv) x = conj(x);
        buf[i] = a[i] + x * a[i + m];
        buf[i + m] = a[i] - x * a[i + m];
    }
    for (register int i = 0; i < n; i++)
        a[i] = buf[i];
}
int main() {
    n = read(), m = read();
    for (register int i = 0; i <= n; i++)
        a[i] = {read(), 0};
    for (register int i = 0; i <= m; i++)
        b[i] = {read(), 0};
    int lim = 1;
    while (lim <= n + m) lim *= 2;
    for (int i = n + 1; i <= lim; i++) a[i] = {0, 0};
    for (int i = m + 1; i <= lim; i++) b[i] = {0, 0};
    fft(a, lim, true), fft(b, lim, true);
    for (register int i = 0; i <= lim; i++)
        c[i] = a[i] * b[i];
    fft(c, lim, false);
    for (register int i = 0; i <= n + m; i++)
        printf("%d ", (int)((c[i].real() / lim) + 0.5));
    return 0;
}

CDQ分治

#include
using namespace std;
#define SIZE 100010
#define INF 100010

struct NODE {
    long long next;
    long long ver;
    long long weight;
    #define next(x) edge[x].next
    #define ver(x) edge[x].ver
    #define weight(x) edge[x].weight
} edge[SIZE];

long long n, m, u, v, w, k, cnt, root, ans, sum, tot;
int head[SIZE], b[SIZE], a[SIZE];
long long size[SIZE];
long long d[SIZE];
bool flag[SIZE];
bool get;

void add(int u, int v, int w) {
    next(++cnt) = head[u];
    ver(cnt) = v;
    weight(cnt) = w;
    head[u] = cnt;
}

void get_root(long long n, long long u,long long fa) {
    long long Size = 0;
    size[u] = 1;
    for (register int i = head[u]; i; i = next(i)) {
        int v = ver(i);
        if (v == fa || flag[v]) continue;
        get_root(n, v, u);
        size[u] += size[v];
        Size = max(Size, size[v]);
    }
    Size = max(Size, n - size[u]);
    if (Size < ans) root = u, ans = Size;
}

void dfs1(long long u, long long fa) {
    if (fa == root) b[u] = u;
    else b[u] = b[fa];
    for (register int i = head[u]; i; i = next(i)) {
        int v = ver(i);
        if (v == fa || flag[v]) continue;
        d[v] = d[u] + weight(i);
        dfs1(v, u);
    }
}

void dfs2(long long u, long long fa) {
    a[++tot] = u;
    for (register int i = head[u]; i; i = next(i)) {
        int v = ver(i);
        if (v == fa || flag[v]) continue;
        dfs2(v, u);
    }
}
bool cmp(long long x, long long y) {
    return d[x] < d[y];
}
void calc(long long o) {
    long long L, R;
    tot = 0;
    dfs2(o, o);
    sort(a + 1, a + tot + 1,cmp);
    L = 0;
    R = tot;
    while (L < R) {
        L++;
        while (d[a[L]] + d[a[R]] >= k){
        	if (d[a[L]] + d[a[R]] == k && b[a[L]] != b[a[R]]) get =true;
        	R--;
        }
        if (get) break;
    }
}

void f(long long n, long long o) {
    root = o;
    ans = INF;
    get_root(n, root, 0);
    d[root] = 0;
    dfs1(root, root);
    calc(root);
    if (get) return;
    flag[root] = true;
    for (register int i = head[root]; i; i = next(i)) {
        int v = ver(i);
        if (flag[v]) continue;
        f(size[v], v);
        if (get) return;
    }
}
int main() {
    scanf("%d%d", &n, &m);
    for (register int i = 1; i <= n - 1; i++) {
        scanf("%d%d%d", &u, &v, &w);
        add(u, v, w);
        add(v, u, w);
    }
    for (register int i = 1; i <= m; i++) {
    	scanf("%d", &k);
    	get = false;
    	memset(flag, false, sizeof(flag));
    	f(n, 1);
    	if (get) printf("AYE\n");
    	else printf("NAY\n");
    }
    return 0;
}

整体二分

#include
using namespace std;
#define INF 1e9
#define SIZE 300010
struct node {
    int x, y, k, p;
    char c;
} q[SIZE];
int n, len, m, xx, yy, kk;
int a[SIZE], t[SIZE], p[SIZE], ans[SIZE], q1[SIZE], q2[SIZE];
char ch;

int lowbit(int x) {
    return x & -x;
}

void add(int x, int y) {
    for (register int i = x; i <= n; i += lowbit(i))
        t[i] += y;
}

int query(int x) {
    int sum = 0;
    for (register int i = x; i; i -= lowbit(i))
        sum += t[i];
    return sum;	
}

void solve(int l, int r, int ll, int rr) {
    int mid = (l + r) >> 1;
    int w1 = 0, w2 = 0;
    if (ll > rr) return;
    if (l == r) {
        for (register int i = ll; i <= rr; i++)
            if (q[p[i]].c == 'Q') ans[p[i]] = mid;
        return;
    }
    for (register int i = ll; i <= rr; i++) {
        int z = p[i];
        if (q[z].c == 'Q') {
            int ans = query(q[z].y) - query(q[z].x - 1);
            if (q[z].k <= ans) q1[++w1] = z;
            else q[z].k -= ans, q2[++w2] = z;
        }
        else {
            if (q[z].x <= mid) add(q[z].k, q[z].y), q1[++w1] = z;
            else q2[++w2] = z;
        }
    }
    for (register int i = ll; i <= rr; i++)
        if (q[p[i]].c == 'C' && q[p[i]].x <= mid) add(q[p[i]].k, -q[p[i]].y);
    for (register int i = 1; i <= w1; i++)
        p[ll + i - 1] = q1[i];
    for (register int i = 1; i <= w2; i++)
        p[ll + w1 + i - 1] = q2[i];
    solve(l, mid, ll, ll + w1 - 1);
    solve(mid + 1, r, ll + w1, rr);
}

int main() {
    scanf("%d%d", &n, &m);
    for (register int i = 1; i <= n; i++) {
        scanf("%d", &xx);
        len++;
        q[len].c = 'C', q[len].x = xx, q[len].y = 1, q[len].k = i, q[len].p = len, a[i] = q[len].x;
    }
    for (register int i = 1; i <= m; i++) {
        ch = getchar();
        while (ch != 'Q' && ch != 'C') ch = getchar();
        if (ch == 'Q') {
            scanf("%d%d%d", &xx, &yy, &kk);
            len++;
            q[len].c = 'Q', q[len].x = xx, q[len].y = yy, q[len].k = kk, q[len].p = len;
        }
        else{
            scanf("%d%d", &xx, &yy);
            len++;
            q[len].c = 'C', q[len].x = a[xx], q[len].y = -1, q[len].k = xx, q[len].p = len;
            ++len;
            q[len].c = 'C', q[len].x = yy, q[len].y = 1, q[len].k = xx, q[len].p = len;
            a[xx] = yy;
        }
    }
    m = len;
    for (register int i = 1; i <= m; i++) p[i] = i;
    memset(ans, -1, sizeof(ans));
    solve(1, INF, 1, m);
    for (register int i = 1; i <= m; i++)
        if (ans[i] >= 0)
            printf("%d\n", ans[i]);
    return 0;
}

扩展中国剩余定理

#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define ll __int128
#define MAXN 200000
ll a[MAXN], b[MAXN];
int n;

inline ll read() {
    ll s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
    return s * w;
}
ll exgcd(ll a, ll b, ll &x, ll &y) {
    if (b == 0) { x = 1, y = 0; return a; }
    ll gcd = exgcd(b, a % b, x, y);
    ll z = x; x = y; y = z - (a / b) * y;
    return gcd;
}
ll excrt() {
    ll A, B, C, D, x = 0, y = 0, gcd;
    A = a[1], B = b[1];
    for (register int i = 2; i <= n; i++) {
        C = b[i] - B;
        gcd = exgcd(A, a[i], x, y);
        x = x * C / gcd;
        D = a[i] / gcd;
        x = (x % D + D) % D;
        B = B + x * A;
        A = A / gcd * a[i];
        B = B % A;
    }
    return (B % A + A) % A;
}
int main() {
    //freopen("pig.in", "r", stdin);
    //freopen("pig.out", "w", stdout);
    n = read();
    for (register int i = 1; i <= n; i++)
        a[i] = read(), b[i] = read();
    printf("%lld" ,excrt());
    return 0;
}

欧拉定理

#include
#include
#include
#include
#include
using namespace std;
#define ll long long

ll a, b, m, p;

inline ll read(ll p) {
    ll s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) {
        s = (s << 1) + (s << 3) + (c ^ 48);
        if (s >= p)
            s = s % p + p;
    }
    return s * w;
}

inline ll phi(ll n) {
    if (n == 1) return 1;
    ll ans = n, m = sqrt(n);
    for (register ll i = 2; i <= m; i++)
        if (n % i == 0) {
            ans = ans / i * (i - 1);
            while (n % i == 0) n /= i;
        }
    if 	(n > 1) ans = ans / n * (n - 1);
    return ans;
}

inline int pow(ll a, ll b, ll m) {
    ll ans = 1;
    while (b) {
        if (b % 2) (ans *= a) %= m;
        (a *= a) %= m;
        b /= 2;
    }
    return ans;
}

int main() {
    scanf("%lld%lld", &a, &m);
    p = phi(m);
    b = read(p);
    printf("%lld", pow(a, b, m));
    return 0;
}

点分治

#include
using namespace std;
#define SIZE 100010
#define INF 100010

struct NODE {
    long long next;
    long long ver;
    long long weight;
    #define next(x) edge[x].next
    #define ver(x) edge[x].ver
    #define weight(x) edge[x].weight
} edge[SIZE];

long long n, m, u, v, w, k, cnt, root, ans, sum, tot;
int head[SIZE], b[SIZE], a[SIZE];
long long size[SIZE];
long long d[SIZE];
bool flag[SIZE];
bool get;

void add(int u, int v, int w) {
    next(++cnt) = head[u];
    ver(cnt) = v;
    weight(cnt) = w;
    head[u] = cnt;
}

void get_root(long long n, long long u,long long fa) {
    long long Size = 0;
    size[u] = 1;
    for (register int i = head[u]; i; i = next(i)) {
        int v = ver(i);
        if (v == fa || flag[v]) continue;
        get_root(n, v, u);
        size[u] += size[v];
        Size = max(Size, size[v]);
    }
    Size = max(Size, n - size[u]);
    if (Size < ans) root = u, ans = Size;
}

void dfs1(long long u, long long fa) {
    if (fa == root) b[u] = u;
    else b[u] = b[fa];
    for (register int i = head[u]; i; i = next(i)) {
        int v = ver(i);
        if (v == fa || flag[v]) continue;
        d[v] = d[u] + weight(i);
        dfs1(v, u);
    }
}

void dfs2(long long u, long long fa) {
    a[++tot] = u;
    for (register int i = head[u]; i; i = next(i)) {
        int v = ver(i);
        if (v == fa || flag[v]) continue;
        dfs2(v, u);
    }
}
bool cmp(long long x, long long y) {
    return d[x] < d[y];
}
void calc(long long o) {
    long long L, R;
    tot = 0;
    dfs2(o, o);
    sort(a + 1, a + tot + 1,cmp);
    L = 0;
    R = tot;
    while (L < R) {
        L++;
        while (d[a[L]] + d[a[R]] >= k){
        	if (d[a[L]] + d[a[R]] == k && b[a[L]] != b[a[R]]) get =true;
        	R--;
        }
        if (get) break;
    }
}

void f(long long n, long long o) {
    root = o;
    ans = INF;
    get_root(n, root, 0);
    d[root] = 0;
    dfs1(root, root);
    calc(root);
    if (get) return;
    flag[root] = true;
    for (register int i = head[root]; i; i = next(i)) {
        int v = ver(i);
        if (flag[v]) continue;
        f(size[v], v);
        if (get) return;
    }
}
int main() {
    scanf("%d%d", &n, &m);
    for (register int i = 1; i <= n - 1; i++) {
        scanf("%d%d%d", &u, &v, &w);
        add(u, v, w);
        add(v, u, w);
    }
    for (register int i = 1; i <= m; i++) {
    	scanf("%d", &k);
    	get = false;
    	memset(flag, false, sizeof(flag));
    	f(n, 1);
    	if (get) printf("AYE\n");
    	else printf("NAY\n");
    }
    return 0;
}

自适应辛普森法

#include
#include
#include
#include
#include
using namespace std;
/*inline int read() {
    int s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
    return s * w;
}*/
double a, b, c, d, l, r;
inline double f(double x) {
    return (c * x + d) / (a * x + b);
} 
inline double simpson(double l, double r) {
    double mid = (l + r) / 2.0;
    return (r - l) * (f(l) + f(r) + 4.0 * f(mid)) / 6.0;
}
double asr(double l, double r, double eps, double ans) {
    double mid = (l + r) / 2;
    double l_ans = simpson(l, mid), r_ans = simpson(mid, r);
    if (fabs(l_ans + r_ans - ans) <= eps * 15) return l_ans + r_ans + (l_ans + r_ans - ans) / 15;
    return asr(l, mid, eps, l_ans) + asr(mid, r, eps, r_ans);
}
inline double asr(double l, double r, double eps) {
    return asr(l, r, eps, simpson(l, r));
}
int main() {
    scanf("%lf%lf%lf%lf%lf%lf", &a, &b, &c, &d, &l, &r);
    printf("%.6lf", asr(l, r, 1e-7));
    return 0;
}

字符串哈希

#include
#include
#include
#include
using namespace std;
#define ll unsigned long long
ll base = 131;
ll a[10010];
char s[10010];
int n, ans = 1;
int prime = 233317; 
ll mod = 212370440130137957ll;
ll hash(char s[]) {
    int len = strlen(s);
    ll ans = 0;
    for (register int i = 0; i < len; i++)
    ans = (ans * base + (ll)s[i]) % mod + prime;
    return ans;
}
int main() {
    scanf("%d",&n);
    for (int i = 1; i <= n; i++)
    {
        scanf("%s", s);
        a[i] = hash(s);
    }
    sort(a + 1, a + n + 1);
    for (register int i = 1;i < n; i++)
    {
        if (a[i] != a[i + 1])
        ans++;
    }
    printf("%d", ans);
} 

线筛素数

#include
using namespace std;
#define SIZE 10000010
int n, m, x, k;
int v[SIZE], prime[SIZE];

int main(){
	cin >> n >> m;
	for (register int i = 2; i <= n; i++) {
		if (v[i] == 0) {
			v[i] = i;
			prime[++k] = i; 
		}
		for (register int j = 1; j <= k; j++) {
			if (prime[j] > v[i] || prime[j] > n / i) break;
			v[prime[j] * i] = prime[j];
		}
	}
	for (register int i = 1; i <= m; ++i){
		cin >> x;
		if (v[x] == x)
			cout << "Yes" << endl;
		else
			cout << "No" << endl;
	}
	return 0;
} 

卢卡斯定理

#include
#include
#include
#include
#include
using namespace std;
#define ll long long
#define MAXN 100010

ll T, n, m, p;
ll a[MAXN];

inline ll read() {
    ll s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
    return s * w;
}

inline ll pow(ll y, ll z, ll p) {
    y %= p;
    ll ans = 1;
    while (z) {
        if (z % 2) (ans *= y) %= p;
        (y *= y) %= p;
        z /= 2;
    }
    return ans;
}

ll C(ll n, ll m) {
    if (m > n) return 0;
    return ((a[n] * pow(a[m], p - 2, p)) % p * pow(a[n - m], p - 2, p) % p); 
}

ll lucas(ll n, ll m) {
    if (!m) return 1;
    return lucas(n / p, m / p) * C(n % p, m % p) % p;
}

int main() {
    T = read();
    while (T--) {
        n = read(), m = read(), p = read();
        a[0] = 1;
        for (register int i = 1; i <= p; i++)
            a[i] = (a[i - 1] * i) % p;
        printf("%lld\n", lucas(n + m, m)); 
    }
    return 0;
}

BSGS

#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define ll long long
ll p, b, n, m, now, t, ans;
map <ll, int> mp;
inline ll read() {
    ll s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c);c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
    return s * w;
}
ll Pow(ll x){
    ll sum = 1;
    ll c = b;
    while (x) {
        if (x & 1)
            sum = (sum * c) % p;
        x = x >> 1;
        c = (c * c) % p;
    }
    return sum;
}
int main() {
    p = read(), b = read(), n = read();
    if (b % p == 0) {
        printf("no solution");
        return 0;
    }
    mp.clear();
    m = ceil(sqrt(p));
    now = n % p;
    mp[now] = 0;
    for (register int i = 1; i <= m; i++) {
        now = (now * b) % p;
        mp[now] = i;
    }
    t = Pow(m);
    now = 1;
    for (register int i = 1; i <= m; i++) {
        now = (now * t) % p;
        if (mp[now]) {
            ans = i * m - mp[now];
            printf("%lld\n", (ans % p + p) % p);
            return 0;
        }
    }
    printf("no solution\n");
    return 0;
    return 0;
}

莫比乌斯反演

#include
using namespace std;
#define ll long long
#define N 100009
int T, pri[N], mu[N], cnt;
ll sum[N];
bool flag[N];
void Cirno() {
    mu[1] = 1;
    for (register int i = 2; i < N; i++) {
        if (!flag[i]) pri[++cnt] = i, mu[i] = -1;
        for (register int j = 1; j <= cnt; j++) {
            if (i * pri[j] >= N) break;
            flag[i * pri[j]] = true;
            if (!(i % pri[j])) { mu[i * pri[j]] = 0; break; }
            else mu[i * pri[j]] = -mu[i];
        }
    }
    for (register int i = 1; i < N; i++)
        sum[i] = sum[i - 1] + mu[i];
}
void solve(int n, int m, int k) {
    ll ans = 0;
    n /= k, m /= k;
    int lim = min(n, m);
    for (register int i = 1; i <= lim;) {
        ll j = min(n / (n / i), m / (m / i));
        ans += 1 * (sum[j] - sum[i - 1]) * (n / i) * (m / i);
        i = j + 1;
    }
    printf("%lld\n", ans);
}
int main() {
    Cirno();
    scanf("%d", &T);
    for (register int i = 1; i <= T; i++) {
        int n, m, k;
        scanf("%d%d%d", &n, &m, &k);
        solve(n, m, k);
    }
    return 0;
}

可持久化平衡树

// luogu-judger-enable-o2
// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define MAXN 500005
int n, opt, a, cnt, x, y, z, i, tim;
int root[MAXN];
struct node {
    int l, r, val, key, size;
} t[MAXN * 50];
inline int read() {
    int s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
    return s * w;
}
inline int New(int val) {
    t[++cnt].val = val, t[cnt].key = rand() * rand(), t[cnt].size = 1;
    return cnt;
}
inline void update(int now) {
    t[now].size = t[t[now].l].size + t[t[now].r].size + 1;
}
inline void Split(int now, int w, int &u, int &v) {
    if (!now) u = v = 0;
    else {
        if (t[now].val <= w) u = ++cnt, t[u] = t[now], Split(t[u].r, w, t[u].r, v), update(u);
        else v = ++cnt, t[v] = t[now], Split(t[v].l, w, u, t[v].l), update(v);
    }
}
inline int Merge(int u, int v) {
    int x = 0;
	if (!u || !v) return u + v;
    if (t[u].key < t[v].key) {
        t[u].r = Merge(t[u].r, v);
        update(u);
        return u;
    }
    else {
        t[v].l = Merge(u, t[v].l);
        update(v);
        return v;
    }
}
inline void Insert(int val) {
    int x, y;
    Split(root[i], val, x, y);
    root[i] = Merge(Merge(x, New(val)), y);
}

inline int Kth(int now, int sum) {
    while (1) {
        if (sum <= t[t[now].l].size) now = t[now].l;
        else if (sum == t[t[now].l].size + 1) return now;
        else sum -= t[t[now].l].size + 1 , now = t[now].r;
    }
}
int main() {
    srand(time(0));
    n = read();
    while (n--) {
    	i++;
    	tim = read();
    	root[i] = root[tim];
        opt = read(), a = read();
        switch (opt) {
            case 1 : {
                Insert(a);
                break;
            }
            case 2 : {
                Split(root[i], a, x, z);
                Split(x, a - 1, x, y);
                y = Merge(t[y].l, t[y].r);
                root[i] = Merge(Merge(x, y), z);
                break;
            }
            case 3 : {
                Split(root[i], a - 1, x, y);
                printf("%d\n", t[x].size + 1);
                root[i] = Merge(x, y);
                break;
            }
            case 4 : {
                printf("%d\n", t[Kth(root[i], a)].val);
                break;
            }
            case 5 : {
                Split(root[i], a - 1, x, y);
                printf("%d\n", t[Kth(x, t[x].size)].val);
                root[i] = Merge(x, y);
                break;
            }
            case 6 : {
                Split(root[i], a, x, y);
                printf("%d\n", t[Kth(y, 1)].val);
                root[i] = Merge(x, y);
                break;
            }
        }
    }
    return 0;
}

可持久化文艺平衡树

#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define ll long long
#define MAXN 200050
struct node {
	ll l, r;
	ll val, key, sum;
	ll tag, size;
} t[MAXN * 60];
ll T, tim, opt, i, cnt, a, b, lastans;
ll root[MAXN];
inline ll read() {
	ll s = 0, w = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
	for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
	return s * w;
}
ll New(ll val) {
	t[++cnt].val = val, t[cnt].sum = val, t[cnt].key = rand() * rand(), t[cnt].size = 1;
	return cnt;
}
void update(ll now) {
	t[now].size = t[t[now].l].size + t[t[now].r].size + 1;
	t[now].sum = t[t[now].l].sum + t[t[now].r].sum + t[now].val; 
}
void pushdown(ll now) {
	ll x;
	if (t[now].tag) {
		if (t[now].l) x = ++cnt, t[x] = t[t[now].l], t[now].l = x;
		if (t[now].r) x = ++cnt, t[x] = t[t[now].r], t[now].r = x;
		swap(t[now].l, t[now].r);
		if (t[now].l) t[t[now].l].tag ^= 1;
		if (t[now].r) t[t[now].r].tag ^= 1;
		t[now].tag = 0;
	}
}
void Split(ll now, ll w, ll &u, ll &v) {
	if (!now) u = v = 0;
	else {
		pushdown(now);
		if (t[t[now].l].size < w) u = ++cnt, t[u] = t[now], Split(t[u].r, w - t[t[now].l].size - 1, t[u].r, v), update(u);
		else v = ++cnt, t[v] = t[now], Split(t[v].l, w, u, t[v].l), update(v);
	}
}
ll Merge(ll u, ll v) {
	ll x;
	if (!u || !v) return u + v;
	pushdown(u), pushdown(v);
	if (t[u].key < t[v].key) {
		t[u].r = Merge(t[u].r, v);
		update(u);
		return u;
	}
	else {
		t[v].l = Merge(u, t[v].l);
		update(v);
		return v;
	}
}
ll Insert(ll a, ll b) {
	ll x = 0, y = 0;
	Split(root[tim], a, x, y);
	root[i] = Merge(Merge(x, New(b)), y);
}
ll Delete(ll a) {
	ll x = 0, y = 0, z = 0;
	Split(root[tim], a, x, z);
	Split(x, a - 1, x, y);
	root[i] = Merge(x, z);
}
int main() {
	T = read();
	while (T--) {
		i++;
		tim = read();
		opt = read();
		if (opt == 1) {
			a = read(), b = read();
			a ^= lastans, b ^= lastans;
			Insert(a, b);
		}
		if (opt == 2) {
			a = read();
			a ^= lastans;
			Delete(a);
		}
		if (opt == 3) {
			ll x = 0, y = 0, z = 0;
			a = read(), b = read();
			a ^= lastans, b ^= lastans;
			Split(root[tim], b, x, z);
            Split(x, a - 1, x, y);
            t[y].tag ^= 1;
            root[i] = Merge(Merge(x, y), z);
		}
		if (opt == 4) {
			ll x = 0, y = 0, z = 0;
			a = read(), b = read();
            a ^= lastans,b ^= lastans;
            Split(root[tim], b, x, z);
            Split(x, a - 1, x, y);
            printf("%lld\n", lastans = t[y].sum);
            root[i] = Merge(Merge(x, y), z);
		}
	}
	return 0;
} 

tarjan(缩点)

#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define MAXN 100010
struct rec {
	int nxt, ver;
} t[MAXN];
int n, m, cnt, tot, deep, Ans;
int head[MAXN], val[MAXN], Val[MAXN], dfn[MAXN], low[MAXN], color[MAXN], ans[MAXN], u[MAXN], v[MAXN], in[MAXN];
bool vis[MAXN];
stack <int> S;
queue <int> q;
inline int read() {
	int s = 0, w = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
	for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
	return s * w;
}
inline void add(int u, int v) {
	t[++cnt].nxt = head[u], t[cnt].ver = v, head[u] = cnt;
}
void tarjan(int u) {
	dfn[u] = low[u] = ++deep;
	vis[u] = true;
	S.push(u);
	for (register int i = head[u]; i; i = t[i].nxt) {
		int v = t[i].ver;
		if (!dfn[v]) {
			tarjan(v);
			low[u] = min(low[u], low[v]);
		}
		else
			if (vis[v])
				low[u] = min(low[u], dfn[v]);
	}
	if (dfn[u] == low[u]) {
		++tot;
		while (!S.empty() && S.top() != u) {
			vis[S.top()] = false;
			Val[tot] += val[S.top()];
			color[S.top()] = tot;
			S.pop();
		}
		vis[S.top()] = false;
		Val[tot] += val[S.top()];
		color[S.top()] = tot;
		S.pop();
	}
}
inline void topsort() {
	while (!q.empty()) {
		int u = q.front();
		q.pop();
		for (register int i = head[u]; i; i = t[i].nxt) {
			int v = t[i].ver;
			ans[v] = max(ans[v], ans[u] + Val[v]);
			if (!(--in[v])) q.push(v); 
		}
	}
}
int main() {
	n = read(), m = read();
	for (register int i = 1; i <= n; i++)
		val[i] = read();
	for (register int i = 1; i <= m; i++) {
		u[i] = read(), v[i] = read();
		add(u[i], v[i]);
	}
	for (register int i = 1; i <= n; i++)
		if (!dfn[i])
			tarjan(i);
	memset(head, 0, sizeof(head));
	cnt = 0;
	memset(t, 0, sizeof(t));
	for (register int i = 1; i <= m; i++)
		if (color[u[i]] != color[v[i]]) {
			add(color[u[i]], color[v[i]]);
			in[color[v[i]]]++;
		}
	for (register int i = 1; i <= tot; i++)
		if (!in[i]) q.push(i), ans[i] = Val[i];
	topsort();
	for (register int i = 1; i <= tot; i++)
		Ans = max(Ans, ans[i]);
	printf("%d", Ans);
	return 0;
}

你可能感兴趣的:(板子)