2020牛客多校第二场A题 All with Pairs Hash+KMP

All with Pairs

题意

f ( s , t ) f(s,t) f(s,t)为最大的 i i i使得 s 1... i = t ∣ t ∣ − i + 1... ∣ t ∣ s_{1...i} =t_{\left|t\right|-i+1...\left|t\right|} s1...i=tti+1...t
n n n个串 s 1 , s 2 , . . . , s n s_1,s_2,...,s_n s1,s2,...,sn,求 ∑ i = 1 n ∑ j = 1 n f ( s i , s j ) 2 \displaystyle\sum_{i = 1} ^ n\displaystyle\sum_{j = 1} ^ n f(s_i,s_j)^2 i=1nj=1nf(si,sj)2

题解

统计所有串每一个后缀出现次数,这个可以用哈希来实现

map<ull, int> mp;
void insert(string &s) {
    ull hash = 0, b = 1;//unsigned long long就可以自然溢出
    for (int i = s.length() - 1; i >= 0; i--, b *= base) {
        hash += b * (s[i] - 'a' + 1);
        mp[hash]++;
    }
}

对于一个串 s s s来说,记 c n t [ i ] cnt[i] cnt[i]为所有串中后缀等于 s 1... i s_{1...i} s1...i的数量
那么串 s s s的贡献就是 ∑ i = 1 ∣ s ∣ i 2 c n t [ i ] \displaystyle\sum_{i=1}^{|s|}i^2cnt[i] i=1si2cnt[i]

这里有一个问题,如果我们按上面 H a s h Hash Hash的方法来求得所有串每一个后缀出现次数
会有重复计算,如 a b a aba aba的后缀有 a , b a , a b a a,ba,aba a,ba,aba, 如果我们当前串 s s s能够匹配的最大长度为 3 3 3,即匹配的串为 a b a aba aba时, a a a这个串也一定是匹配的,所以要对 c n t cnt cnt进行容斥
容斥只需要从前往后令 c n t [ n e x t [ i ] ] = c n t [ n e x t [ i ] ] − c n t [ i ] cnt[next[i]] = cnt[next[i]]-cnt[i] cnt[next[i]]=cnt[next[i]]cnt[i]即可

因为如果从后往前容斥就要不断跳 n e x t next next数组,前面每一个重复的串都要减去 c n t [ i ] cnt[i] cnt[i]的贡献
但是正向进行就只用减一次,因为 c n t [ n e x t [ i ] ] = c n t [ n e x t [ i ] ] − c n t [ i ] cnt[next[i]] = cnt[next[i]]-cnt[i] cnt[next[i]]=cnt[next[i]]cnt[i]中, c n t [ i ] cnt[i] cnt[i]本就包含了后面重复的贡献

代码

#include
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int MAX = 1e5 + 10;
const int base = 233;
const int mod = 998244353;

vector<int> getNext(string &s) {
    int n = s.length();
    vector<int> nxt(n);
    for (int i = 1; i < n; i++) {
        int j = nxt[i - 1];
        while (j > 0 && s[i] != s[j]) j = nxt[j - 1];
        if (s[i] == s[j]) j++;
        nxt[i] = j;
    }
    return nxt;
}

map<ull, int> mp;
void insert(string &s) {
    ull hash = 0, b = 1;//unsigned long long就可以自然溢出
    for (int i = s.length() - 1; i >= 0; i--, b *= base) {
        hash += b * (s[i] - 'a' + 1);
        mp[hash]++;
    }
}

int N;
string s[MAX];
int cnt[MAX * 10];

int main() {

    cin >> N;
    for (int i = 1; i <= N; i++) {
        cin >> s[i];
        insert(s[i]);
    }
    ll ans = 0;
    for (int i = 1; i <= N; i++) {
        vector<int> nxt = getNext(s[i]);
        ull hash = 0;
        for (int j = 0; j < s[i].length(); j++) {
            hash = hash * base + s[i][j] - 'a' + 1;
            cnt[nxt[j]] -= (cnt[j + 1] = mp[hash]);
            //我这里j是从0开始的, 但是cnt数组是从1开始的,有点不一样
        }
        for (int j = 1; j <= s[i].length(); j++)
            ans = (ans + 1ll * cnt[j] * j % mod * j % mod) % mod;
    }
    printf("%lld\n", ans);


    return 0;
}

你可能感兴趣的:(字符串,#,KMP,#,Hash)