ACM模板

动态规划

背包

01背包

//物品数量为n 背包容量为v
//第i个物品的价值为val[i] 体积为vol[i]
//最终结果为dp[v]
const int maxn = 1010;
int dp[maxn];
int vol[maxn];
int val[maxn];
int n, v;

void solve() {
    memset(dp, 0, sizeof(dp));
    for(int i = 0; i < n; i++) {
        for(int j = v; j >= vol[i]; j--) {
            dp[j] = max(dp[j], dp[j - vol[i]] + val[i]);
        }
    }
}

完全背包

const int maxn = 1010;
int dp[maxn];
int vol[maxn];
int val[maxn];
int n, v;

void solve() {
    memset(dp, 0, sizeof(dp));
    for(int i = 0; i < n; i++) {
        for(int j = vol[i]; j <= v; j++) {
            dp[j] = max(dp[j], dp[j - vol[i]] + val[i]);
        }
    }
}

多重背包

//恰好装满 hdu2844
#include
#include
#include
#include
#include
using namespace std;

const int maxn = 110;
const int maxm = 100010;
int val[maxn];
int num[maxn];
int dp[maxm];
int n, m;

void solve() {
    memset(dp, -1, sizeof(dp));
    dp[0] = 0;
    for(int i = 0; i < n; i++) {
        if(val[i] * num[i] >= m) {
            for(int j = val[i]; j <= m; j++) {
                if(dp[j - val[i]] != -1) {
                    dp[j] = max(dp[j], dp[j - val[i]] + val[i]);
                }
            }
        }
        else {
            int cnt = 1;
            int totVal = val[i];
            while(num[i] > cnt) {
                for(int j = m; j >= totVal; j--) {
                    if(dp[j - totVal] != -1) {
                        dp[j] = max(dp[j], dp[j - totVal] + totVal);
                    }
                }
                num[i] -= cnt;
                cnt <<= 1;
                totVal <<= 1;
            }
            totVal = num[i] * val[i];
            for(int j = m; j >= totVal; j--) {
                if(dp[j - totVal] != -1) {
                    dp[j] = max(dp[j], dp[j - totVal] + totVal);
                }
            }
        }
    }
    int ans = 0;
    for(int i = 1; i <= m; i++) {
        if(dp[i] == i) {
            ans++;
        }
    }
    printf("%d\n", ans);
}

int main() {
    while(scanf("%d%d", &n, &m) && (n || m)) {
        for(int i = 0; i < n; i++) {
            scanf("%d", val + i);
        }
        for(int i = 0; i < n; i++) {
            scanf("%d", num + i);
        }
        solve();
    }
    return 0;
}

多重背包判断可行性

//每种有若干件的物品能否填满给定容量的背包
//O(NV)
// hdu2844
#include
#include
#include
#include
#include
using namespace std;

const int maxn = 101;
const int maxm = 100001;
int val[maxn];
int num[maxn];
int dp[maxm];
int n, m;

void solve() {
    for(int i = 1; i <= m; i++) {
        dp[i] = -1;
    }
    dp[0] = 0;
    for(int i = 1; i <= n; i++) {
        for(int j = 0; j <= m; j++) {
            dp[j] = dp[j] >= 0 ? num[i] : -1;
        }
        for(int j = 0; j <= m - val[i]; j++) {
            if(dp[j] > 0) {
                dp[j + val[i]] = max(dp[j + val[i]], dp[j] - 1);
            }
        }
    }
    int ans = 0;
    for(int i = 1; i <= m; i++) {
        if(dp[i] != -1) {
            ans++;
        }
    }
    printf("%d\n", ans);
}

int main() {
    while(scanf("%d%d", &n, &m) && (n || m)) {
        for(int i = 1; i <= n; i++) {
            scanf("%d", val + i);
        }
        for(int i = 1; i <= n; i++) {
            scanf("%d", num + i);
        }
        solve();
    }
    return 0;
}

数论

组合数求模

卢卡斯定理

//FZU 2020
//要求p为素数
//适合使用的场景:
//1 <= m <= n <= 1e18 , p <= 1e5
typedef long long LL;
LL quickPowMod(LL n, LL r, LL mod) {
    LL res = 1;
    n %= mod;
    while(r) {
        if(r & 1) {
            res = (res * n) % mod;
        }
        n = (n * n) % mod;
        r >>= 1;
    }
    return res;
}

LL inverse(LL n, LL mod) {
    return quickPowMod(n, mod - 2, mod);
}

LL C(LL n, LL m, LL mod) {
    if(m > n) {
        return 0;
    }
    LL res = 1;
    LL a, b;
    for(int i = 1; i <= m; i++) {
        a = (n + i - m) % mod;
        b = inverse(i, mod);
        res = (((res * a) % mod)* b) % mod;
    }
    return res;
}

LL lucas(LL n, LL m, LL mod) {
    if(m == 0) {
        return 1;
    }
    return C(n % mod, m % mod, mod) * lucas(n / mod, m / mod, mod) % mod;
}

int main() {
    int t;
    LL n, m, mod;
    cin >> t;
    while(t--) {
        cin >> n >> m >> mod;
        cout << lucas(n, m, mod) << endl;
    }
    return 0;
}

扩展卢卡斯定理

C n m % p      1 ≤ m ≤ n ≤ 1 0 18      2 ≤ p ≤ 1000000      且 p 不 保 证 为 质 数 C_n^m \% p\;\;1\leq m \leq n \leq 10^{18}\;\;2\leq p \leq 1000000\;\;且p不保证为质数 Cnm%p1mn10182p1000000p

//洛谷4720模板
//Gym - 100633J
#include
#include
#include
#include
#include
#include
#include
using namespace std;
typedef long long LL;

LL quickPowerMod(LL n, LL r, LL p) {
    LL res = 1;
    n %= p;
    while(r) {
        if(r & 1) {
            res = res * n % p;
        }
        r >>= 1;
        n = n * n % p;
    }
    return res;
}

LL exgcd(LL a, LL b, LL &x, LL &y) {
    if(b == 0) {
        x = 1;
        y = 0;
        return a;
    }
    LL res = exgcd(b, a % b, x, y);
    LL temp = x;
    x = y;
    y = temp - a / b * y;
    return res;
}

LL inverse(LL a, LL m) {
    LL x1, x2;
    LL d = exgcd(a, m, x1, x2);
    if(d != 1) {
        return -1;
    }
    return (x1 % m + m) % m;
}

LL CRT(vector &b, vector &m) {
    LL n = m.size();
    LL ans = 0;
    LL M = 1;
    LL Mi;
    LL x, y;
    for(LL temp : m) {
        M *= temp;
    }
    for(LL i = 0; i < n; i++) {
        Mi = M / m[i];
        //exgcd(Mi, m[i], x, y);
        ans = (ans + b[i] * Mi * inverse(Mi, m[i])) % M;
    }
    ans = (ans + M) % M;
    return ans;
}

//计算n!中因子p的个数
LL calculateExp(LL n, LL p) {
    LL res = 0;
    while(n) {
        res += n / p;
        n /= p;
    }
    return res;
}

//计算n!中除去所有因子p之后模mod(p^x)的值
LL facModP(LL n, LL mod, LL p) {
    if(n == 0) {
        return 1;
    }
    LL res = 1;
    if(n / mod) {
        for(LL i = 2; i <= mod; i++) {
            if(i % p != 0) {
                res = res * i % mod;
            }
        }
        res = quickPowerMod(res, n / mod, mod);
    }
    LL temp = n % mod;
    for(LL i = 2; i <= temp; i++) {
        if(i % p != 0) {
            res = res * i % mod;
        }
    }
    return res * facModP(n / p, mod, p) % mod;
}

//计算C(n, m) % p^x 的值
LL CmodP(LL n, LL m, LL mod, LL p) {
    LL res, inv1, inv2;
    LL cnt = 0;
    cnt += calculateExp(n, p);
    cnt -= calculateExp(m, p);
    cnt -= calculateExp(n - m, p);

    res = facModP(n, mod, p);
    inv1 = facModP(m, mod, p);
    inv2 = facModP(n - m, mod, p);
    inv1 = inverse(inv1, mod);
    inv2 = inverse(inv2, mod);
    res = (res * inv1 % mod) * inv2 % mod;
    res = res * quickPowerMod(p, cnt, mod) % mod;
    return res;
}

//计算C(n, m) % mod(合数)的值
LL CmodX(LL n, LL m, LL mod) {
    vector b, md;
    LL temp;
    for(LL i = 2; i * i <= mod; i++) {
        if(mod % i == 0) {
            temp = 1;
            while(mod % i == 0) {
                temp *= i;
                mod /= i;
            }
            md.push_back(temp);
            b.push_back(CmodP(n, m, temp, i));
        }
    }
    if(mod > 1) {
        md.push_back(mod);
        b.push_back(CmodP(n, m, mod, mod));
    }
    return CRT(b, md);
}

int main() {
    LL n, m, p;
    while(cin >> n >> m >> p) {
        cout << CmodX(n, m, p) << endl;
    }
    return 0;
}

模线性方程组

中国剩余定理

typedef long long LL;
const int maxn = 1000;
LL b[maxn];
LL m[maxn];

LL extendGcd(LL a, LL b, LL &x, LL &y) {
    if(b == 0) {
        x = 1;
        y = 0;
        return a;
    }
    LL res = extendGcd(b, a % b, x, y);
    LL temp = x;
    x = y;
    y = temp - a/b*y;
    return res;
}

