题目大意:有A,B两个序列,序列中每个正整数的值大小不超过100,A的数列可形成一个环,求三个奇怪的东西(大雾),具体什么东西嘛,咳咳。
我的做法:用SAM建后缀树,逐个求a[i]和b[1]的LCP,累计答案。
#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=a;i>=b;i--)
#define maxn 100007
#define maxm 600007
inline int rd() {
char c = getchar();
while (!isdigit(c)) c = getchar() ; int x = c - '0';
while (isdigit(c = getchar())) x = x * 10 + c - '0';
return x;
}
typedef int arr[maxn];
typedef int sam[maxm];
typedef long long ll ;
int fa[maxm][19] , go[maxm][101];
sam tp , len , dep;
arr a , b , pos[2];
int n , m , ed , tot , cnt[(maxm >> 1) + 4];
inline void upmax(int&a , int b) { if (a < b) a = b ; }
void add(int c , int l , int k) {
int p = ed , np = ed = ++ tot;
len[np] = len[p] + 1;
if (l) pos[k][l] = np ;
for(;p && !go[p][c];p = fa[p][0]) go[p][c] = np;
if (!p)
{ fa[np][0] = 1 ; return ; }
int q = go[p][c] ;
if (len[q] == len[p] + 1)
{ fa[np][0] = q ; return ; }
int r = ++ tot ; len[r] = len[p] + 1;
memcpy(go[r] , go[q] , sizeof go[q]);
for(fa[r][0] = fa[q][0] , fa[q][0] = fa[np][0] = r;go[p][c] == q;p = fa[p][0]) go[p][c] = r;
}
void tpsort() {
rep(i , 1 , tot) cnt[len[i]] ++;
rep(i , 1 , n + n + n + 2) cnt[i] += cnt[i - 1];
per(i , tot , 1) tp[cnt[len[i]] --] = i;
rep(i , 1 , tot) {
int u = tp[i];
dep[u] = dep[fa[u][0]] + 1;
rep(j , 1 , 18) fa[u][j] = fa[fa[u][j - 1]][j - 1];
}
}
void input() {
n = rd();
rep(i , 1 , n) upmax(m , a[i] = rd());
rep(i , 1 , n) upmax(m , b[i] = rd());
ed = tot = 1 ;
per(i , n , 1) add(b[i] , i , 1);add(0 , 0 , 0);
per(i , n , 1) add(a[i] , 0 , 0);
per(i , n , 1) add(a[i] , i , 0);
}
int lcp(int u , int v) {
if (dep[u] < dep[v]) swap(u , v);
int d = dep[u] - dep[v];
rep(i , 0 , 18) if (d & (1 << i)) u = fa[u][i];
if (u == v) return len[u];
per(i , 18 , 0) if (fa[u][i] != fa[v][i]) u = fa[u][i] , v = fa[v][i];
return len[fa[u][0]];
}
inline int nxt(int x) { x ++ ; if (x > n) x -= n ; return x ; }
void solve() {
tpsort();
ll ans_1 = 0 , ans_2 = 0 , ans_3 = 0;
rep(i , 1 , n) {
int l = lcp(pos[0][i] , pos[1][1]);
ans_3 += l;
if (l != n) {
ans_3 ++;
if (a[nxt(i + l - 1)] > b[l + 1]) ans_1 ++;
else ans_2 ++;
}
}
printf("%lld %lld %lld\n" , ans_1 , ans_2 , ans_3);
}
int main() {
input();
solve();
return 0;
}