第一道独立解决的Div1F
,嘿嘿,幸好没看题解
把串分为以下几类
不包含star
的串
太简单,略
star
在最前面的串
略
star
在最后面的串
略
单独一个star
答案++
单独一个空串
答案++
star
在中间的串
注意到,假设star
的位置是pos
,实际上相当于选择一个右端点为pos-1
的串s1
,再选择一个左端点为pos+1
的串s2
,问这样的pair(s1,s2)
有多少个
也就是选两个原串的子串,并且这两个子串要满足上面那个条件,问方案数
对原串的前n-2
个字符建SAM
,称为sam
对原串的后n-2
个字符倒过来再建SAM
,称为rsam
注意到,SAM
上每个点表示的本质不同子串数量是len[u] - len[pa[u]]
,其中len
是点u
所表示字符串的最长长度,pa
是点u
在后缀树上的父亲,记这个值为val[u]
也就是说,问题变成了:枚举sam
里面的一个点u
,枚举rsam
里面的一个点v
,如果v
的end_pos
集合存在一个数字,等于u
的end_pos
集合里面的某个数字+2
,那么ans += val[u] * val[v]
考虑sam
里面的每个点u
,假设u
的end_pos
集合是{a1, a2, a3, ..., ak}
,那么在rsam
里面,有哪些点可以和u
产生贡献?所有end_pos
集合包含某一个ai+2
的点可以和u
产生贡献,这在rsam
的后缀树上,是k
条树链的并
在sam
的后缀树上跑DSU on Tree
,维护上述end_pos
集合,并时刻维护集合中所有点的树链的并
总复杂度两个log
#include
using namespace std;
typedef long long ll;
const int N = 200010;
int _w;
struct SAM {
int ch[N][26];
int len[N];
int pa[N];
int idx;
void init() {
memset(ch, 0, sizeof ch);
memset(len, 0, sizeof len);
memset(pa, 0, sizeof pa);
idx = 1;
pa[0] = -1;
}
int append( int p, int c ) {
int np = idx++;
len[np] = len[p] + 1;
while( p != -1 && !ch[p][c] )
ch[p][c] = np, p = pa[p];
if( p == -1 ) pa[np] = 0;
else {
int q = ch[p][c];
if( len[q] == len[p] + 1 ) pa[np] = q;
else {
int nq = idx++;
memcpy(ch[nq], ch[q], sizeof ch[nq]);
len[nq] = len[p] + 1;
pa[nq] = pa[q];
pa[q] = pa[np] = nq;
while( p != -1 && ch[p][c] == q )
ch[p][c] = nq, p = pa[p];
}
}
return np;
}
};
int n;
char str[N];
SAM sam, rsam;
ll solve_origin() {
sam.init();
int last = 0;
for( int i = 1; i <= n; ++i )
last = sam.append(last, str[i] - 'a');
ll ans = 0;
for( int i = 1; i < sam.idx; ++i )
ans += sam.len[i] - sam.len[sam.pa[i]];
return ans;
}
ll solve_before() {
sam.init();
int last = 0;
for( int i = 2; i <= n; ++i )
last = sam.append(last, str[i] - 'a');
ll ans = 0;
for( int i = 1; i < sam.idx; ++i )
ans += sam.len[i] - sam.len[sam.pa[i]];
return ans;
}
ll solve_after() {
sam.init();
int last = 0;
for( int i = 1; i <= n-1; ++i )
last = sam.append(last, str[i] - 'a');
ll ans = 0;
for( int i = 1; i < sam.idx; ++i )
ans += sam.len[i] - sam.len[sam.pa[i]];
return ans;
}
struct Graph {
int head[N], nxt[N], to[N], eid;
void init() {
eid = 0;
memset(head, -1, sizeof head);
}
void link( int u, int v ) {
to[eid] = v, nxt[eid] = head[u], head[u] = eid++;
}
};
Graph g, rg;
namespace HLD {
int dfn[N], dfnc, top[N], dep[N];
int pa[N], sz[N], son[N], val[N];
int rdfn[N];
void dfs1( int u, int fa, int d ) {
sz[u] = 1, dep[u] = d, pa[u] = fa;
val[u] = rsam.len[u];
for( int i = rg.head[u]; ~i; i = rg.nxt[i] ) {
int v = rg.to[i];
dfs1(v, u, d+1);
sz[u] += sz[v];
if( son[u] == -1 || sz[v] > sz[son[u]] )
son[u] = v;
}
}
void dfs2( int u, int tp ) {
dfn[u] = ++dfnc, top[u] = tp;
rdfn[dfnc] = u;
if( son[u] != -1 )
dfs2( son[u], tp );
for( int i = rg.head[u]; ~i; i = rg.nxt[i] ) {
int v = rg.to[i];
if( v != son[u] )
dfs2(v, v);
}
}
void init() {
memset(son, -1, sizeof son);
dfs1(0, -1, 1);
dfs2(0, 0);
}
int lca( int u, int v ) {
while( top[u] != top[v] ) {
if( dep[top[u]] < dep[top[v]] )
swap(u, v);
u = pa[top[u]];
}
return dep[u] < dep[v] ? u : v;
}
}
int mark[N], rmark[N], rmark2nod[N];
ll solve_ans = 0, now = 0;
set st;
void ins_node( int u ) {
u = mark[u];
if( !u ) return;
u = rmark2nod[u+2];
u = HLD::dfn[u];
if( st.empty() ) {
st.insert(u);
u = HLD::rdfn[u];
now += HLD::val[u];
} else {
auto after = st.lower_bound(u);
auto before = after;
--before;
if( after == st.end() ) {
int L = *before;
L = HLD::rdfn[L];
u = HLD::rdfn[u];
int lca = HLD::lca(L, u);
now -= HLD::val[lca];
now += HLD::val[u];
u = HLD::dfn[u];
st.insert(u);
} else if( after == st.begin() ) {
int R = *after;
R = HLD::rdfn[R];
u = HLD::rdfn[u];
int lca = HLD::lca(R, u);
now -= HLD::val[lca];
now += HLD::val[u];
u = HLD::dfn[u];
st.insert(u);
} else {
int L = *before;
int R = *after;
L = HLD::rdfn[L];
R = HLD::rdfn[R];
now += HLD::val[HLD::lca(L, R)];
u = HLD::rdfn[u];
now -= HLD::val[HLD::lca(L, u)];
now -= HLD::val[HLD::lca(R, u)];
now += HLD::val[u];
u = HLD::dfn[u];
st.insert(u);
}
}
}
void ins_tree( int u ) {
ins_node(u);
for( int i = g.head[u]; ~i; i = g.nxt[i] )
ins_tree( g.to[i] );
}
int sz[N], son[N];
void init_sack( int u ) {
sz[u] = 1, son[u] = -1;
for( int i = g.head[u]; ~i; i = g.nxt[i] ) {
int v = g.to[i];
init_sack(v);
sz[u] += sz[v];
if( son[u] == -1 || sz[v] > sz[son[u]] )
son[u] = v;
}
}
void sack( int u, bool clr ) {
// printf( "u = %d\n", u );
for( int i = g.head[u]; ~i; i = g.nxt[i] )
if( g.to[i] != son[u] )
sack( g.to[i], true );
if( son[u] != -1 )
sack( son[u], false );
for( int i = g.head[u]; ~i; i = g.nxt[i] )
if( g.to[i] != son[u] )
ins_tree( g.to[i] );
ins_node(u);
// printf( "u = %d, now = %lld\n", u, now );
if( u )
solve_ans += 1LL * now * (sam.len[u] - sam.len[sam.pa[u]]);
if( clr ) st.clear(), now = 0;
}
ll solve() {
sam.init();
int last = 0;
for( int i = 1; i <= n-2; ++i )
last = sam.append(last, str[i] - 'a');
g.init();
for( int i = 1; i < sam.idx; ++i )
g.link( sam.pa[i], i );
last = 0;
for( int i = 1; i <= n-2; ++i ) {
last = sam.ch[last][str[i] - 'a'];
mark[last] = i;
}
rsam.init();
last = 0;
for( int i = n; i >= 3; --i )
last = rsam.append(last, str[i] - 'a');
rg.init();
for( int i = 1; i < rsam.idx; ++i )
rg.link( rsam.pa[i], i );
last = 0;
for( int i = n; i >= 3; --i ) {
last = rsam.ch[last][str[i] - 'a'];
rmark[last] = i;
rmark2nod[i] = last;
}
HLD::init();
init_sack(0);
sack(0, false);
return solve_ans;
}
int main() {
_w = scanf( "%s", str+1 );
n = (int)strlen(str+1);
ll ans = 0;
ans += solve_origin();
// printf( "after origin = %lld\n", ans );
if( n >= 2 ) {
ans += solve_before();
ans += solve_after();
}
// printf( "before after = %lld\n", ans );
if( n >= 3 ) {
ans += solve();
}
printf( "%lld\n", ans+2 );
return 0;
}