广义后缀自动机是建在 T r i e Trie Trie 树上的后缀自动机,和加特殊字符拼接相比好像就是空间上的优化? 实现上就是每次加入新的模式串的时候,将 l a s t last last 结点重置为 r o o t root root。好像也没啥特殊的。
sol: 寻找在各种置换下本质不同的子串个数。
由于字符集只有 a b c abc abc,可行的置换只有 3 ! 3! 3!种。对原始串在这6种置换下的表示建广义后缀自动机。发现对于只有一种字符的子串,每个重复了3种,这种子串的个数可以通过求最长连续同色子串的长度得到,如 a a a aaa aaa, c c c c c c c ccccccc ccccccc。其余的子串则重复了 3 ! 3! 3!次。则最终的答案为 ( 所 有 子 串 个 数 + 同 字 符 子 串 个 数 ∗ 3 ) / 6 (所有子串个数 + 同字符子串个数 * 3)/6 (所有子串个数+同字符子串个数∗3)/6.
code:
#include
using namespace std;
typedef long long ll;
const int maxn = 8e5+5;
const int s_sz = 4;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;
#define fi first
#define se second
#define MP make_pair
#define pii pair
int pos[maxn],sum[maxn],tmp[maxn];
int maxx,cur;
string ss[10];
char s1[maxn],s2[maxn];
void init(){
cur = 0;
ss[0] = "abc";
do{
ss[++cur] = ss[0];
}while(next_permutation(ss[0].begin(),ss[0].end()));
}
struct SAM{
int ch[maxn][s_sz];
int rt,sz,last;
int len[maxn],suf[maxn],r[maxn];
ll val[maxn],pre[maxn];
void init(){
memset(ch,0,sizeof(ch[0]) * (sz+1));
memset(suf,0,sizeof(int)*(sz+1));
memset(r,0,sizeof(int)*(sz+1));
rt = sz = last = 1;
}
inline void add(int x,int c){
int p = last,np = ++sz;
last = np;
len[np] = x;
while(p && !ch[p][c]){
ch[p][c] = np;
p = suf[p];
}
if(!p){
suf[np] = rt;
return;
}
int q = ch[p][c];
if(len[q] == len[p] + 1) suf[np] = q;
else{
int nq = ++ sz;
len[nq] = len[p] + 1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
suf[nq] = suf[q];
suf[np] = suf[q] = nq;
while(ch[p][c] == q){
ch[p][c] = nq;
p = suf[p];
}
}
}
inline int idx(char c){
return c - 'a';
}
inline void build(char* s){
last = rt;
int n = strlen(s);
for(int i = 0;i<n;i++){
add(i+1,idx(s[i]));
}
}
inline void Topsort(int n){
memset(sum,0,sizeof(int)*(n+1));
for(int i = 1;i<=sz;i++) sum[len[i]] ++ ;
for(int i = 1;i<=n;i++) sum[i] += sum[i-1];
for(int i = 1;i<=sz;i++) tmp[sum[len[i]]--] = i;
}
inline void get_right(){
for(int i = sz;i;i--){
int u = tmp[i];
if(suf[u]) r[suf[u]] += r[u];
}
}
inline ll Query(){
ll ret = 0;
for(int i = rt + 1;i<=sz;i++){
ret += len[i] - len[suf[i]];
}
return ret;
}
}sam;
int main(){
int n;
init();
while(~scanf("%d",&n)){
scanf("%s",s1);
s2[n] = '\0';
sam.init();
for(int i = 1;i<=cur;i++){
for(int j = 0;j<n;j++){
int c = s1[j] - 'a';
s2[j] = ss[i][c];
}
sam.build(s2);
}
sam.Topsort(n);
sam.get_right();
ll ans = sam.Query();
int res = 1;
int maxx = 1;
for(int i = 1;i<n;i++){
if(s1[i] == s1[i-1]) res++;
else res = 1;
maxx = max(maxx,res);
}
ans += maxx * 3;
ans /= 6;
printf("%lld\n",ans);
}
return 0;
}
sol:求有n个大串和m个询问,每次给出一个字符串s询问在多少个大串中出现过
广义自动机裸题。。。 记录每个状态在几个大串中出现即可。
code:
#include
using namespace std;
typedef long long ll;
const int maxn = 4e5+5;
const int s_sz = 26;
const int inf = 0x3f3f3f3f;
#define fi first
#define se second
#define MP make_pair
#define pii pair
int sum[maxn],tmp[maxn],pos[maxn];
char str[maxn];
char s1[maxn];
ll f[maxn];
struct SAM{
int ch[maxn][s_sz];
int rt,sz,last;
int len[maxn],suf[maxn],r[maxn];
int cnt[maxn],pre[maxn];
void init(){
memset(ch,0,sizeof(ch[0]) * (sz+1));
memset(suf,0,sizeof(int)*(sz+1));
memset(r,0,sizeof(int)*(sz+1));
rt = sz = last = 1;
}
inline void add(int x,int c){
int p = last,np = ++sz;
last = np;
len[np] = x;
while(p && !ch[p][c]){
ch[p][c] = np;
p = suf[p];
}
if(!p){
suf[np] = rt;
return;
}
int q = ch[p][c];
if(len[q] == len[p] + 1) suf[np] = q;
else{
int nq = ++ sz;
len[nq] = len[p] + 1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
suf[nq] = suf[q];
suf[np] = suf[q] = nq;
while(ch[p][c] == q){
ch[p][c] = nq;
p = suf[p];
}
}
}
inline int idx(char c){
return c - 'a';
}
inline void build(char* s){
last = rt;
int n = strlen(s);
for(int i = 0;i<n;i++){
add(i+1,idx(s[i]));
}
}
inline void Topsort(int n){
memset(sum,0,sizeof(int)*(n+1));
for(int i = 1;i<=sz;i++) sum[len[i]] ++ ;
for(int i = 1;i<=n;i++) sum[i] += sum[i-1];
for(int i = 1;i<=sz;i++) tmp[sum[len[i]]--] = i;
}
inline void work(char* s,int cur){
int n = strlen(s);
int u = rt;
for(int i = 0;i<n;i++){
u = ch[u][idx(s[i])];
int fp = u;
while(fp && pre[fp] != cur){
cnt[fp] ++;
pre[fp] = cur;
fp = suf[fp];
}
}
}
inline ll sol(char* s){
ll ret = 0;
int n = strlen(s);
int u = rt;
for(int i = 0;i<n;i++){
u = ch[u][idx(s[i])];
if(!u) return 0;
}
return cnt[u];
}
}sam;
int main(){
int n,m;
scanf("%d%d",&n,&m);
int Last = 0;
pos[0] = Last;
for(int i = 1;i<=n;i++){
scanf("%s",str+Last);
int len = strlen(str+Last);
Last += len;
str[Last] = '\0';
Last++;
pos[i+1] = Last;
}
sam.init();
for(int i = 1;i<=n;i++) {
sam.build(str+pos[i]);
}
for(int i = 1;i<=n;i++) sam.work(str+pos[i],i);
while(m--){
scanf("%s",s1);
ll ans = sam.sol(s1);
printf("%lld\n",ans);
}
return 0;
}
sol:给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?
对每个状态维护 c n t cnt cnt 表示是几个串的子串和 p r e pre pre 表示上个更新的串是哪个。每插入一个新串,对所有经过的节点,沿 p a r e n t parent parent树暴力向上更新,直到遇到第一个已经被当前串更新的结点为止。这个复杂度不是很懂怎么算啊。
建完广义后缀自动机后,对每个结点 s s s 沿拓扑序计算 s s s 和 s s s 所有祖先的贡献。对每个串,在自动机上扫一遍,把贡献累加起来就行。
upd: 这个题还有一个启发式合并的做法。每个节点维护一个 s e t set set存包含这个子串的模式串的标号 i d ∈ [ 1 , n ] id\in[1,n] id∈[1,n],建好 p a r e n t parent parent树之后从根开始 d f s dfs dfs,启发式合并即可。由于每个节点只有一个父亲,我们只需要 s e t set set集合的大小(用一个数组另外存),而每个节点的 s e t set set在被父节点调用时信息是完整的,所以统计的答案是正确的。
code:
#include
using namespace std;
typedef long long ll;
const int maxn = 4e5+5;
const int s_sz = 26;
const int inf = 0x3f3f3f3f;
#define fi first
#define se second
#define MP make_pair
#define pii pair
int sum[maxn],tmp[maxn],pos[maxn];
char str[maxn];
ll f[maxn];
struct SAM{
int ch[maxn][s_sz];
int rt,sz,last;
int len[maxn],suf[maxn],r[maxn];
int cnt[maxn],pre[maxn];
void init(){
memset(ch,0,sizeof(ch[0]) * (sz+1));
memset(suf,0,sizeof(int)*(sz+1));
memset(r,0,sizeof(int)*(sz+1));
rt = sz = last = 1;
}
inline void add(int x,int c){
int p = last,np = ++sz;
last = np;
len[np] = x;
while(p && !ch[p][c]){
ch[p][c] = np;
p = suf[p];
}
if(!p){
suf[np] = rt;
return;
}
int q = ch[p][c];
if(len[q] == len[p] + 1) suf[np] = q;
else{
int nq = ++ sz;
len[nq] = len[p] + 1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
suf[nq] = suf[q];
suf[np] = suf[q] = nq;
while(ch[p][c] == q){
ch[p][c] = nq;
p = suf[p];
}
}
}
inline int idx(char c){
return c - 'a';
}
inline void build(char* s){
last = rt;
int n = strlen(s);
for(int i = 0;i<n;i++){
add(i+1,idx(s[i]));
}
}
inline void Topsort(int n){
memset(sum,0,sizeof(int)*(n+1));
for(int i = 1;i<=sz;i++) sum[len[i]] ++ ;
for(int i = 1;i<=n;i++) sum[i] += sum[i-1];
for(int i = 1;i<=sz;i++) tmp[sum[len[i]]--] = i;
}
inline void work(char* s,int cur){
int n = strlen(s);
int u = rt;
for(int i = 0;i<n;i++){
u = ch[u][idx(s[i])];
int fp = u;
while(fp && pre[fp] != cur){
cnt[fp] ++;
pre[fp] = cur;
fp = suf[fp];
}
}
}
inline void get_f(int k){
memset(f,0,sizeof(int)*(sz+1));
Topsort(1e5+5);
f[rt] = 0;
for(int i = 1;i<=sz;i++){
int u = tmp[i];
if(cnt[u]>=k) f[u] = len[u] - len[suf[u]];
f[u] += f[suf[u]];
}
}
inline ll sol(char* s,int k){
ll ret = 0;
int n = strlen(s);
int u = rt;
for(int i = 0;i<n;i++){
u = ch[u][idx(s[i])];
ret += f[u];
}
return ret;
}
}sam;
int main(){
int n,k;
scanf("%d%d",&n,&k);
int Last = 0;
pos[0] = Last;
for(int i = 1;i<=n;i++){
scanf("%s",str+Last);
int len = strlen(str+Last);
Last += len;
str[Last] = '\0';
Last++;
pos[i+1] = Last;
}
sam.init();
for(int i = 1;i<=n;i++) {
sam.build(str+pos[i]);
}
for(int i = 1;i<=n;i++) sam.work(str+pos[i],i);
sam.get_f(k);
for(int i = 1;i<=n;i++) {
if(i>1) printf(" ");
printf("%lld",sam.sol(str+pos[i],k));
}
return 0;
}
sol: 给定一颗树,树上每个结点对应一个字符,每个路径对应一个字符串。问所有可能的子串有多少种。
叶子结点最多只有20个。考虑从叶子为起点开始进行爆搜,将路径上所有的子串插入到后缀自动机中。显然不可能将子串全部生成后再插入,考虑在搜索的同时维护上个生成的串对应的结点,那么新的结点就直接在上个结点的基础上扩展。建出后就是模板操作,直接统计即可。
code:
#include
using namespace std;
typedef long long ll;
const int maxn = 4e6+5;
const int s_sz = 10;
const int inf = 0x3f3f3f3f;
#define fi first
#define se second
#define MP make_pair
#define pii pair
int sum[maxn],tmp[maxn];
int sta[maxn],top;
struct SAM{
int ch[maxn][s_sz];
int rt,sz,last;
int len[maxn],suf[maxn],r[maxn];
void init(){
memset(ch,0,sizeof(ch[0]) * (sz+1));
memset(suf,0,sizeof(int)*(sz+1));
memset(r,0,sizeof(int)*(sz+1));
rt = sz = last = 1;
}
inline int add(int pre,int x,int c){
int p = pre,np = ++sz;
last = np;
len[np] = x;
while(p && !ch[p][c]){
ch[p][c] = np;
p = suf[p];
}
if(!p){
suf[np] = rt;
return last;
}
int q = ch[p][c];
if(len[q] == len[p] + 1) suf[np] = q;
else{
int nq = ++ sz;
len[nq] = len[p] + 1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
suf[nq] = suf[q];
suf[np] = suf[q] = nq;
while(ch[p][c] == q){
ch[p][c] = nq;
p = suf[p];
}
}
return last;
}
inline int idx(char c){
return c - 'A';
}
inline void Topsort(int n){
memset(sum,0,sizeof(int)*(n+1));
for(int i = 1;i<=sz;i++) sum[len[i]] ++ ;
for(int i = 1;i<=n;i++) sum[i] += sum[i-1];
for(int i = 1;i<=sz;i++) tmp[sum[len[i]]--] = i;
}
inline void get_right(char* s){
int u = rt;
int n = strlen(s);
for(int i = 0;i<n;i++){
u = ch[u][idx(s[i])];
r[u] = 1;
}
for(int i = sz;i;i--){
int u = tmp[i];
r[suf[u]] += r[u];
}
}
inline ll work(){
ll ret = 0;
for(int i = rt+1;i<=sz;i++){
ret += len[i] - len[suf[i]];
}
return ret;
}
}sam;
int col[maxn];
vector<int> G[maxn];
void DFS(int sta,int u,int fa,int L){
// cout<
int newsta = sam.add(sta,L,col[u]);
for(int i = 0;i<G[u].size();i++){
int v = G[u][i];
if(v == fa) continue;
DFS(newsta,v,u,L+1);
}
}
int main(){
int n;
int coll;
scanf("%d%d",&n,&coll);
for(int i = 1;i<=n;i++) scanf("%d",&col[i]);
for(int i = 1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
sam.init();
for(int i = 1;i<=n;i++){
if(G[i].size()==1) {
DFS(sam.rt,i,0,1);
}
}
ll ans = sam.work();
printf("%lld\n",ans);
return 0;
}
sol: 给一堆模式串和对应的权值。一个子串的权值是它在其中出现过的模式串权值的积,问不超过L的串长的权值的期望。
用类似上上个题的做法维护每个状态对应的权值。注意到一个状态 s t a sta sta 表示的子串是以 r i g h t right right集合为最后一个字符,长度在 ( l e n [ s u f [ s t a ] ] , l e n [ s t a ] ] ( len[suf[sta]],len[sta] ] (len[suf[sta]],len[sta]]之间的后缀。考虑用树状数组去维护差分数组,就可以快速地求出长度为 L L L的子串权值之和。求出每个长度恰好为 L L L的子串贡献之后,再正向递推求出小于等于 L L L的子串的权值之和。长度不超过 L L L的子串一共有 ∑ i = 1 L 2 6 i \sum_{i=1}^{L} 26^i ∑i=1L26i,等比数列求和搞一下就行。upd: 直接线性维护差分数组即可,因为只有最后一次查询,求两遍前缀和即可。用
树状数组的我怕是石乐志。
code:
#include
using namespace std;
typedef long long ll;
const int maxn = 4e5+5;
const int s_sz = 26;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;
#define fi first
#define se second
#define MP make_pair
#define pii pair
int pos[maxn];
char str[maxn];
ll h[maxn],A[maxn];
ll f[1000000+10];
int maxx;
void Mul(ll& x,ll y){
x *= y;
if(x>=mod) x%=mod;
}
void Add(ll& x,ll y){
x += y;
if(x>=mod) x%=mod;
}
inline int lowbit(int x){ return x&-x; }
void Modify(int x,ll p){
for(int j = x;j<=maxx;j+=lowbit(j)){
Add(A[j],p);
}
}
ll Query(int x){
ll ret = 0;
for(int j = x;j;j-=lowbit(j)){
Add(ret,A[j]);
}
return ret;
}
struct SAM{
int ch[maxn][s_sz];
int rt,sz,last;
int len[maxn],suf[maxn],r[maxn];
ll val[maxn],pre[maxn];
void init(){
memset(ch,0,sizeof(ch[0]) * (sz+1));
memset(suf,0,sizeof(int)*(sz+1));
memset(r,0,sizeof(int)*(sz+1));
rt = sz = last = 1;
}
inline void add(int x,int c){
int p = last,np = ++sz;
last = np;
len[np] = x;
while(p && !ch[p][c]){
ch[p][c] = np;
p = suf[p];
}
if(!p){
suf[np] = rt;
return;
}
int q = ch[p][c];
if(len[q] == len[p] + 1) suf[np] = q;
else{
int nq = ++ sz;
len[nq] = len[p] + 1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
suf[nq] = suf[q];
suf[np] = suf[q] = nq;
while(ch[p][c] == q){
ch[p][c] = nq;
p = suf[p];
}
}
}
inline int idx(char c){
return c - 'a';
}
inline void build(char* s){
last = rt;
int n = strlen(s);
for(int i = 0;i<n;i++){
add(i+1,idx(s[i]));
}
}
inline void work(char* s,int cur){
int n = strlen(s);
int u = rt;
for(int i = 0;i<n;i++){
u = ch[u][idx(s[i])];
int fp = u;
while(fp && pre[fp] != cur){
Mul(val[fp],h[cur]);
pre[fp] = cur;
fp = suf[fp];
}
}
}
inline void sol(){
for(int i = rt+1;i<=sz;i++){
ll tmp = val[i];
Modify(len[suf[i]]+1,tmp);
Modify(len[i]+1,(mod - tmp) % mod);
}
for(int i = 1;i<=maxx;i++){
f[i] = Query(i);
Add(f[i],f[i-1]);
}
}
}sam;
ll qpow(ll a,ll b){
ll ret = 1;
while(b){
if(b&1) Mul(ret,a);
Mul(a,a);
b>>=1;
}
return ret;
}
ll Inv(ll n){
return qpow(n,mod-2);
}
int main(){
int n;
scanf("%d",&n);
int Last = 0;
pos[0] = Last;
maxx = 0;
for(int i = 1;i<=n;i++){
scanf("%s",str+Last);
int len = strlen(str+Last);
Last += len;
str[Last] = '\0';
Last++;
pos[i+1] = Last;
maxx = max(Last - pos[i] + 5,maxx);
}
for(int i = 1;i<=n;i++) scanf("%lld",&h[i]);
sam.init();
for(int i = 1;i<=n;i++) {
sam.build(str+pos[i]);
}
for(int i = 1;i<=sam.sz;i++) {
sam.val[i] = 1;
}
for(int i = 1;i<=n;i++) {
sam.work(str+pos[i],i);
}
sam.sol();
int m;
scanf("%d",&m);
int invv = Inv(25);
while(m--){
int L;
scanf("%d",&L);
ll ans =(qpow(26,L)-1) * 26 % mod;
Mul(ans,invv);
ans = Inv(ans);
Mul(ans,f[L]);
printf("%lld\n",ans);
}
return 0;
}