BJOI2019 Day1 简要题解

T1 奥术神杖

二分,AC 自动机。 Θ ( n s ∣ Σ ∣ log ⁡ log ⁡ v ϵ ) \Theta \left(ns|\Sigma|\log \frac{\log v}{\epsilon}\right) Θ(nsΣlogϵlogv)

#include 
#include 
#include 
#include 

#include 
#include 
#include 
#include 
#include 

#define LOG(FMT...) fprintf(stderr, FMT)

using namespace std;

typedef long long ll;

const int N = 1510, A = 10;
double EPS;

int n, m, k;
int trie[N][A], fail[N], pos[N];
vector<int> fch[N];
double x[N], v[N];
double dp[N][N];
pair<int, int> sol[N][N];
char s[N], t[N];

void buildac() {
    queue<int> q;
    q.push(1);
    fail[1] = 1;
    while (!q.empty()) {
        int p = q.front();
        q.pop();
        if (p == 1) {
            for (int i = 0; i < A; ++i)
                if (trie[p][i]) {
                    fail[trie[p][i]] = p;
                    q.push(trie[p][i]);
                } else
                    trie[p][i] = p;
        } else {
            for (int i = 0; i < A; ++i)
                if (trie[p][i]) {
                    fail[trie[p][i]] = trie[fail[p]][i];
                    q.push(trie[p][i]);
                } else
                    trie[p][i] = trie[fail[p]][i];
        }
    }
    for (int i = 2; i <= k; ++i)
        fch[fail[i]].push_back(i);
}

void dfs(int u) {
    for (int i = 0; i < fch[u].size(); ++i) {
        v[fch[u][i]] += v[u];
        dfs(fch[u][i]);
    }
}

double pred(double x) {
    memset(v, 0, sizeof(v));
    for (int i = 1; i <= m; ++i)
        v[pos[i]] += ::x[i] - x;
    dfs(1);
    for (int i = 0; i <= n; ++i)
        fill(dp[i] + 1, dp[i] + k + 1, -1e9);
    dp[0][1] = 0;
    for (int i = 0; i < n; ++i) {
        for (int j = 1; j <= k; ++j) {
            if (s[i + 1] == '.') {
                for (int c = 0; c < A; ++c)
                    dp[i + 1][trie[j][c]] = max(dp[i + 1][trie[j][c]], dp[i][j]);
            } else {
                int c = s[i + 1] - '0';
                dp[i + 1][trie[j][c]] = max(dp[i + 1][trie[j][c]], dp[i][j]);
            }
        }
        for (int j = 1; j <= k; ++j)
            dp[i + 1][j] += v[j];
    }
    return *max_element(dp[n] + 1, dp[n] + k + 1);
}

int main() {
    scanf("%d%d%s", &n, &m, s + 1);
    ++k;
    if (n <= 501)
        EPS = 1e-6;
    else
        EPS = 3e-5;
    for (int i = 1; i <= m; ++i) {
        int v;
        scanf("%s%d", t + 1, &v);
        x[i] = log(v);
        int p = 1;
        for (char* c = t + 1; *c; ++c) {
            int cc = *c - '0';
            if (trie[p][cc])
                p = trie[p][cc];
            else
                p = trie[p][cc] = ++k;
        }
        pos[i] = p;
    }
    buildac();
    double l = 0, r = 22;
    while (r - l > EPS) {
        double mid = (l + r) / 2;
        if (pred(mid) > 0)
            l = mid;
        else
            r = mid;
    }
    
    memset(v, 0, sizeof(v));
    for (int i = 1; i <= m; ++i)
        v[pos[i]] += x[i] - l;
    dfs(1);
    for (int i = 0; i <= n; ++i)
        fill(dp[i] + 1, dp[i] + k + 1, -1e9);
    dp[0][1] = 0;
    for (int i = 0; i < n; ++i) {
        for (int j = 1; j <= k; ++j)
            if (s[i + 1] == '.') {
                for (int c = 0; c < A; ++c) {
                    if (dp[i + 1][trie[j][c]] < dp[i][j]) {
                        dp[i + 1][trie[j][c]] = dp[i][j];
                        sol[i + 1][trie[j][c]] = make_pair(j, c);
                    }
                }
            } else {
                int c = s[i + 1] - '0';
                if (dp[i + 1][trie[j][c]] < dp[i][j]) {
                    dp[i + 1][trie[j][c]] = dp[i][j];
                    sol[i + 1][trie[j][c]] = make_pair(j, c);
                }
            }
        for (int j = 1; j <= k; ++j)
            dp[i + 1][j] += v[j];
    }
    int cur = max_element(dp[n] + 1, dp[n] + k + 1) - dp[n];
    for (int i = n; i; --i) {
        t[i] = '0' + sol[i][cur].second;
        cur = sol[i][cur].first;
    }
    puts(t + 1);
    
    return 0;
}

