【矩阵快速幂】P10581 [蓝桥杯 2024 国 A] 重复的串|省选-

本文涉及知识点

【矩阵快速幂】封装类及测试用例及样例

P10581 [蓝桥杯 2024 国 A] 重复的串

题目描述

给定一个仅含小写字母的字符串 S S S,问有多少个长度为 n n n 的仅含小写字母的字符串中恰好出现了两次 S S S。答案对 998   244   353 998\ 244\ 353 998 244 353 取模。

输入格式

输入一行包含一个字符串 S S S 和一个整数 n n n,用一个空格分隔。

输出格式

输出一行包含一个整数表示答案。

输入输出样例 #1

输入 #1

aba 6

输出 #1

53

输入输出样例 #2

输入 #2

aba 10

输出 #2

77907666

说明/提示

对于 40 % 40\% 40% 的评测用例, n ≤ 20 n \le 20 n20 ∣ S ∣ ≤ 6 |S| \le 6 S6
另有 10 % 10\% 10% 的评测用例, n ≤ 500 n\le 500 n500 ∣ S ∣ ≤ 2 |S| \le 2 S2
对于 70 % 70\% 70% 的评测用例, n ≤ 1 0 5 n\le 10^5 n105
对于所有评测用例, 1 ≤ n ≤ 1 0 9 1\le n\le 10^9 1n109 1 ≤ ∣ S ∣ ≤ 30 1 \le |S| \le 30 1S30

矩阵指数幂 KMP

M = S的长度
dp[n][j]表示长度为n符合题意的字符串s2数量,j是S2和S匹配的字符数量。
矩阵长度3M+1
前置状态是3M无需枚举,直接mat[3M][3M]=26
第一层循环枚举前置状态是j=to 0 M-1, 第二层循环枚举当前字符ch 。
j2是S.Right(j1)+ch和S的最长公共后缀、前缀。
mat[j+2M][j2+2M]+=1
如果j2 ≠ \neq = M
mat[j][j2]+=1 mat[j+M][j2+m] +=1
否则
x是S的最长公共前后缀
mat[j][j2+x]+=1 mat[j+M][j2+m+x] +=1
建立矩阵时间复杂度:O( ∑ \sum MMM) ∑ \sum 是字符集的大小,本题是26。使用KMP降为O( ∑ \sum MM)。
初始pre=dp[0],pre[0]为1,其它为0.
matAns = pre*matn
matAns[2M…3M-1]便是答案。
时间复杂度:O(lognMMM)

代码

核心代码

#include 
#include 
#include 
#include
#include
#include
#include
#include
#include
#include
#include
#include 
#include
#include
#include 
#include 
#include
#include
#include
#include 

#include 
using namespace std;

template<class T1, class T2>
std::istream& operator >> (std::istream& in, pair<T1, T2>& pr) {
	in >> pr.first >> pr.second;
	return in;
}

template<class T1, class T2, class T3 >
std::istream& operator >> (std::istream& in, tuple<T1, T2, T3>& t) {
	in >> get<0>(t) >> get<1>(t) >> get<2>(t);
	return in;
}

template<class T1, class T2, class T3, class T4 >
std::istream& operator >> (std::istream& in, tuple<T1, T2, T3, T4>& t) {
	in >> get<0>(t) >> get<1>(t) >> get<2>(t) >> get<3>(t);
	return in;
}

template<class T = int>
vector<T> Read() {
	int n;
	scanf("%d", &n);
	vector<T> ret(n);
	for (int i = 0; i < n; i++) {
		cin >> ret[i];
	}
	return ret;
}

template<class T = int>
vector<T> Read(int n) {
	vector<T> ret(n);
	for (int i = 0; i < n; i++) {
		cin >> ret[i];
	}
	return ret;
}

template<int N = 1'000'000>
class COutBuff
{
public:
	COutBuff() {
		m_p = puffer;
	}
	template<class T>
	void write(T x) {
		int num[28], sp = 0;
		if (x < 0)
			*m_p++ = '-', x = -x;

		if (!x)
			*m_p++ = 48;

		while (x)
			num[++sp] = x % 10, x /= 10;

		while (sp)
			*m_p++ = num[sp--] + 48;
		AuotToFile();
	}
	void writestr(const char* sz) {
		strcpy(m_p, sz);
		m_p += strlen(sz);
		AuotToFile();
	}
	inline void write(char ch)
	{
		*m_p++ = ch;
		AuotToFile();
	}
	inline void ToFile() {
		fwrite(puffer, 1, m_p - puffer, stdout);
		m_p = puffer;
	}
	~COutBuff() {
		ToFile();
	}
private:
	inline void AuotToFile() {
		if (m_p - puffer > N - 100) {
			ToFile();
		}
	}
	char  puffer[N], * m_p;
};