LL CRT(LL *b, LL *m, int n) {
    LL ans = 0;
    LL M = 1;
    LL Mi;
    LL x, y;
    for(int i = 0; i < n; i++) {
        M *= m[i];
    }
    for(int i = 0; i < n; i++) {
        Mi = M / m[i];
        extendGcd(Mi, m[i], x, y);
        ans = (ans + b[i] * Mi * x) % M;
    }
    ans = (ans + M) % M;
    return ans;
}

扩展中国剩余定理

//模板验证POJ2891

typedef long long LL;
const int maxn = 1000;
LL b[maxn];
LL m[maxn];

LL exGcd(LL a, LL b, LL &x, LL &y) {
    if(b == 0) {
        x = 1;
        y = 0;
        return a;
    }
    LL res = exGcd(b, a % b, x, y);
    LL temp = x;
    x = y;
    y = temp - a/b*y;
    return res;
}

LL gcd(LL a, LL b) {
    return b == 0 ? a : gcd(b, a % b);
}

LL inverse(LL a, LL m) {
    LL x1, x2;
    LL d = exGcd(a, m, x1, x2);
    if(d != 1) {
        return -1;
    }
    return x1 % m;
}

bool merge(LL m1, LL b1, LL m2, LL b2, LL &m3, LL &b3) {
    LL d = gcd(m1, m2);
    LL b = b2 - b1;
    if(b % d) {
        return false;
    }
    b = (b % m2 + m2) % m2;
    m1 /= d;
    m2 /= d;
    b /= d;
    b *= inverse(m1, m2);
    b %= m2;
    b *= m1 * d;
    b += b1;
    m3 = m1 * m2 * d;
    b3 = (b % m3 + m3) % m3;
    return true;
}

LL CRT(LL *b, LL *m, int n) {
    LL tempm, tempb;
    for(int i = 0; i < n - 1; i++) {
        if(merge(m[i], b[i], m[i + 1], b[i + 1], tempm, tempb)) {
            m[i + 1] = tempm;
            b[i + 1] = tempb;
        }
        else {
            return -1;
        }
    }
    return (b[n - 1] % m[n - 1] + m[n - 1]) % m[n - 1];
}

int main() {
    int n;
    while(cin >> n) {
        for(int i = 0; i < n; i++) {
            cin >> m[i] >> b[i];
        }
        cout << CRT(b, m, n) << endl;
    }
    return 0;
}

欧拉筛

欧拉筛筛选素数

const int maxn = 100000;
bool isPrim[maxn];
int primes[maxn];
int countPrim;
void sieve() {
    int temp;
    memset(isPrim, -1, sizeof(isPrim));
    isPrim[1] = false;
    countPrim = 0;
    for(int i = 2; i < maxn; i++) {
        if(isPrim[i]) {
            primes[countPrim++] = i;
        }
        for(int j = 0; j < countPrim && (temp = i * primes[j]) < maxn; j++) {
            isPrim[temp] = false;
            if(i % primes[j] == 0) {
                break;
            }
        }
    }
}

欧拉筛计算欧拉函数

const int maxn = 100000;
bool isPrim[maxn];
int primes[maxn];
int phi[maxn];
int countPrim;
void sieve() {
    LL temp;
    memset(isPrim, -1, sizeof(isPrim));
    memset(phi, 0, sizeof(phi));
    phi[1] = 1;
    countPrim = 0;
    for(int i = 2; i < maxn; i++) {
        if(isPrim[i]) {
            primes[countPrim++] = i;
            phi[i] = i - 1;
        }
        for(int j = 0; j < countPrim && (temp = i * primes[j]) < maxn; j++) {
            isPrim[temp] = false;
            if(i % primes[j] == 0) {
                phi[temp] = phi[i]*primes[j];
                break;
            }
            else {
                phi[temp] = phi[i]*phi[primes[j]];
            }
        }
    }
}

欧拉筛计算莫比乌斯函数

const int maxn = 100000;
bool isPrim[maxn];
int primes[maxn];
int mu[maxn];
int countPrm;
void sieve() {
    memset(isPrim, -1, sizeof(isPrim));
    memset(mu, 0, sizeof(mu));
    memset(primes, 0, sizeof(primes));
    countPrim = 0;
    isPrim[1] = false;
    mu[1] = 1;
    LL temp;
    for(int i = 2; i < maxn; i++) {
        if(isPrim[i]) {
            primes[countPrim++] = i;
            mu[i] = -1;
        }
        for(int j = 0; j < countPrim && (temp = i * primes[j]) < maxn; j++) {
            isPrim[temp] = false;
            if(i % primes[j] == 0) {
                mu[temp] = 0;
                break;
            }
            else {
                mu[temp] = -mu[i];
            }
        }
    }
}

Gauss-Jordan消元

const double eps = 1e-8;
const int maxn = 110;
typedef double Matrix[maxn][maxn];

void guassJordan(Matrix A, int n) {
    int i, j, k, r;
    for(i = 0; i < n; i++) {
        r = i;
        for(j = i + 1; j < n; j++) {
            if(fabs(A[j][i]) > fabs(A[r][i])) {
                r = j;
            }
        }
        if(fabs(A[r][i]) < eps) {
            continue;
        }
        if(r != i) {
            for(j = 0; j <= n; j++) {
                swap(A[r][j], A[i][j]);
            }
        }
        for(k = 0; k < n; k++) {
            if(k != i) {
                for(j = n; j >= i; j--) {
                    A[k][j] -= A[k][i] / A[i][i] * A[i][j];
                }
            }
        }
    }
}

FFT NTT FWT

FFT

//模板验证 HDU1402 模拟大数乘法
#include
#include
#include
#include
#include
#include
using namespace std;

const int maxm = 17;
const int maxn = 1 << maxm;
complex buffer1[maxn];
complex buffer2[maxn];
complex omega[maxn];
complex conjOmega[maxn];
char number1[maxn];
char number2[maxn];
int sum[maxn];

void initOmega(int n) {
    double pi = acos(-1);
    for(int i = 0; i < n; i++) {
        omega[i] = complex(cos(2 * pi * i / n), sin(2 * pi * i / n));
        conjOmega[i] = conj(omega[i]);
    }
}

void inverse(int n, complex* buffer) {
    for(int i = 0, j = 0; i < n; i++) {
        if(i < j) {
            swap(buffer[i], buffer[j]);
        }
        for(int k = n >> 1; (j ^= k) < k; k >>= 1) {
            continue;
        }
    }
}

void fft(int n, complex* buffer, complex* omega) {
    inverse(n, buffer);
    for(int i = 2; i <= n; i <<= 1) {
        int m = i >> 1;
        int step = n / i;
        for(int j = 0; j < n; j += i) {
            for(int k = 0; k < m; k++) {
                complex temp = buffer[j + k];
                buffer[j + k] = buffer[j + k] + omega[step * k] * buffer[j + k + m];
                buffer[j + k + m] = temp - omega[step * k] * buffer[j + k + m];
            }
        }
    }
}

int main() {
    int len1, len2, len;
    while(scanf("%s%s", number1, number2) == 2) {
        len1 = strlen(number1);
        len2 = strlen(number2);
        for(len = 1; len < len1 * 2 || len < len2 * 2; len <<= 1) {
            continue;
        }
        for(int i = 0; i < len1; i++) {
            buffer1[i] = number1[len1 - i - 1] - '0';
        }
        for(int i = len1; i < len; i++) {
            buffer1[i] = 0;
        }
        for(int i = 0; i < len2; i++) {
            buffer2[i] = number2[len2 - i - 1] - '0';
        }
        for(int i = len2; i < len; i++) {
            buffer2[i] = 0;
        }
        initOmega(len);
        fft(len, buffer1, omega);
        fft(len, buffer2, omega);
        for(int i = 0; i < len; i++) {
            buffer1[i] *= buffer2[i];
        }
        fft(len, buffer1, conjOmega);
        for(int i = 0; i < len; i++) {
            sum[i] = (int)(buffer1[i].real() / len + 0.5);
        }
        for(int i = 0; i < len; i++) {
            sum[i + 1] += sum[i] / 10;
            sum[i] %= 10;
        }
        for(len = len1 + len2 - 1; len > 0 && sum[len] <= 0; len--) {
            continue;
        }
        for(int i = len; i >= 0; i--) {
            printf("%d", sum[i]);
        }
        putchar('\n');
    }
    return 0;
}

NTT

//模板验证 HDU1402 模拟大数乘法
#include
#include
#include
#include
#include
using namespace std;
typedef long long LL;

const LL MOD = 1004535809;
const LL gOfMod = 3;
const LL maxm = 17;
const LL maxn = 1 << maxm;
LL g[maxn];
LL invG[maxn];
LL buffer1[maxn];
LL buffer2[maxn];
char number1[maxn];
char number2[maxn];
LL sum[maxn];

LL quickPowerMod(LL n, LL r, LL mod) {
    LL res = 1;
    n %= mod;
    while(r) {
        if(r & 1) {
            res = res * n % mod;
        }
        n = n * n % mod;
        r >>= 1;
    }
    return res;
}

void initG(LL n) {
    g[0] = 1;
    g[1] = quickPowerMod(3, (MOD - 1) / n, MOD);
    invG[0] = 1;
    invG[1] = quickPowerMod(g[1], MOD - 2, MOD);
    for(LL i = 2; i < n; i++) {
        g[i] = g[i - 1] * g[1] % MOD;
        invG[i] = invG[i - 1] * invG[1] % MOD;
    }
}