T2 勘破神机

用通项公式,扩域计算。 Θ ( k 2 log ⁡ r ) \Theta(k^2 \log r) Θ(k2logr)

#include 
#include 
#include 
#include 

#include 
#include 
#include 
#include 

#define LOG(FMT...) fprintf(stderr, FMT)

using namespace std;

typedef long long ll;

const int P = 998244353, K = 510;

int norm(int x) { return x >= P ? x - P : x; }

int mpow(int x, int k) {
    int ret = 1;
    while (k) {
        if (k & 1)
            ret = ret * (ll)x % P;
        k >>= 1;
        x = x * (ll)x % P;
    }
    return ret;
}

#define DEFZ(W) \
struct Z##W {\
    int a, b;\
    \
    Z##W(int a = 0, int b = 0) : a(a), b(b) {}\
    \
    Z##W operator+(const Z##W& rhs) const { return Z##W(norm(a + rhs.a), norm(b + rhs.b)); }\
    Z##W operator-() const { return Z##W(norm(P - a), norm(P - b)); }\
    Z##W operator-(const Z##W& rhs) const { return *this + -rhs; }\
    Z##W operator*(const Z##W& rhs) const { return Z##W((a * (ll)rhs.a + b * (ll)rhs.b * W) % P, (a * (ll)rhs.b + b * (ll)rhs.a) % P); }\
    Z##W inv() const { return Z##W(a, norm(P - b)) * mpow(norm((a * (ll)a - b * (ll)b * W) % P + P), P - 2); }\
    Z##W operator/(const Z##W& rhs) const { return *this * rhs.inv(); }\
    \
    bool operator==(const Z##W& rhs) const { return a == rhs.a && b == rhs.b; }\
};

DEFZ(3)
DEFZ(5)

#define DEFPOW(W) \
Z##W mpow(Z##W x, ll k) {\
    Z##W ret(1, 0);\
    while (k) {\
        if (k & 1)\
            ret = ret * x;\
        k >>= 1;\
        x = x * x;\
    }\
    return ret;\
}

DEFPOW(5)
DEFPOW(3)

int fac[K], ifac[K];
int fall[K][K];

int binom(int n, int m) { return fac[n] * (ll)ifac[m] % P * ifac[n - m] % P; }

void prepare(int k) {
    fac[0] = 1;
    for (int i = 1; i <= k; ++i)
        fac[i] = fac[i - 1] * (ll)i % P;
    ifac[1] = 1;
    for (int i = 2; i <= k; ++i)
        ifac[i] = -(P / i) * (ll)ifac[P % i] % P + P;
    ifac[0] = 1;
    for (int i = 1; i <= k; ++i)
        ifac[i] = ifac[i - 1] * (ll)ifac[i] % P;
    
    fall[0][0] = 1;
    for (int i = 1; i <= k; ++i) {
        int v = norm(P - (i - 1));
        fall[i][0] = fall[i - 1][0] * (ll)v % P;
        for (int j = 1; j <= i; ++j)
            fall[i][j] = (fall[i - 1][j - 1] + fall[i - 1][j] * (ll)v) % P;
    }
}