template<int N = 1'000'000>
class CInBuff
{
public:
	inline CInBuff() {}
	inline CInBuff<N>& operator>>(char& ch) {
		FileToBuf();
		ch = *S++;
		return *this;
	}
	inline CInBuff<N>& operator>>(int& val) {
		FileToBuf();
		int x(0), f(0);
		while (!isdigit(*S))
			f |= (*S++ == '-');
		while (isdigit(*S))
			x = (x << 1) + (x << 3) + (*S++ ^ 48);
		val = f ? -x : x; S++;//忽略空格换行		
		return *this;
	}
	inline CInBuff& operator>>(long long& val) {
		FileToBuf();
		long long x(0); int f(0);
		while (!isdigit(*S))
			f |= (*S++ == '-');
		while (isdigit(*S))
			x = (x << 1) + (x << 3) + (*S++ ^ 48);
		val = f ? -x : x; S++;//忽略空格换行
		return *this;
	}
	template<class T1, class T2>
	inline CInBuff& operator>>(pair<T1, T2>& val) {
		*this >> val.first >> val.second;
		return *this;
	}
	template<class T1, class T2, class T3>
	inline CInBuff& operator>>(tuple<T1, T2, T3>& val) {
		*this >> get<0>(val) >> get<1>(val) >> get<2>(val);
		return *this;
	}
	template<class T1, class T2, class T3, class T4>
	inline CInBuff& operator>>(tuple<T1, T2, T3, T4>& val) {
		*this >> get<0>(val) >> get<1>(val) >> get<2>(val) >> get<3>(val);
		return *this;
	}
	template<class T = int>
	inline CInBuff& operator>>(vector<T>& val) {
		int n;
		*this >> n;
		val.resize(n);
		for (int i = 0; i < n; i++) {
			*this >> val[i];
		}
		return *this;
	}
	template<class T = int>
	vector<T> Read(int n) {
		vector<T> ret(n);
		for (int i = 0; i < n; i++) {
			*this >> ret[i];
		}
		return ret;
	}
	template<class T = int>
	vector<T> Read() {
		vector<T> ret;
		*this >> ret;
		return ret;
	}
private:
	inline void FileToBuf() {
		const int canRead = m_iWritePos - (S - buffer);
		if (canRead >= 100) { return; }
		if (m_bFinish) { return; }
		for (int i = 0; i < canRead; i++)
		{
			buffer[i] = S[i];//memcpy出错			
		}
		m_iWritePos = canRead;
		buffer[m_iWritePos] = 0;
		S = buffer;
		int readCnt = fread(buffer + m_iWritePos, 1, N - m_iWritePos, stdin);
		if (readCnt <= 0) { m_bFinish = true; return; }
		m_iWritePos += readCnt;
		buffer[m_iWritePos] = 0;
		S = buffer;
	}
	int m_iWritePos = 0; bool m_bFinish = false;
	char buffer[N + 10], * S = buffer;
};


template<class T = long long>
class CMatMul
{
public:
	CMatMul(T llMod = 1e9 + 7) :m_llMod(llMod) {}
	// 矩阵乘法
	vector<vector<T>> multiply(const vector<vector<T>>& a, const vector<vector<T>>& b) {
		const int r = a.size(), c = b.front().size(), iK = a.front().size();
		assert(iK == b.size());
		vector<vector<T>> ret(r, vector<T>(c));
		for (int i = 0; i < r; i++)
		{
			for (int j = 0; j < c; j++)
			{
				for (int k = 0; k < iK; k++)
				{
					ret[i][j] = (ret[i][j] + a[i][k] * b[k][j]) % m_llMod;
				}
			}
		}
		return ret;
	}

	// 矩阵快速幂
	vector<vector<T>> pow(const vector<vector<T>>& a, vector<vector<T>> b, T n) {
		vector<vector<T>> res = a;
		for (; n; n /= 2) {
			if (n % 2) {
				res = multiply(res, b);
			}
			b = multiply(b, b);
		}
		return res;
	}
	vector<vector<T>> pow(vector<vector<T>> pre, vector<vector<T>> mat, const string& str)
	{
		for (int i = str.length() - 1; i >= 0; i--) {
			const int t = str[i] - '0';
			pre = pow(pre, mat, t);
			mat = pow(mat, mat, 9);
		}
		return pre;
	}
	vector<vector<T>> TotalRow(const vector<vector<T>>& a)
	{
		vector<vector<T>> b(a.front().size(), vector<T>(1, 1));
		return multiply(a, b);
	}
protected:
	const  T m_llMod;
};