void inverseBuffer(LL n, LL* buffer) {
    for(LL i = 0, j = 0; i < n; i++) {
        if(i > j) {
            swap(buffer[i], buffer[j]);
        }
        for(int k = n >> 1; (j ^= k) < k; k >>= 1) {
            continue;
        }
    }
}

void fft(LL n, LL* buffer, LL* g) {
    inverseBuffer(n, buffer);
    for(LL i = 2; i <= n; i <<= 1) {
        LL m = i / 2;
        LL step = n / i;
        for(int j = 0; j < n; j+= i) {
            for(int k = 0; k < m; k++) {
                LL temp = buffer[j + k];
                buffer[j + k] = (buffer[j + k] + g[k * step] * buffer[j + k + m] % MOD) % MOD;
                buffer[j + k + m] = ((temp - g[k * step] * buffer[j + k + m]) % MOD + MOD) % MOD;
            }
        }
    }
}

int main() {
    LL len1, len2, len;
    while(scanf("%s%s", number1, number2) == 2) {
        len1 = strlen(number1);
        len2 = strlen(number2);
        for(len = 1; len < len1 * 2 || len < len2 * 2; len <<= 1) {
            continue;
        }
        for(int i = 0; i < len1; i++) {
            buffer1[i] = number1[len1 - i - 1] - '0';
        }
        for(int i = len1; i < len; i++) {
            buffer1[i] = 0;
        }
        for(int i = 0; i < len2; i++) {
            buffer2[i] = number2[len2 - i - 1] - '0';
        }
        for(int i = len2; i < len; i++) {
            buffer2[i] = 0;
        }
        initG(len);
        fft(len, buffer1, g);
        fft(len, buffer2, g);
        for(int i = 0; i < len; i++) {
            buffer1[i] = buffer1[i] * buffer2[i] % MOD;
        }
        fft(len, buffer1, invG);
        LL invLen = quickPowerMod(len, MOD - 2, MOD);
        for(int i = 0; i < len; i++) {
            sum[i] = buffer1[i] * invLen % MOD;
        }
        for(int i = 0; i < len; i++) {
            sum[i + 1] += sum[i] / 10;
            sum[i] %= 10;
        }
        for(len = len1 + len2 - 1; len > 0 && sum[len] <= 0; len--) {
            continue;
        }
        while(len >= 0) {
            printf("%I64d", sum[len--]);
        }
        putchar('\n');
    }
    return 0;
}

数据结构

线段树

线段树单点更新

非递归建树

/*
模板验证 HDU1166
求区间和,题目给的区间描述是左闭右闭的,代码中全部使用左开右闭的区间
所有下标从0开始
非递归建树建立的是一颗满二叉树
*/
#include
#include
#include
#include

#define Add "Add"
#define Sub "Sub"
#define End "End"
#define Query "Query"

using namespace std;

const int maxn = 50050;
const int inf = 0x7fffffff;
int date[maxn * 4];
int leaves;	//叶子节点数量
int n;
char order[6];

void init() {
    leaves = 1;
    while(leaves < n) {
        leaves <<= 1;
    }
    memset(date, 0, sizeof(date));
}

void update(int index, int val) {
    index += leaves - 1;       //原数组中下标在线段树对应的数组中下标值的计算方法
    date[index] += val;
    while(index > 0) {
        index = (index - 1) / 2;
        date[index] += val;
    }
}

//要求查询的区间为[a, b), 当前查询的区间为[l, r),当前区间对应的节点为k
//在主函数中调用时调用方法为query(a, b, 0, 0, leaves);
int query(int a, int b, int k, int l, int r) {
    if(a >= r || b <= l) {
        return 0;
    }
    if(a <= l && b >= r) {
        return date[k];
    }
    return query(a, b, 2 * k + 1, l, (l + r)/2) + query(a, b, 2 * k + 2, (l + r) / 2, r);
}

int main() {
    int t, temp, index, l ,r;
    scanf("%d", &t);
    for(int cs = 1; cs <= t; cs++) {
        scanf("%d", &n);
        init();
        for(int i = 0; i < n; i++) {
            scanf("%d", &temp);
            update(i, temp);
        }
        printf("Case %d:\n", cs);
        while(scanf("%s", order)) {
            if(strcmp(order, End) == 0) {
                break;
            }
            else if(strcmp(order, Add) == 0) {
                scanf("%d%d", &index, &temp);
                update(index - 1, temp);
            }
            else if(strcmp(order, Sub) == 0) {
                scanf("%d%d", &index, &temp);
                update(index - 1, -temp);
            }
            else if(strcmp(order, Query) == 0) {
                scanf("%d%d", &l, &r);
                printf("%d\n", query(l - 1, r, 0, 0, leaves));
            }
        }
    }
    return 0;
}

递归建树

/*
模板验证 HDU1166
求区间和,题目给的区间描述是左闭右闭的,代码中全部使用左开右闭的区间
所有下标从0开始
*/
#include
#include
#include
#include

#define Add "Add"
#define Sub "Sub"
#define End "End"
#define Query "Query"

using namespace std;

const int maxn = 50050;
const int inf = 0x7fffffff;
int date[maxn * 4];
int sourceDate[maxn];
int n;
char order[6];

//当前节点为v,对应区间为[l, r)
void build(int v, int l, int r) {
    if(r - l == 1) {
        date[v] = sourceDate[l];
    }
    else {
        build(2 * v + 1, l, (l + r) / 2);
        build(2 * v + 2, (l + r) / 2, r);
        date[v] = date[2 * v + 1] + date[2 * v + 2];
    }
}

//需要修改原数组中下标为index的元素,变化的值为val
void update(int index, int val, int v, int l, int r) {
    if(r - l == 1) {
        date[v] += val;
    }
    else {
        if(index < (l + r) / 2) {
            update(index, val, v * 2 + 1, l, (l + r) / 2);
        }
        else {
            update(index, val, v * 2 + 2, (l + r) / 2, r);
        }
        date[v] = date[2 * v + 1] + date[2 * v + 2];
    }
}

int query(int a, int b, int v, int l, int r) {
    if(a >= r || b <= l) {
        return 0;
    }
    if(a <= l && b >= r) {
        return date[v];
    }
    return query(a, b, 2 * v + 1, l, (l + r) / 2) + query(a, b, 2 * v + 2, (l + r) / 2, r);
}

int main() {
    int t, index, val, l, r;
    scanf("%d", &t);
    for(int cs = 1; cs <= t; cs++) {
        scanf("%d", &n);
        for(int i = 0; i < n; i++) {
            scanf("%d", sourceDate + i);
        }
        build(0, 0, n);
        printf("Case %d:\n", cs);
        while(scanf("%s", order) == 1) {
            if(strcmp(order, End) == 0) {
                break;
            }
            else if(strcmp(order, Add) == 0) {
                scanf("%d%d", &index, &val);
                update(index - 1, val, 0, 0, n);
            }
            else if(strcmp(order, Sub) == 0) {
                scanf("%d%d", &index, &val);
                update(index - 1, -val, 0, 0, n);
            }
            else if(strcmp(order, Query) == 0) {
                scanf("%d%d", &l, &r);
                printf("%d\n", query(l - 1, r, 0, 0, n));
            }
        }
    }
    return 0;
}

线段树区间更新

/*
模板验证POJ3468
计算区间和 
*/
#include
#include
#include
#include
using namespace std;
typedef long long LL;

const int maxn = 100010;
LL sourceDate[maxn];	//原始数据
LL date[maxn << 2];	//存储线段树
LL mark[maxn << 2];	//懒惰标记

void build(int v, int l, int r) {
    mark[v] = 0;
    if(r - l == 1) {
        date[v] = sourceDate[l];
    }
    else {
        build(2 * v + 1, l, (l + r) / 2);
        build(2 * v + 2, (l + r) / 2, r);
        date[v] = date[2 * v + 1] + date[2 * v + 2];
    }
}

//下移懒惰标记
void pushDown(int v, int l, int r) {
    if(mark[v] != 0) {
        int chl = 2 * v + 1;
        int chr = 2 * v + 2;
        int mid = (l + r) / 2;
        mark[chl] += mark[v];
        mark[chr] += mark[v];
        date[chl] += (mid - l) * mark[v];
        date[chr] += (r - mid) * mark[v];
        mark[v] = 0;
    }
}

void update(int a, int b, int val, int v, int l, int r) {
    if(a >= r || b <= l) {
        return ;
    }
    else if(a <= l && b >= r) {
        mark[v] += val;
        date[v] += val * (r - l);
    }
    else {
        pushDown(v, l, r);
        update(a, b, val, 2 * v + 1, l, (l + r) / 2);
        update(a, b, val, 2 * v + 2, (l + r) / 2, r);
        date[v] = date[2 * v + 1] + date[2 * v + 2];
    }
}

LL query(int a, int b, int v, int l, int r) {
    if(a >= r || b <= l) {
        return 0;
    }
    else if(a <= l && b >= r) {
        return date[v];
    }
    else {
        pushDown(v, l, r);
        return query(a, b, 2 * v + 1, l, (l + r) / 2) + query(a, b, 2 * v + 2, (l + r) / 2, r);
    }
}