int g2(ll l, ll r, int k) {
    int I2 = mpow(2, P - 2);
    Z5 a = Z5(I2, I2), b = Z5(I2, P - I2);
    Z5 res = 0;
    for (int j = 0; j <= k; ++j) {
        Z5 cur = mpow(a, j) * mpow(b, k - j), con = (cur == 1) ? ((r - l + 1) % P) : ((mpow(cur, r + 2) - mpow(cur, l + 1)) / (cur - 1));
        if ((k - j) & 1) {
            res = res - con * binom(k, j);
        } else
            res = res + con * binom(k, j);
    }
    return (res / mpow(Z5(0, 1), k)).a;
}

int g3(ll l, ll r, int k) {
    int I2 = mpow(2, P - 2), I6 = mpow(6, P - 2);
    Z3 a = Z3(2, 1), b = Z3(2, P - 1), x = Z3(I2, I6), y = Z3(I2, P - I6);
    Z3 res = 0;
    for (int j = 0; j <= k; ++j) {
        Z3 cur = mpow(a, j) * mpow(b, k - j), con = (cur == 1) ? ((r - l + 1) % P) : ((mpow(cur, r + 1) - mpow(cur, l)) / (cur - 1));
        res = res + con * mpow(x, j) * mpow(y, k - j) * binom(k, j);
    }
    return res.a;
}

int main() {
    prepare(505);
    int t, m;
    scanf("%d%d", &t, &m);
    while (t--) {
        ll l, r;
        int k;
        scanf("%lld%lld%d", &l, &r, &k);
        int in = mpow((r - l + 1) % P, P - 2);
        if (m == 3) {
            r /= 2;
            l = (l + 1) / 2;
            if (l > r) {
                puts("0");
                continue;
            }
        }
        int ans = 0;
        for (int i = 0; i <= k; ++i)
            ans = (ans + fall[k][i] * (ll)(m == 2 ? g2(l, r, i) : g3(l, r, i))) % P;
        ans = ans * (ll)ifac[k] % P;
        ans = ans * (ll)in % P;
        printf("%d\n", ans);
    }
    
    return 0;
}

T3 送别

平衡树维护若干个环路。 Θ ( ( q + n m ) log ⁡ n m ) \Theta((q+nm)\log nm) Θ((q+nm)lognm)

#include 
#include 
#include 
#include 
#include 

#include 
#include 
#include 
#include 

#define LOG(FMT...) fprintf(stderr, FMT)

using namespace std;

typedef long long ll;

struct Node {
    int v;
    bool cyc;
    Node* prt;
    union {
        struct {
            Node *ls, *rs;
        };
        Node* ch[2];
    };

    Node() : v(1) {}

    bool rel() const { return this == prt->rs; }

    void upd() {
        v = 1;
        for (int i = 0; i < 2; ++i)
            if (ch[i])
                v += ch[i]->v;
    }

    void rot() {
        cyc = prt->cyc;
        prt->cyc = false;
        bool f = rel();
        Node* p = prt;
        prt = p->prt;
        if (prt)
            prt->ch[p->rel()] = this;
        p->ch[f] = ch[!f];
        if (ch[!f])
            ch[!f]->prt = p;
        ch[!f] = p;
        p->prt = this;
        p->upd();
        upd();
    }

    void spl(Node* goal = NULL) {
        while (prt != goal) {
            if (prt->prt == goal)
                return rot();
            if (rel() == prt->rel()) {
                prt->rot();
                rot();
            } else {
                rot();
                rot();
            }
        }
    }
};

void link(Node* x, Node* y) {
    x->spl();
    Node* o = y;
    while (o->prt) o = o->prt;
    bool f = o == x;
    y->spl();
    if (f) {
        y->cyc = true;
    } else {
        x->rs = y;
        y->prt = x;
        x->upd();
    }
}