class KMP
{
public:
	virtual int Find(const string& s, const string& t)
	{
		CalLen(t);
		for (int i1 = 0, j = 0; i1 < s.length(); )
		{
			for (; (j < t.length()) && (i1 + j < s.length()) && (s[i1 + j] == t[j]); j++);
			//i2 = i1 + j 此时s[i1,i2)和t[0,j)相等 s[i2]和t[j]不存在或相等
			//t[0,j)的结尾索引是j-1,所以最长公共前缀为m_vLen[j-1],简写为y 则t[0,y)等于t[j-y,j)等于s[i2-y,i2)
			if (0 == j)
			{
				i1++;
				continue;
			}
			const int i2 = i1 + j;
			j = m_vLen[j - 1];
			i1 = i2 - j;//i2不变
		}
		return -1;
	}
	//vector m_vSameLen;//m_vSame[i]记录 s[i...]和t[0...]最长公共前缀,增加可调试性 部分m_vSameLen[i]会缺失
	//static vector Next(const string& s)
	//{// j = vNext[i] 表示s[0,i]的最大公共前后缀是s[0,j]
	//	const int len = s.length();
	//	vector vNext(len, -1);
	//	for (int i = 1; i < len; i++)
	//	{
	//		int next = vNext[i - 1];
	//		while ((-1 != next) && (s[next + 1] != s[i]))
	//		{
	//			next = vNext[next];
	//		}
	//		vNext[i] = next + (s[next + 1] == s[i]);
	//	}
	//	return vNext;
	//}

	const vector<int> CalLen(const string& str)
	{
		m_vLen.resize(str.length());
		for (int i = 1; i < str.length(); i++)
		{
			int next = m_vLen[i - 1];
			while (str[next] != str[i])
			{
				if (0 == next)
				{
					break;
				}
				next = m_vLen[next - 1];
			}
			m_vLen[i] = next + (str[next] == str[i]);
		}
		return m_vLen;
	}
protected:
	int m_c;
	vector<int> m_vLen;//m_vLen[i] 表示str[0,i]的最长公共前后缀的长度
};

class Solution {
public:
	int Ans(long long n, string S) {
		const int M = S.length();
		KMP kmp;
		auto tmp = kmp.CalLen(S);
		vector<vector<long long>> mat(3 * M + 1, vector<long long>(3 * M + 1));
		auto Len = [&](int preLen, char ch) {
			while ((0 != preLen) && (ch != S[preLen])) {
				preLen = tmp[preLen - 1];
			}
			return (ch != S[preLen]) ? 0 : (preLen + 1);
		};
		mat[3 * M][3 * M] = 26;
		for (int j = 0; j < M; j++) {
			for (char ch = 'a'; ch <= 'z'; ch++) {
				int j2 = Len(j, ch);
				if (M == j2) {
					mat[j][M + tmp.back()] ++; mat[j + M][2 * M + tmp.back()] ++;
				}
				else {
					mat[j][j2]++; mat[j + M][j2 + M]++;
				}
				mat[j + 2 * M][j2 + 2 * M]++;
			}
		}
		vector<vector<long long>> pre(1, vector<long long>(3 * M + 1));
		pre[0][0] = 1;
		CMatMul<> matMul(998244353);
		auto matAns = matMul.pow(pre, mat, n);
		long long ans = accumulate(matAns[0].begin() + 2 * M, matAns[0].begin() + 3 * M, 0LL);
		return ans % 998244353;
	}
};

int main() {
#ifdef _DEBUG
	freopen("a.in", "r", stdin);
#endif // DEBUG	
	ios::sync_with_stdio(0);
	string s; int n;
	cin >> s >> n ;
		auto res = Solution().Ans(n,s);
		cout <<res << "\n";
		
#ifdef _DEBUG		
	//printf("start=%d,end=%d,T=%d", start,end,T);
	//Out(edge, "edge=");
	//Out(fish, ",fish=");
	/*Out(edge, "edge=");
	Out(que, "que=");*/
#endif // DEBUG	
	
	return 0;
}

扩展阅读

我想对大家说的话
工作中遇到的问题,可以按类别查阅鄙人的算法文章,请点击《算法与数据汇总》。
学习算法:按章节学习《喜缺全书算法册》,大量的题目和测试用例,打包下载。重视操作
有效学习:明确的目标 及时的反馈 拉伸区(难度合适) 专注
闻缺陷则喜(喜缺)是一个美好的愿望,早发现问题,早修改问题,给老板节约钱。
子墨子言之:事无终始,无务多业。也就是我们常说的专业的人做专业的事。
如果程序是一条龙,那算法就是他的是睛
失败+反思=成功 成功+反思=成功

视频课程

先学简单的课程,请移步CSDN学院,听白银讲师(也就是鄙人)的讲解。
https://edu.csdn.net/course/detail/38771
如何你想快速形成战斗了,为老板分忧,请学习C#入职培训、C++入职培训等课程
https://edu.csdn.net/lecturer/6176

测试环境

操作系统:win7 开发环境: VS2019 C++17
或者 操作系统:win10 开发环境: VS2022 C++17
如无特殊说明,本算法用**C++**实现。

你可能感兴趣的:(蓝桥杯,线性代数,c++,洛谷,数学,矩阵快速幂,重复)