int main() {
    int n, m;
    char order;
    int a, b, v;
    while(scanf("%d%d", &n, &m) == 2) {
        for(int i = 0; i < n; i++) {
            scanf("%lld", sourceDate + i);
        }
        build(0, 0, n);
        for(int i = 0; i < m; i++) {
            scanf("\n%c", &order);
            if(order == 'Q') {
                scanf("%d%d", &a, &b);
                printf("%lld\n", query(a- 1, b, 0, 0, n));
            }
            else {
                scanf("%d%d%d", &a, &b, &v);
                update(a - 1, b, v, 0, 0, n);
            }
        }
    }
    return 0;
}

线段树+扫描线

HDU1542
//线段树 + 扫描线 + 离散化
//注意这里的区间是闭区间
#include
#include
#include
#include
using namespace std;
struct Edge {
    double xl;
    double xr;
    double h;
    int flag;
    bool operator<(const Edge &e) {
        return h < e.h;
    }
};

const int maxn = 110;
double len[maxn << 3];
int cover[maxn << 3];
Edge edge[maxn << 1];
double point[maxn << 1];

void build() {
    memset(len, 0, sizeof(len));
    memset(cover, 0, sizeof(cover));
}

void update(int a, int b, int val, int v, int l, int r) {
    if(a >= r || b <= l) {
        return ;
    }
    else if(a <= l && b >= r) {
        cover[v] += val;
        if(cover[v]) {
            len[v] = point[r] - point[l];
        }
        else if(r - l > 1) {
            len[v] = len[2 * v + 1] + len[2 * v + 2];
        }
        else {
            len[v] = 0;
        }
    }
    else {
        update(a, b, val, 2 * v + 1, l, (l + r) / 2);
        update(a, b, val, 2 * v + 2, (l + r) / 2, r);
        if(cover[v]) {
            len[v] = point[r] - point[l];
        }
        else {
            len[v] = len[2 * v + 1] + len[2 * v + 2];
        }
    }
}

int main() {
    int n;
    double x1, x2, y1, y2;
    int cnt;
    int cs = 0;
    double ans;
    int a, b;
    while(scanf("%d", &n) && n) {
        for(int i = 0; i < n; i++) {
            scanf("%lf%lf%lf%lf", &x1, &y1, &x2, &y2);

            edge[2 * i].xl = edge[2 * i + 1].xl = x1;
            edge[2 * i].xr = edge[2 * i + 1].xr = x2;
            edge[2 * i].flag = 1;
            edge[2 * i + 1].flag = -1;
            edge[2 * i].h = y1;
            edge[2 * i + 1].h = y2;

            point[2 * i] = x1;
            point[2 * i + 1] = x2;
        }
        sort(edge, edge + 2 * n);
        sort(point, point + 2 * n);
        cnt = 0;
        for(int i = 1; i < 2 * n; i++) {
            if(point[i] != point[cnt]) {
                point[++cnt] = point[i];
            }
        }
        build();
        ans = 0.0;
        for(int i = 0; i < 2 * n - 1; i++) {
            a = lower_bound(point, point + cnt + 1, edge[i].xl) - point;
            b = lower_bound(point, point + cnt + 1, edge[i].xr) - point;
            update(a, b, edge[i].flag, 0, 0, cnt);
            ans += len[0] * (edge[i + 1].h - edge[i].h);
            //printf("%f %f\n", len[0], edge[i + 1].h - edge[i].h);
        }
        printf("Test case #%d\nTotal explored area: %.2f\n\n", ++cs, ans);
    }
    return 0;
}

可持久化线段树 无更新

//模板验证 HDU2665
#include
#include
#include
#include
using namespace std;

struct Node {
    int ls;
    int rs;
    int val;
};

const int maxn = 100010;
Node date[maxn * 20];
int roots[maxn];
int sourceDate[maxn];
int sortDate[maxn];
int cntNode;
int sz;

void build(int &v, int l, int r) {
    v = cntNode++;
    date[v].val = 0;
    if(r - l == 1) {
        return ;
    }
    build(date[v].ls, l, (l + r) / 2);
    build(date[v].rs, (l + r) / 2, r);
}

void update(int pre, int &now, int l, int r, int val) {
    now = cntNode++;
    int mid = (l + r) / 2;
    date[now].ls = date[pre].ls;
    date[now].rs = date[pre].rs;
    date[now].val = date[pre].val + 1;
    if(r - l == 1) {
        return ;
    }
    if(val < mid) {
        update(date[pre].ls, date[now].ls, l, mid, val);
    }
    else {
        update(date[pre].rs, date[now].rs, mid, r, val);
    }
}

int query(int k, int lv, int rv, int l, int r) {
    if(r - l == 1) {
        return l;
    }
    int mid = (l + r) / 2;
    int res = date[date[rv].ls].val - date[date[lv].ls].val;
    if(k <= res) {
        return query(k, date[lv].ls, date[rv].ls, l, mid);
    }
    else {
        return query(k - res, date[lv].rs, date[rv].rs, mid, r);
    }
}

int main() {
    int t;
    int n, m;
    int l, r, k;
    int ans;
    int val;
    scanf("%d", &t);
    while(t--) {
        scanf("%d%d", &n, &m);
        for(int i = 0; i < n; i++) {
            scanf("%d", sourceDate + i);
            sortDate[i] = sourceDate[i];
        }
        sort(sortDate, sortDate + n);
        sz = unique(sortDate, sortDate + n) - sortDate;
        cntNode = 0;
        build(roots[0], 0, sz);
        for(int i = 0; i < n; i++) {
            val = lower_bound(sortDate, sortDate + sz, sourceDate[i]) - sortDate;
            update(roots[i], roots[i + 1], 0, sz, val);
        }
        for(int i = 0; i < m; i++) {
            scanf("%d%d%d", &l, &r, &k);
            ans = query(k, roots[l - 1], roots[r], 0, sz);
            printf("%d\n", sortDate[ans]);
        }
    }
    return 0;
}

可持久化线段树 单点更新

//ZOJ2112
#include
#include
#include
#include
#include
using namespace std;

struct Node {
    int ls;
    int rs;
    int val;
};

struct Operation {
    char order;
    int l;
    int r;
    int k;
};

const int maxn = 50050;
const int maxm = 10010;
Node date[(maxn + maxm) * 32];  //主席树上的节点
Operation op[maxm];             //存储操作
int roots[maxn];                //主席树各个版本的根节点
int bits[maxn];                 //使用树状数组维护的各个主席树的根节点
int sourceDate[maxn + maxm];    //原始序列
int sortDate[maxn + maxm];      //离散化之后的序列
vector li, ri;             //查询和修改时需要访问的主席树(使用树状数组维护的)的下标
char order[3];                  //暂时存储输入的命令'Q'和'C'
int cntNode;                    //统计所有的主席树的节点数量
int sz;                         //保存sortDate数组的大小
int n, m;

void build(int &v, int l, int r) {  //节点date[v]维护的区间为[l, r)
    v = cntNode++;
    date[v].val = 0;
    if(r - l == 1) {
        return ;
    }
    build(date[v].ls, l, (l + r) / 2);
    build(date[v].rs, (l + r) / 2, r);
}

void update(int pre, int &now, int l, int r, int k, int val) {
//now 是当前建立的节点,pre是前一棵树上与now对应的节点,date[now]维护的区间是[l, r)
//当前插入的值在离散化区间中对应的数字为k
//如果是向序列中插入k,则val = 1
//如果是将k替换为其它数字,则val = -1
    now = cntNode++;
    date[now].ls = date[pre].ls;
    date[now].rs = date[pre].rs;
    date[now].val = date[pre].val + val;
    if(r - l == 1) {
        return ;
    }
    int mid = (l + r) / 2;
    if(k < mid) {
        update(date[pre].ls, date[now].ls, l, mid, k, val);
    }
    else {
        update(date[pre].rs, date[now].rs, mid, r, k, val);
    }
}

void add(int index, int val) {
//当有修改操作时需要调用该函数
//index指的是被修改的数字在原序列中的下标(从1开始计数)
//我的下标是从0开始的,所以在计算离散化后它对应的数字的时候,使用index - 1
//如果当前的sourceDate[index - 1]是还没修改的数字,则val = -1;否则val = 1
    int k = lower_bound(sortDate, sortDate + sz, sourceDate[index - 1]) - sortDate;
    while(index <= n) {
        update(bits[index], bits[index], 0, sz, k, val);
        index += index & -index;
    }
}

int query(int sv, int ev, int l, int r, int k) {
//若查询的区间为[s + 1, e],则在组函数中调用该函数时
//sv 为第s棵线段树的根节点下标
//ev 为第e棵线段树的根节点下标(从0开始计数)
//sv和ev所维护的区间为[l, r)
//函数询问的是区间[l, r)中第k小的数(从1开始计数)
    if(r - l == 1) {
        return l;
    }
    int res = 0;

//下面两个循环采用了树状数组的思维进行求和
    for(int i : ri) {
        res += date[date[i].ls].val;
    }
    for(int i : li) {
        res -= date[date[i].ls].val;
    }

    res += date[date[ev].ls].val - date[date[sv].ls].val;
    if(k <= res) {
        for(auto it = ri.begin(); it != ri.end(); it++) {
            *it = date[*it].ls;
        }
        for(auto it = li.begin(); it != li.end(); it++) {
            *it = date[*it].ls;
        }
        return query(date[sv].ls, date[ev].ls, l, (l + r) / 2, k);
    }
    else {
        for(auto it = ri.begin(); it != ri.end(); it++) {
            *it = date[*it].rs;
        }
        for(auto it = li.begin(); it != li.end(); it++) {
            *it = date[*it].rs;
        }
        return query(date[sv].rs, date[ev].rs, (l + r) / 2, r, k - res);
    }
}