void cut(Node* x) {
    x->spl();
    Node* o = NULL;
    if (x->rs) {
        o = x->rs;
        o->prt = NULL;
        x->rs = NULL;
        x->upd();
    }
    if (x->cyc) {
        x->cyc = false;
        if (o) {
            while (o->rs) o = o->rs;
            o->spl();
            o->rs = x;
            x->prt = o;
            o->upd();
        }
    }
}

void rcut(Node* x) {
    x->spl();
    Node* o = NULL;
    if (x->ls) {
        o = x->ls;
        o->prt = NULL;
        x->ls = NULL;
        x->upd();
    }
    if (x->cyc) {
        x->cyc = false;
        if (o) {
            while (o->ls) o = o->ls;
            o->spl();
            o->ls = x;
            x->prt = o;
            o->upd();
        }
    }
}

int ind(Node* x) {
    x->spl();
    return x->ls ? x->ls->v : 0;
}

int tour(Node* x, Node* y) {
    x->spl();
    Node* o = y;
    while (o->prt) o = o->prt;
    bool f = o == x;
    y->spl();
    if (!f)
        return -1;
    int dy = ind(y);
    int dx = ind(x);
    if (dy >= dx)
        return dy - dx;
    return x->v + dy - dx;
}

Node* pred(Node* x) {
    x->spl();
    if (!x->ls) {
        if (!x->rs)
            return x;
        Node* o = x->rs;
        while (o->rs) o = o->rs;
        return o->spl(), o;
    }
    Node* o = x->ls;
    while (o->rs) o = o->rs;
    return o->spl(), o;
}

Node* succ(Node* x) {
    x->spl();
    if (!x->rs) {
        if (!x->ls)
            return x;
        Node* o = x->ls;
        while (o->ls) o = o->ls;
        return o->spl(), o;
    }
    Node* o = x->rs;
    while (o->ls) o = o->ls;
    return o->spl(), o;
}

const int N = 510;

int n, m;
bool col[N][N], row[N][N];
Node o[N][N][4];

#define X first.first
#define Y first.second
#define D second

typedef pair<pair<int, int>, int> Pos;

Pos bpos(int x, int y, int d) { return make_pair(make_pair(x, y), d); }

Pos gpos(int x1, int y1, int x2, int y2, int d) {
    if (x1 == x2) {
        if (y1 > y2)
            swap(y1, y2);
        return bpos(x1, y1, 1 + (1 - d) * 2);
    } else {
        if (x1 > x2)
            swap(x1, x2);
        return bpos(x1, y1, (1 - d) << 1);
    }
}

Pos go(const Pos& p) {
    if (p.D == 0) {
        if (row[p.X][p.Y])
            return bpos(p.X, p.Y, 1);
        if (p.X > 0 && col[p.X - 1][p.Y])
            return bpos(p.X - 1, p.Y, 0);
        if (p.Y > 0 && row[p.X][p.Y - 1])
            return bpos(p.X, p.Y - 1, 3);
        return bpos(p.X, p.Y, 2);
    } else if (p.D == 1) {
        if (col[p.X][p.Y + 1])
            return bpos(p.X, p.Y + 1, 2);
        if (row[p.X][p.Y + 1])
            return bpos(p.X, p.Y + 1, 1);
        if (p.X > 0 && col[p.X - 1][p.Y + 1])
            return bpos(p.X - 1, p.Y + 1, 0);
        return bpos(p.X, p.Y, 3);
    } else if (p.D == 2) {
        if (p.Y > 0 && row[p.X + 1][p.Y - 1])
            return bpos(p.X + 1, p.Y - 1, 3);
        if (col[p.X + 1][p.Y])
            return bpos(p.X + 1, p.Y, 2);
        if (row[p.X + 1][p.Y])
            return bpos(p.X + 1, p.Y, 1);
        return bpos(p.X, p.Y, 0);
    } else {
        if (p.X > 0 && col[p.X - 1][p.Y])
            return bpos(p.X - 1, p.Y, 0);
        if (p.Y > 0 && row[p.X][p.Y - 1])
            return bpos(p.X, p.Y - 1, 3);
        if (col[p.X][p.Y])
            return bpos(p.X, p.Y, 2);
        return bpos(p.X, p.Y, 1);
    }
}

