http://acm.hdu.edu.cn/showproblem.php?pid=7131
给定一个序列s,求其[前缀是,后缀是>=1个a]的子序列个数
先求出子序列为nunhehheh的个数,定义dp(i,j)为s的前i个字符中和nunhehheh匹配到第j个个数.然后预处理出i后面有多少个a,记为a[i],对于每个dp(i,9)乘 2 a [ i ] 2^{a[i]} 2a[i]再相加即可得到所有方案数
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#pragma GCC optimize(2)
#pragma GCC optimize("inline")
#pragma GCC optimize("-fgcse")
#pragma GCC target("avx","sse2")
#pragma GCC optimize("-fgcse-lm")
#pragma GCC optimize("-fipa-sra")
#pragma GCC optimize("-ftree-pre")
#pragma GCC optimize("-ftree-vrp")
#pragma GCC optimize("-fpeephole2")
#pragma GCC optimize("-ffast-math")
#pragma GCC optimize("-fsched-spec")
#pragma GCC optimize("unroll-loops")
using namespace std;
#define ll long long
#define PII pair<int,int>
#define PLL pair<ll,ll>
#define PIII pair<int,PII>
#define PLLL pair<ll,PLL>
#define fi first
#define se second
#define pb push_back
#define debug(a) cout << #a << " " << a << '\n';
const int N = 1e5 + 5;
const int M = 1e5 + 5;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 998244353;
inline ll read();
int n, m, t;
ll dp[N][15];
ll a[N];
ll poww[N];
void solve() {
char s[N];
string p = "@nunhehheh";
cin >> (s + 1);
ll len =strlen(s+1);
for (int i = 0; i <= len+1; i++) {//初始化
for (int j = 0; i <= 10; i++)dp[i][j] = 0;
a[i] = 0;
}
for (int i = len; i >= 0; i--) {
dp[i][0] = 1;//与s中第i个字符一个都不匹配的数量是1
if (s[i] == 'a')a[i] = (a[i + 1] + 1) % mod;
else {
a[i] = a[i + 1];//预处理
}
}
ll ans = 0;
for (int i = 1; i <= len; i++) {
for (int j = 1; j <= 9; j++) {
if (s[i] == p[j])dp[i][j] = (dp[i - 1][j - 1] + dp[i - 1][j]) % mod;
else {
dp[i][j] = dp[i - 1][j] % mod;//算公共序列个数
}
}
}
for (int i = 0; i <= len; i++) {
if (s[i] == 'h') {
ans += (dp[i][8] * (poww[a[i]] - 1)) % mod;//注意这里是dp[i][8].如用dp[i][9]算答案会重复
}
}
cout << ans % mod << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin >> t;
poww[0]=1;
for(int i=1;i<=1e5;i++){
poww[i] =(poww[i-1]*2)%mod;
}
while (t--) {
solve();
}
return 0;
}
inline ll read() {
char ch = getchar();
ll p = 1, data = 0;
while (ch < '0' || ch > '9') {
if (ch == '-')p = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
data = data * 10 + (ch ^ 48);
ch = getchar();
}
return p * data;
}