int main() {
    int t;
    int k, ans, temp;
    scanf("%d", &t);
    while(t--) {
        scanf("%d%d", &n, &m);
        sz = 0;
        for(int i = 0; i < n; i++) {
            scanf("%d", sourceDate + i);
            sortDate[sz++] = sourceDate[i];
        }
        for(int i = 0; i < m; i++) {
            scanf("%s", order);
            op[i].order = order[0];
            if(order[0] == 'Q') {
                scanf("%d%d%d", &op[i].l, &op[i].r, &op[i].k);
            }
            else {
                scanf("%d%d", &op[i].l, &op[i].r);
                sortDate[sz++] = op[i].r;
            }
        }
        sort(sortDate, sortDate + sz);
        sz = unique(sortDate, sortDate + sz) - sortDate;
        cntNode = 0;
        build(roots[0], 0, sz);
        for(int i = 0; i < n; i++) {
            k = lower_bound(sortDate, sortDate + sz, sourceDate[i]) - sortDate;
            update(roots[i], roots[i + 1], 0, sz, k, 1);
        }
        for(int i = 0; i <= n; i++) {
            //注意这里是 <= ,树状数组维护的是n + 1棵线段树
            bits[i] = roots[0];
        }
        for(int i = 0; i < m; i++) {
            if(op[i].order == 'Q') {
                li.clear();
                ri.clear();

                //下面两个循环计算的是在query函数中
                //使用树状数组思维求和时
                //需要访问的线段树根节点下标
                temp = op[i].r;
                while(temp) {
                    ri.push_back(bits[temp]);
                    temp -= temp & -temp;
                }

                temp = op[i].l - 1;
                while(temp) {
                    li.push_back(bits[temp]);
                    temp -= temp & -temp;
                }
                ans = query(roots[op[i].l - 1], roots[op[i].r], 0, sz, op[i].k);
                printf("%d\n", sortDate[ans]);
            }
            else {
                add(op[i].l, -1);
                sourceDate[op[i].l - 1] = op[i].r;
                add(op[i].l, 1);
            }
        }
    }
    return 0;
}

树状数组

一维树状数组

一维树状数组单点更新

#include
#include
#include
#include
using namespace std;
typedef long long LL;

const int maxn = 500050;
LL bit[maxn];

void add(int k, int val) {
    while(k < maxn) {
        bit[k] += val;
        k += k & -k;
    }
}

LL sum(int k) {
    LL res = 0;
    while(k > 0) {
        res += bit[k];
        k -= k & -k;
    }
    return res;
}

int main() {
}

一维树状数组区间更新

//模板验证 HDU1556
#include
#include
#include
#include
using namespace std;

const int maxn = 100010;
int bit0[maxn];
int bit1[maxn];
int n;

void pointAdd(int *bit, int k, int val) {
    while(k <= n) {
        bit[k] += val;
        k += k & -k;
    }
}

int sum(int *bit, int k) {
    int res = 0;
    while(k > 0) {
        res += bit[k];
        k -= k & -k;
    }
    return res;
}

void rangeAdd(int l, int r, int val) {
    pointAdd(bit0, l, -val * (l - 1));
    pointAdd(bit1, l, val);
    pointAdd(bit0, r + 1, val * r);
    pointAdd(bit1, r + 1, -val);
}

int rangeSum(int l, int r) {
    return  sum(bit0, r) + r * sum(bit1, r) -
            sum(bit0, l - 1) - (l - 1) * sum(bit1, l - 1);
}

int main() {
    int l, r;
    while(scanf("%d", &n) && n) {
        memset(bit0, 0, sizeof(bit0));
        memset(bit1, 0, sizeof(bit1));
        for(int i = 0; i < n; i++) {
            scanf("%d%d", &l, &r);
            rangeAdd(l, r, 1);
        }
        for(int i = 1; i < n; i++) {
            printf("%d ", rangeSum(i, i));
        }
        printf("%d\n", rangeSum(n, n));
    }
    return 0;
}

二维树状数组

二维树状数组单点更新

const int maxn = 1000;
int bit[maxn][maxn];

void add(int x, int y, int val) {
    int memoryY = y;
    while(x <= maxn) {
        y = memoryY;
        while(y <= maxn) {
            bit[x][y] += val;
            y += y & -y;
        }
        x += x & -x;
    }
}

int sum(int x, int y) {
    int memoryY = y;
    int res = 0;
    while(x > 0) {
        y = memoryY;
        while(y > 0) {
            res += bit[x][y];
            y -= y & -y;
        }
        x -= x & -x;
    }
    return res;
}

字符串

KMP

#include
#include
#include
#include
using namespace std;

const int maxn = 400001;
char source[maxn];
char pattern[maxn];
int next[maxn];
/*
int violentMatch(char *s, char *p) {
    int slen = strlen(s);
    int plen = strlen(p);
    int si = 0;
    int pi = 0;
    while(si < slen && pi < plen) {
        if(s[si] == p[pi]) {
            si++;
            pi++;
        }
        else {
            si = si - pi + 1;
            pi = 0;
        }
    }
    if(pi == plen) {
        return si - pi;
    }
    else {
        return -1;
    }
}
*/
void setNext(char *p, int *nx) {
    int plen = strlen(p);
    nx[0] = -1;
    int k = -1;
    int j = 0;
    while(j < plen - 1) {
        if(k == -1 || p[j] == p[k]) {
            k++;
            j++;
            if(p[j] != p[k]) {
                next[j] = k;
            }
            else {
                next[j] = next[k];
            }
        }
        else {
            k = next[k];
        }
    }
}
//判断p是否在s中,使用的next数组是p的next数组
//返回p在s中第一次出现的位置(0开始) 若s中没有p,则返回-1
int KMP(char *s, char *p) {
    int slen = strlen(s);
    int plen = strlen(p);
    int si = 0;
    int pi = 0;
    while(si <= slen - (plen - pi) && pi < plen) {
        if(pi == -1 || s[si] == p[pi]) {
            si++;
            pi++;
        }
        else {
            pi = next[pi];
        }
    }
    if(pi == plen) {
        return si - pi;
    }
    else {
        return -1;
    }
}

int main() {
    return 0;
}

扩展KMP

#include
#include
#include
#include
using namespace sdt;

const int maxn = 1000;
char source[maxn];
char pattern[maxn];
int next[maxn];
int extend[maxn];

//求解setNext使用 abcabc 作为子串,自己手动使用该算法算一遍应该就懂了...
void setNext(char *str, int *nx) {
    int len = strlen(str);
    nx[0] = len;
    for(nx[1] = 0; nx[1] < n - 1 && str[nx[1]] == str[nx[1] + 1]; nx[1]++) {
        continue;
    }
    int p0 = 1, p = nx[1] + 1 - 1;
    for(int i = 2; i < len; i++) {
        if(nx[i - p0] + i - 1 < p) {
            nx[i] = nx[i - p0];
        }
        else {
            nx[i] = max(0, p - i + 1);
            while(i + nx[i] < len && str[nx[i]] == str[i + nx[i]]) {
                nx[i]++;
            }
            p0 = i;
            p = i + nx[i] - 1;
        }
    }
}

void exKMP(char *sc, char *pat) {
    setNext(pat, next);
    int sLen = strlen(sc);
    int pLen = strlen(pat);
    for(extend[0] = 0; extend[0] < sLen && extend[0] < pLen && sc[extend[0]] == pat[extend[0]]; extend[0]++) {
        continue;
    }
    int p0 = 0;
    int p = extend[0] - 1;
    for(int i = 1; i < sLen; i++) {
        if(next[i - p0] + i - 1 < p) {
            extend[i] = next[i - p0];
        }
        else {
            int j = p - i + 1;
            if(j < 0) {
                j = 0;
            }
            while(j < pLen && i + j < sLen && sc[i + j] == pat[j]) {
                j++;
            }
            p0 = i;
            p = i + j - 1;
        }
    }
}

Manacher

#include
#include
#include
#include
using namespace std;

const int maxn = 1000;
char source[maxn];          //原始字符串数据
char tempStr[maxn << 1];    //转换后的字符串
int maxLen[maxn << 1];      //maxLen[i]为以tempStr[i]为中心的最长回文子串最右边的位置到tempStr[i]的距离

int manacher() {
    int len = strlen(source);
    for(int i = 0; i < len; i++) {
        tempStr[2 * i] = '#';
        tempStr[2 * i + 1] = source[i];
    }
    tempStr[2 * len] = '#';
    tempStr[2 * len + 1] = '\0';
    len = 2 * len + 1;
    memset(maxLen, 0, sizeof(maxLen));
    maxLen[0] = 1;
    int p0 = 0, p = 0, ans = 0;     //p为当前匹配到了的最远的字符的下标,p0为取得该最远位置的中心的下标
    for(int i = 1; i < len; i++) {
        if(i < p) {
            if(i  + maxLen[2 * p0 - i] - 1 < p) {
                maxLen[i] = maxLen[2 * p0 - i];
            }
            else {
                maxLen[i] = p - i + 1;
                while(i - maxLen[i] >= 0 && i + maxLen[i] < len && tempStr[i - maxLen[i]] == tempStr[i + maxLen[i]]) {
                    maxLen[i]++;
                }
                if(i + maxLen[i] > p0 + maxLen[p0]) {
                    p0 = i;
                    p = i + maxLen[i] - 1;
                }
            }
        }
        else {
            maxLen[i] = 1;
            while(i - maxLen[i] >= 0 && i + maxLen[i] < len && tempStr[i - maxLen[i]] == tempStr[i + maxLen[i]]) {
                maxLen[i]++;
            }
            if(i + maxLen[i] > p0 + maxLen[p0]) {
                p0 = i;
                p = i + maxLen[i] - 1;
            }
        }
        ans = max(ans, maxLen[i] - 1);
    }
    return ans;
}