Pos rgo(const Pos& p) {
    for (int i = -1; i <= 1; ++i)
        for (int j = -1; j <= 1; ++j)
            if (p.X + i >= 0 && p.Y + j >= 0) {
                for (int d = 0; d < 4; ++d) {
                    if (((d & 1) ? row : col)[p.X + i][p.Y + j])
                        if (go(bpos(p.X + i, p.Y + j, d)) == p)
                            return bpos(p.X + i, p.Y + j, d);
                }
            }
}

Node* geto(const Pos& p) { return &o[p.X][p.Y][p.D]; }

void ins(int x, int y, int c) {
    if (c == 1)
        row[x][y] = true;
    else
        col[x][y] = true;
    Pos p1 = bpos(x, y, c), p2 = bpos(x, y, c + 2);
    Node *o1 = geto(p1), *o2 = geto(p2);
    if (go(p1) != p2) {
        rcut(geto(go(p1)));
        cut(geto(rgo(p2)));
    }
    if (go(p2) != p1) {
        rcut(geto(go(p2)));
        cut(geto(rgo(p1)));
    }
    if (go(p1) != p2) {
        link(o1, geto(go(p1)));
        link(geto(rgo(p2)), o2);
    } else
        link(o1, o2);
    if (go(p2) != p1) {
        link(o2, geto(go(p2)));
        link(geto(rgo(p1)), o1);
    } else
        link(o2, o1);
}

void rmv(int x, int y, int c) {
    Pos p1 = bpos(x, y, c), p2 = bpos(x, y, c + 2);
    Node *o1 = geto(p1), *o2 = geto(p2);
    cut(o1);
    rcut(o1);
    cut(o2);
    rcut(o2);
    if (go(p1) != p2)
        link(geto(rgo(p2)), geto(go(p1)));
    if (go(p2) != p1)
        link(geto(rgo(p1)), geto(go(p2)));
    if (c == 1)
        row[x][y] = false;
    else
        col[x][y] = false;
}

int main() {
    int q;
    scanf("%d%d%d", &n, &m, &q);
    for (int i = 0; i < m; ++i) {
        ins(0, i, 1);
        ins(n, i, 1);
    }
    for (int i = 0; i < n; ++i) {
        ins(i, 0, 0);
        ins(i, m, 0);
    }
    for (int i = 0; i < n; ++i)
        for (int j = 1; j < m; ++j) {
            int v;
            scanf("%d", &v);
            if (v)
                ins(i, j, 0);
        }
    for (int i = 1; i < n; ++i)
        for (int j = 0; j < m; ++j) {
            int v;
            scanf("%d", &v);
            if (v)
                ins(i, j, 1);
        }
    while (q--) {
        int opt, x1, y1, x2, y2;
        scanf("%d%d%d%d%d", &opt, &x1, &y1, &x2, &y2);
        if (opt == 1) {
            if (x1 == x2)
                ins(x1, min(y1, y2), 1);
            else
                ins(min(x1, x2), y1, 0);
        } else if (opt == 2) {
            if (x1 == x2)
                rmv(x1, min(y1, y2), 1);
            else
                rmv(min(x1, x2), y1, 0);
        } else {
            int d1, x3, y3, x4, y4, d2;
            scanf("%d%d%d%d%d%d", &d1, &x3, &y3, &x4, &y4, &d2);
            Pos p1 = gpos(x1, y1, x2, y2, d1), p2 = gpos(x3, y3, x4, y4, d2);
            printf("%d\n", tour(geto(p1), geto(p2)));
        }
    }
    return 0;
}

你可能感兴趣的:(题集/比赛题解)