字典树

const alphNum = 26;
const int maxn = 1000;
int trie[maxn][alphNum];

void insert(const char *s) {
    int len = strlen(s);
    root = 0;
    int id;
    for(int i = 0; i < len; i++) {
        id = s[i] - 'a';
        if(tire[root][id] == 0) {
            trie[root][id] = ++tot;
        }
        root = trie[root][id];
    }
}

bool find(const char *s) {
    int len = strlen(s);
    int root = 0;
    int id;
    for(int i = 0; i < len; i++) {
        id = s[i] - 'a';
        if(trie[root][id] == 0) {
            return false;
        }
        root = trie[root][id];
    }
    return true;
}

AC自动机

//板子验证 HDU2222
#include
#include
#include
#include
#include
using namespace std;
class Node {
public:
    Node *next[26];
    Node *fail;
    int cnt;    //本来觉得用bool就好了,但是看了HDU2222之后,发现还是用int好一点
    Node() {
        for(int i = 0; i < 26; i++) {
            next[i] = NULL;
        }
        fail = NULL;
        cnt = 0;
    }
};

class ACauto {
public:
    Node *root;

    ACauto() {
        root = new Node;
    }

    ~ACauto() {
        clear();
        delete root;
    }

    void clear() {
        Node* temp;
        queue q;
        q.push(root);
        while(!q.empty()) {
            temp = q.front();
            q.pop();
            for(int i = 0; i < 26; i++) {
                if(temp->next[i]) {
                    q.push(temp->next[i]);
                }
            }
            if(temp != root) {
                delete temp;
            }
        }
        for(int i = 0; i < 26; i++) {
            root->next[i] = NULL;
        }
    }

    void insert(const char * s) {
        Node *p = root;
        int id;
        for(int i = 0; s[i]; i++) {
            id = s[i] - 'a';
            if(p->next[id] == NULL) {
                p->next[id] = new Node;
            }
            p = p->next[id];
        }
        p->cnt++;
    }

    void buildFail() {
        Node *father;
        Node *fatherFail;
        queue q;
        q.push(root);
        while(!q.empty()) {
            father = q.front();
            q.pop();
            for(int i = 0; i < 26; i++) {
                if(father->next[i]) {
                    fatherFail = father->fail;
                    while(fatherFail) {
                        if(fatherFail->next[i]) {
                            father->next[i]->fail = fatherFail->next[i];
                            break;
                        }
                        fatherFail = fatherFail->fail;
                    }
                    if(fatherFail == NULL) {
                        father->next[i]->fail = root;
                    }
                    q.push(father->next[i]);
                }
            }
        }
    }

    int find(const char *s) {
        int res = 0;
        Node *p = root;
        Node *temp;
        int id;
        for(int i = 0; s[i]; i++) {
            id = s[i] - 'a';
            while(p->next[id] == NULL && p != root) {
                p = p->fail;
            }
            p = p->next[id];
            if(p == NULL) {
                p = root;
            }
            temp = p;
            while(temp != root && temp->cnt != -1) {
                res += temp->cnt;
                temp->cnt = -1;
                temp = temp->fail;
            }
        }
        return res;
    }
};

const int maxKeyLen = 55;
const int maxTextLen = 1000010;
char key[maxKeyLen];
char str[maxTextLen];
ACauto ac;

int main() {
    int t;
    int n;
    scanf("%d", &t);
    while(t--) {
        ac.clear();
        scanf("%d", &n);
        for(int i = 0; i < n; i++) {
            scanf("%s", key);
            ac.insert(key);
        }
        ac.buildFail();
        scanf("%s", str);
        printf("%d\n", ac.find(str));
    }
    return 0;
}

AC自动机+矩阵快速幂

/*
模板测试 POJ2778
在建立失效指针的同时,统计安全节点的数量
以此减小矩阵的规模
测试数据
比如当输入的字符串为
ATCG
T
的时候,安全节点只有两个,为root和A
此时cntNode = 2
*/
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
typedef long long LL;

class Node {
public:
    Node *next[4];
    Node *fail;
    bool isTail;
    int tag;
    Node() {
        memset(next, 0, sizeof(next));
        fail = NULL;
        isTail = false;
        tag = -1;
    }
};

class ACauto {
private:
    Node *root;
    map trans;

public:
    int cntNode;    //统计安全节点的数量
    ACauto() {
        root = new Node();
        root->tag = 0;
        cntNode = 1;
        trans['A'] = 0;
        trans['C'] = 1;
        trans['T'] = 2;
        trans['G'] = 3;
    }

    ~ACauto() {
        clear();
        delete root;
    }

    void clear() {
        queue que;
        Node *temp;
        que.push(root);
        while(!que.empty()) {
            temp = que.front();
            que.pop();
            for(int i = 0; i < 4; i++) {
                if(temp->next[i] != NULL) {
                    que.push(temp->next[i]);
                }
            }
            if(temp != root) {
                delete temp;
            }
        }
        for(int i = 0; i < 4; i++) {
            root->next[i] = NULL;
        }
        cntNode = 1;
    }

    void insert(const char *s) {
        Node *p = root;
        int id;
        for(int i = 0; s[i]; i++) {
            id = trans[s[i]];
            if(p->next[id] == NULL) {
                if(p->isTail) {
                    return ;
                }
                p->next[id] = new Node();
            }
            p = p->next[id];
        }
        p->isTail = true;
    }

    void buildFail() {
        Node *father;
        Node *fatherFail;
        queue que;
        que.push(root);
        while(!que.empty()) {
            father = que.front();
            que.pop();
            for(int i = 0; i < 4; i++) {
                if(father->next[i] != NULL && !(father->next[i]->isTail)) {
                    fatherFail = father->fail;
                    while(fatherFail != NULL) {
                        if(fatherFail->next[i]) {
                            father->next[i]->fail = fatherFail->next[i];
                            break;
                        }
                        fatherFail = fatherFail->fail;
                    }
                    if(fatherFail == NULL) {
                        father->next[i]->fail = root;
                    }
                    if(father->next[i]->fail->isTail) {
                        father->next[i]->isTail = true;
                    }
                    else {
                        father->next[i]->tag = cntNode++;
                        que.push(father->next[i]);
                    }
                }
            }
        }
    }

    void buildMatrix(LL mat[][110]) {
        for(int i = 0; i < cntNode; i++) {
            for(int j = 0; j < cntNode; j++) {
                mat[i][j] = 0;
            }
        }
        queue que;
        Node *father, *temp;
        que.push(root);
        while(!que.empty()) {
            father = que.front();
            que.pop();
            for(int i = 0; i < 4; i++) {
                if(father->next[i] != NULL && !(father->next[i]->isTail)) {
                    que.push(father->next[i]);
                }
                temp = father;
                while(temp != NULL && temp->next[i] == NULL) {
                   temp = temp->fail;
                }
                if(temp == NULL) {
                    mat[father->tag][0]++;
                }
                else if(temp->next[i]->tag >= 0){
                    mat[father->tag][temp->next[i]->tag]++;
                }
            }
        }
    }
};

const LL MOD = 100000;
char str[20];
LL matrix1[110][110];
LL matrix2[110][110];
LL matrix3[110][110];
ACauto ac;

void mul(LL matrix1[110][110], LL matrix2[110][110], int len) {
    for(int i = 0; i < len; i++) {
        for(int j = 0; j < len; j++) {
            matrix3[i][j] = 0;
        }
    }
    for(int i = 0; i < len; i++) {
        for(int j = 0; j >= 1;
    }
    for(int i = 0; i < len; i++) {
        for(int j = 0; j < len; j++) {
            matrix1[i][j] = matrix2[i][j];
        }
    }
}


int main() {
    int m, n;
    LL ans;
    while(scanf("%d%d", &m, &n) == 2) {
        ac.clear();
        for(int i = 0; i < m; i++) {
            scanf("%s", str);
            ac.insert(str);
        }
        ac.buildFail();
        ac.buildMatrix(matrix1);
        power(matrix1, n, ac.cntNode);
        ans = 0;
        for(int j = 0; j < ac.cntNode; j++) {
            ans = (ans + matrix1[0][j]) % MOD;
        }
        printf("%lld\n", ans);
    }
    return 0;
}

后缀数组

//时间复杂度大约为(4~6)O(n * floor(log n))
#include
#include
#include
//#define DEBUG
using namespace std;

const int maxn = 1000;
char str[maxn];             //需要求后缀数组的字符串
int suffixArray[maxn];      //suffixArray[i]为排名第i的子串的起点
                            //在计算过程中,如果出现子串相等的情况,则起点靠前的子串排名靠前
int cnt[maxn];              //统计排名(_rank)为i的二元组的数量
int tempArray1[maxn];
int tempArray2[maxn];
int n;
int Rank[maxn];             //Rank[i]为后缀i的排名
int height[maxn];           //height[i]为后缀i和后缀i-1的最长公共前缀长度

#ifdef DEBUG
void show(int *a) {
    for(int i = 0; i < n; i++) {
        cout << a[i] << " ";
    }
    cout << endl;
}
#endif // DEBUG

void buildSa(int m) {
    int countKind;                      //统计二元组的种类
    int *_rank = tempArray1;            //_rank[i]为,以str[i]开头的长度为k的子串的排名
                                        //即刘汝佳《算法竞赛入门经典——训练指南》P220中每幅图最下面的一行
                                        //若子串相等,则排名相同

/* *************************************************************************************************** */
    memset(cnt, 0, sizeof(int) * m);
    for(int i = 0; i < n; i++) {        //cnt先暂时用来统计每种字符出现的数量
        cnt[(int)str[i]]++;
    }
    for(int i = 1; i < m; i++) {
        cnt[i] += cnt[i - 1];
    }
    for(int i = n - 1; i >= 0; i--) {   //计算名次为--cnt[str[i]]的子串的起点
        //1 获取字符串中第i个位置的字符
        //2 通过--cnt[str[i]]计算出以str[i]开头的长度为1的子串的排名
        //3 则排名为--cnt[str[i]]的子串的起点下标为i
        suffixArray[--cnt[(int)str[i]]] = i;
    }
    countKind = 0;
    for(int i = 0, preKey = -1; i < n; i++) {
        if(str[suffixArray[i]] == preKey) {
            _rank[suffixArray[i]] = countKind;
        }
        else {
            preKey = str[suffixArray[i]];
            _rank[suffixArray[i]] = ++countKind;
        }
    }
#ifdef DEBUG
    show(_rank);
    show(suffixArray);
#endif // DEBUG
    if(countKind == n) {
        return ;
    }
//到目前为止,都是为了计算出字串长度为1时的排名
//_rank在接下来会用于计算第一和第二关键字
//当前计算出的_rank就是下一次迭代中二元组的第一关键字
/* *************************************************************************************************** */
    for(int k = 1; k <= n; k <<= 1) {
        //每个子串的长度为2 * k
        int *firstKey = _rank;                                                  //上一重迭代中计算出的_rank就是当前迭代中二元组的第一关键字
        int *secondRankIndex = _rank == tempArray1 ? tempArray2 : tempArray1;   //secondRankIndex第二关键字排名第i的二元组的下标
        for(int i = 0; i < k; i++) {
            secondRankIndex[i] = n - k + i;
        }
        for(int i = 0, j = k; i < n; i++) {
            if(suffixArray[i] >= k) {
                secondRankIndex[j++] = suffixArray[i] - k;
            }
        }
        memset(cnt, 0, sizeof(int) * (countKind + 1));
        for(int i = 0; i < n; i++) {
            cnt[firstKey[i]]++;
        }
        for(int i = 1; i <= countKind; i++) {
            cnt[i] += cnt[i - 1];
        }
        for(int i = n - 1; i >= 0; i--) {
            //1 找到第二关键字排名第i的二元组的下标secondRankIndex[i],令index = secondRankIndex[i]
            //2 找到该二元组的第一关键字firstKey[index],令key = firstKey[index]
            //3 求出该二元组的排名--cnt[key]
            suffixArray[--cnt[firstKey[secondRankIndex[i]]]] = secondRankIndex[i];
        }
        //程序执行到这里的时候,_rank中保存的是上一重迭代的计算结果
        //_rank[i]即为firstKey[i]即为第i个二元组的第一关键字
        //当i>=k时,firstKey[i] = secondKey[i - k]即第i-k个二元组的第二关键字
        _rank = secondRankIndex;    //现在secondRankIndex用不上了,就把它的空间拿给_rank用
                                    //但是firstKey保持不变,依然是指向上一重迭代中计算出来的_rank
        int preFirstKey = -1, preSecondKey = -1;
        int curFirstKey, curSecondKey;
        countKind = 0;
        for(int i = 0; i < n; i++) {
            curFirstKey = firstKey[suffixArray[i]];
            curSecondKey = suffixArray[i] >= n - k ? 0 : firstKey[suffixArray[i] + k];
            if(curFirstKey == preFirstKey && curSecondKey == preSecondKey) {
                _rank[suffixArray[i]] = countKind;
            }
            else {
                preFirstKey = curFirstKey;
                preSecondKey = curSecondKey;
                _rank[suffixArray[i]] = ++countKind;
            }
        }
#ifdef DEBUG
        show(_rank);
        show(suffixArray);
#endif // DEBUG
        if(countKind >= n) {
            return ;
        }
    }
}

void getHeight() {
    /*
        通过h[i] >= h[i - 1] - 1线性计算height
        设suffix(k)是排在suffix(i - 1)前一名的后缀
        即Rank[k] = Rank[i - 1] - 1
        则它们的最长公共前缀的长度为h[i - 1]即height[Rank[i - 1]]
        那么suffix(k + 1)排在suffix(i)前面
        (当h[i - 1] <= 1时,原式显然成立,故下面均假设h[i - 1] > 1)
        并且suffix(k + 1)和suffix(i)的最长公共前缀为h[i - 1] - 1
        于是,suffix(i)与它前面一名的最长公共前缀至少为h[i - 1] - 1
    */
    for(int i = 0; i < n; i++) {
        Rank[suffixArray[i]] = i;
    }
    for(int i = 0, j, k = 0; i < n; i++) {
        if(k > 0) {
            k--;
        }
        if(Rank[i] == 0) {
            height[Rank[i]] = 0;
            k = 0;
            continue;
        }
        j = suffixArray[Rank[i] - 1];
        while(i + k < n && j + k < n && str[i + k] == str[j + k]) {
            k++;
        }
        height[Rank[i]] = k;
    }
}

int main() {
    while(cin >> str) {
        n = strlen(str);
        buildSa('z' + 1);
    }
    return 0;
}

回文树(回文自动机)

指针版(性能较数组版低)

//HDU3948
//249ms	22.9MB	2544B	
//计算不同回文子串的数量
//注意使用时字符串下标从1开始
//输入时需要scanf("%s", str + 1);
#include
#include
#include
#include
#include
using namespace std;

class Node {
public:
    enum{SIZE = 26};
    int len;
    int failLen;
    Node* next[SIZE];
    Node* fail;

    Node(int _len = 0, int _failLen = -1, Node *f = NULL) {
        len = _len;
        failLen = _failLen;
        for(int i = 0; i < SIZE; i++) {
            next[i] = NULL;
        }
        fail = len == -1 ? this : f;
    }
};

class PalindromeTree{
public:
    Node* odd;
    Node* even;
    int countNode;

    PalindromeTree() {
        odd = new Node(-1);
        even = new Node(0, -1, odd);
        countNode = 0;
    }

    void build(const char *s) {
        clear();
        Node *p = odd;
        for(int i = 1; s[i]; i++) {
            while(s[i - p->len - 1] != s[i]) {
                p = p->fail;
            }
            int id = s[i] - 'a';
            if(p->next[id] != NULL) {
                p = p->next[id];
            }
            else {
                countNode++;
                Node *temp = p;
                p->next[id] = new Node(p->len + 2, p->failLen + 1);
                p = p->next[id];
                if(temp->len == -1) {
                    p->fail = even;
                }
                else {
                    temp = temp->fail;
                    while(s[i - temp->len - 1] != s[i]) {
                        temp = temp->fail;
                    }
                    p->fail = temp->next[id];
                }
            }
        }
    }

    void clear() {
        queue que;
        que.push(odd);
        que.push(even);
        Node *temp;
        while(!que.empty()) {
            temp = que.front();
            que.pop();
            for(int i = 0; i < Node::SIZE; i++) {
                if(temp->next[i] != NULL) {
                    que.push(temp->next[i]);
                }
            }
            if(temp == odd || temp == even) {
                continue;
            }
            delete temp;
        }
        for(int i = 0; i < Node::SIZE; i++) {
            odd->next[i] = NULL;
            even->next[i] = NULL;
        }
        countNode = 0;
    }

    ~PalindromeTree() {
        clear();
        delete odd;
        delete even;
    }
};

const int maxn = 110000;
char str[maxn];
PalindromeTree pt;

int main() {
    int t;
    int cs = 0;
    scanf("%d", &t);
    while(t--) {
        scanf("%s", str + 1);
        pt.build(str);
        printf("Case #%d: %d\n", ++cs, pt.countNode);
        pt.clear();
    }
    return 0;
}

数组版

//板子验证同上
//78ms	11.3MB	1586B	
#include
#include
#include
#include
#include
#include
using namespace std;

struct Node {
    enum{SIZE = 26};
    int len;
    int next[Node::SIZE];
    int fail;

    void init(int _len = 0, int _fail = -1) {
        len = _len;
        for(int i = 0; i < SIZE; i++) {
            next[i] = -1;
        }
        fail = len == -1 ? 0 : _fail;
    }
};

const int maxn = 100010;
Node tree[maxn];
char str[maxn];
int countNode;

void clear() {
    tree[0].init(-1, 0);
    tree[1].init(0, 0);
    countNode = 2;
}

void build(const char *s) {
    clear();
    int p = 0;
    for(int i = 1; s[i]; i++) {
        while(s[i - tree[p].len - 1] != s[i]) {
            p = tree[p].fail;
        }
        int id = s[i] - 'a';
        if(tree[p].next[id] != -1) {
            p = tree[p].next[id];
        }
        else {
            int temp = p;
            tree[p].next[id] = countNode;
            tree[countNode].init(tree[p].len + 2);
            p = tree[p].next[id];
            if(temp == 0) {
                tree[p].fail = 1;
            }
            else {
                temp = tree[temp].fail;
                while(s[i - tree[temp].len - 1] != s[i]) {
                    temp = tree[temp].fail;
                }
                tree[p].fail = tree[temp].next[id];
            }
            countNode++;
        }
    }
}

int main() {
    int t;
    int cs = 0;
    scanf("%d", &t);
    while(t--) {
        scanf("%s", str + 1);
        build(str);
        printf("Case #%d: %d\n",++cs, countNode - 2);
    }
    return 0;
}

后缀自动机

1、两个串的最长公共子串
2、统计出现次数最多的长度为n的串一共出现了多少次,使用基数排序
3、
4、

两个串的最长公共子串

//两个串的最长公共子串
//https://www.spoj.com/problems/LCS/en/
#include
#include
#include
#include
using namespace std;

const int SIZE = 26;
const int START = int('a');
const int maxn = 260000;
char str1[maxn];
char str2[maxn];
struct State {
    int len;
    int next[SIZE];
    int link;
}st[maxn * 2];
int sz;
int last;

void initSa() {
    sz = 1;
    last = 0;
    st[0].len = 0;
    st[0].link = -1;
    for(int i = 0; i < SIZE; i++) {
        st[0].next[i] = -1;
    }
}

void addChar(char c) {
    int cur = sz++;
    st[cur].len = st[last].len + 1;
    for(int i = 0; i < SIZE; i++) {
        st[cur].next[i] = -1;
    }
    int id = c - START;
    int p = last;
    for(; p != -1 && st[p].next[id] == -1; p = st[p].link) {
        st[p].next[id] = cur;
    }
    if(p == -1) {
        st[cur].link = 0;
    }
    else {
        int q = st[p].next[id];
        if(st[p].len + 1 == st[q].len) {
            st[cur].link = q;
        }
        else {
            int clone = sz++;
            st[clone].len = st[p].len + 1;
            st[clone].link = st[q].link;
            for(int i = 0; i < SIZE; i++) {
                st[clone].next[i] = st[q].next[i];
            }
            for(; p != -1 && st[p].next[id] == q; p = st[p].link) {
                st[p].next[id] = clone;
            }
            st[q].link = st[cur].link = clone;
        }
    }
    last = cur;
}

int lcs() {
    initSa();
    for(int i = 0; str1[i]; i++) {
        addChar(str1[i]);
    }
    int p = 0;
    int ans = 0;
    int curLen = 0;
    int id;
    for(int i = 0; str2[i]; i++) {
        id = str2[i] - START;
        while(p && st[p].next[id] == -1) {
            p = st[p].link;
            curLen = st[p].len;
        }
        if(st[p].next[id] != -1) {
            p = st[p].next[id];
            curLen++;
        }
        ans = max(ans, curLen);
    }
    return ans;
}

int main() {
    while(scanf("%s%s", str1, str2) == 2) {
        printf("%d\n", lcs());
    }
    return 0;
}

多个串的最长公共子串

//LCS2 - Longest Common Substring II 
//https://www.spoj.com/problems/LCS2/en/
#include
#include
#include
#include
#include
using namespace std;

const int maxn = 110000;
const int SIZE = 26;
struct State {
    int len;
    int matcheLen;
    int link;
    int next[SIZE];
}st[maxn * 2];
int sz;
int last;
char str[maxn];
int bucket[maxn];
int Rank[maxn * 2];

void initSa() {
    sz = 1;
    last = 0;
    st[0].len = 0;
    st[0].link = -1;
    st[0].matcheLen = 0;
    for(int i = 0; i < SIZE; i++) {
        st[0].next[i] = -1;
    }
}

void addChar(int c) {
    int cur = sz++;
    st[cur].len = st[last].len + 1;
    st[cur].matcheLen = 0;
    for(int i = 0; i < SIZE; i++) {
        st[cur].next[i] = -1;
    }
    int p = last;
    for(; p != -1 && st[p].next[c] == -1; p = st[p].link) {
        st[p].next[c] = cur;
    }
    if(p == -1) {
        st[cur].link = 0;
    }
    else {
        int q = st[p].next[c];
        if(st[p].len + 1 == st[q].len) {
            st[cur].link = q;
        }
        else {
            int clone = sz++;
            st[clone].len = st[p].len + 1;
            st[clone].link = st[q].link;
            st[clone].matcheLen = 0;
            for(int i = 0; i < SIZE; i++) {
                st[clone].next[i] = st[q].next[i];
            }
            st[q].link = st[cur].link = clone;
            for(; p != -1 && st[p].next[c] == q; p = st[p].link) {
                st[p].next[c] = clone;
            }
        }
    }
    last = cur;
}

void match(const char* s) {
    int cur = 0;
    int curLen = 0;
    int id;
    for(int i = 0; s[i]; i++) {
        id = s[i] - 'a';
        while(cur && st[cur].next[id] == -1) {
            cur = st[cur].link;
            st[cur].matcheLen = curLen = st[cur].len;
        }
        if(st[cur].next[id] != -1) {
            curLen++;
            cur = st[cur].next[id];
            st[cur].matcheLen = max(st[cur].matcheLen, curLen);
        }
    }
    for(int i = 0; i < sz; i++) {
        if(st[Rank[i]].matcheLen) {
            st[st[Rank[i]].link].matcheLen = st[st[Rank[i]].link].len;
        }
        st[Rank[i]].len = min(st[Rank[i]].len, st[Rank[i]].matcheLen);
        st[Rank[i]].matcheLen = 0;
    }
}

void radixSort(int len) {
    for(int i = 0; i < sz; i++) {
        ++bucket[st[i].len];
    }
    for(int i = len - 1; i >= 0; i--) {
        bucket[i] += bucket[i + 1];
    }
    for(int i = 0; i < sz; i++) {
        Rank[--bucket[st[i].len]] = i;
    }
}

int main() {
    //freopen("wa.txt", "r", stdin);
    //freopen("out.txt", "w", stdout);
    initSa();
    scanf("%s", str);
    int len;
    for(len = 0; str[len]; len++) {
        addChar(str[len] - 'a');
    }
    radixSort(len);
    while(scanf("%s", str) == 1) {
        match(str);
    }
    int ans = 0;
    for(int i = 0; i < sz; i++) {
        ans = max(ans, st[i].len);
    }
    printf("%d\n", ans);
    return 0;
}

出现次数查询

//Substrings SPOJ - NSUBSTR 
//https://www.spoj.com/problems/NSUBSTR/en/
//使用基数排序
#include
#include
#include
#include
using namespace std;

const int maxn = 260000;
const int SIZE = 26;
struct {
    int len;
    int next[SIZE];
    int link;
    int cnt;
}st[maxn * 2];
int bucket[maxn];
int Rank[maxn * 2];
int ans[maxn];
char str[maxn];
int sz;
int last;

void initSa() {
    last = 0;
    sz = 1;
    st[0].cnt = 0;
    st[0].len = 0;
    st[0].link = -1;
    for(int i = 0; i < SIZE; i++) {
        st[0].next[i] = -1;
    }
}

void addChar(int c) {
    int cur = sz++;
    st[cur].len = st[last].len + 1;
    st[cur].cnt = 1;
    for(int i = 0; i < SIZE; i++) {
        st[cur].next[i] = -1;
    }
    int p = last;
    for(; p != -1 && st[p].next[c] == -1; p = st[p].link) {
        st[p].next[c] = cur;
    }
    if(p == -1) {
        st[cur].link = 0;
    }
    else {
        int q = st[p].next[c];
        if(st[p].len + 1 == st[q].len) {
            st[cur].link = q;
        }
        else {
            int clone = sz++;
            st[clone].len = st[p].len + 1;
            st[clone].cnt = 0;
            st[clone].link = st[q].link;
            for(int i = 0; i < SIZE; i++) {
                st[clone].next[i] = st[q].next[i];
            }
            st[q].link = st[cur].link = clone;
            for(; p != -1 && st[p].next[c] == q; p = st[p].link) {
                st[p].next[c] = clone;
            }
        }
    }
    last = cur;
}

void radixSort(int len) {
    memset(bucket, 0, sizeof(bucket));
    for(int i = 0; i < sz; i++) {
        bucket[st[i].len]++;
    }
    for(int i = len - 1; i >= 0; i--) {
        bucket[i] += bucket[i + 1];
    }
    for(int i = 0; i < sz; i++) {
        Rank[--bucket[st[i].len]] = i;
    }
}

void solve() {
    initSa();
    int len;
    for(len = 0; str[len]; len++) {
        addChar(str[len] - 'a');
    }
    radixSort(len);
    for(int i = 0; i < sz; i++) {
        st[st[Rank[i]].link].cnt += st[Rank[i]].cnt;
    }
    memset(ans, 0, sizeof(ans));
    for(int i = 0; i < sz; i++) {
        ans[st[i].len] = max(ans[st[i].len], st[i].cnt);
    }
    for(int i = 1; i <= len; i++) {
        printf("%d\n", ans[i]);
    }
}

int main() {
    while(scanf("%s", str) == 1) {
        solve();
    }
    return 0;
}

你可能感兴趣的:(ACM竞赛,模板)