题意:给你 n n n个病毒的DNA序列,现在要造出一个长度为 m m m的DNA序列,问你有多少种不含病毒DNA序列的方案。
首先可以看到要构造的序列长度很大,达到了2e9(20亿),遍历一遍都会超时,肯定得写一个时间复杂度在 O ( l o g n ) O(logn) O(logn)以下的算法。
怎么解决这一问题呢?先直接说结论吧:
用AC自动机构造出邻接矩阵,然后跑矩阵快速幂,最后取矩阵第0行元素之和即可。
看了这句话是不是一头雾水?(我也一样) 下面来分析一下为什么要构造邻接矩阵。
在离散数学中有这样一个结论:
说人话 就是:从 u u u点到 v v v点恰好经过 k k k步的方案数,为邻接矩阵的 k k k次幂得到的矩阵(假设是 a n s ans ans)中的元素 a n s [ u ] [ v ] ans[u][v] ans[u][v]。(具体解法详见这篇文章)
那么这一结论对本题有什么启示呢?
所谓构造一个序列,其实就是让一个点从根节点开始走,保证走到的第一个点是序列的第一个元素,第二个点是第二个元素,走的过程实际上就是在Trie图中进行状态的转移。先看看暴力的想法:直接让一个点从根节点出发,在Trie图中“随便乱走”,由于Trie图在Trie树的基础上补全了不存在的出边节点, 那么每个点在下一步都有四个点(A,T,C,G)的选择,走到m步就停止,当然走的过程中不能经过病毒串终点,这样就是合法的序列。但是之前已经说过,不可能走m步,因为m太大了。
一次性走m步求不出,但是可以求每个点只走一步能转移到哪些点,这实际上就是求邻接矩阵。
如果我们能求出邻接矩阵 A A A,那么再求出 A m A^m Am,就得到了走m步能到达的点的所有情况。在 A m A^m Am矩阵中,第0行,第 j 列元素表示从0点开始走m步走到 j 点可能的方案数,求和即可。
注意ac自动机中的cnt[]成为病毒串标记,在构造fail指针时记得将标记向下传递。
AC代码:
#include
#include
#include
#include
using namespace std;
typedef long long ll;
const int N=105,M=4,K=10,mod=1e5;
int to_int(char c) // A,T,C,G -> 0,1,2,3
{
if(c=='A')return 0;
else if(c=='T')return 1;
else if(c=='C')return 2;
else return 3;
}
struct matrix
{
ll m[N][N];
matrix() // 构造函数,初始化
{
memset(m,0,sizeof(m));
};
};
struct trie
{
int ch[N][M];
int fail[N];
bool cnt[N];
int tot;
queue<int>q;
void ins(char s[])
{
int u=0;
for(int i=0;s[i];i++)
{
int x=to_int(s[i]);
if(!ch[u][x])ch[u][x]=++tot;
u=ch[u][x];
}
// 病毒串终点标记(题目应该保证了不会出现两个相同病毒串)
cnt[u]=1;
}
void build_fail()
{
for(int i=0;i<M;i++)
{
if(ch[0][i])
q.push(ch[0][i]);
}
while(!q.empty())
{
int u=q.front();q.pop();
for(int i=0;i<M;i++)
{
int &v=ch[u][i];
int f=ch[fail[u]][i];
if(v)
{
fail[v]=f;
cnt[v]|=cnt[f];// 等价于 if(cnt[f])cnt[v]=1;
// 病毒串终点标记向下传递(不要写反了!是f传递到v!)
q.push(v);
}
else v=f;
}
}
}
matrix build_matrix() // 构建邻接矩阵
{
matrix ans=matrix();
for(int i=0;i<=tot;i++) // tot+1个点
{
if(cnt[i])continue; // u不能是病毒串终点
for(int j=0;j<M;j++) // 每个点有M条出边
{
int v=ch[i][j]; // 走到的下一节点为v
if(!cnt[v]) // v不能是病毒串终点
ans.m[i][v]++;
}
}
return ans;
}
}ac;
matrix mul(matrix s1,matrix s2) // 两矩阵相乘
{
matrix ans=matrix();
int sz=ac.tot+1;
for(int i=0;i<sz;i++)
{
for(int j=0;j<sz;j++)
{
for(int k=0;k<sz;k++)
{
ans.m[i][j]+=s1.m[i][k]*s2.m[k][j];
// ans.m[i][j]=(ans.m[i][j]+s1.m[i][k]*s2.m[k][j]%mod)%mod;
// 取模这样写容易超时,在保证不会爆long long的情况下
// 应该先加起来存long long里,最后再取模
}
ans.m[i][j]%=mod;
}
}
return ans;
}
matrix matrix_pow(matrix a,int b) // 矩阵a的b次幂
{
matrix ans=matrix();
int sz=ac.tot+1;
for(int i=0;i<sz;i++)
ans.m[i][i]=1; // 单位矩阵
while(b)
{
if(b&1)ans=mul(ans,a);
b/=2;
a=mul(a,a);
}
return ans;
}
int n,m;
char t[K];
int main()
{
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=1;i<=n;i++)
{
cin>>t;
ac.ins(t);
}
ac.build_fail();
matrix a=ac.build_matrix(); // 得到邻接矩阵a
matrix ans=matrix_pow(a,m); // 得到矩阵a^m
ll sum=0;
for(int i=0;i<=ac.tot;i++) // 累加 0 —> 0,1,...tot 的所有方案数
sum=(sum+ans.m[0][i])%mod;
printf("%lld\n",sum);
return 0;
}
/*
2 32
A
T
ans:67296
10 100
AGAGAGT
CGTATTG
AAAATTTCGC
GCGTA
TCGA
AATTGGA
TAGATAGC
AGCGTATT
TTCGA
TACGTATTG
ans:35771
*/
和上题差不多,这个是要求<=m的所有方案,构造一个矩阵[{E,E},{0,A}]进行快速幂即可得到A0+A1+…+Am。
然后对264取模的意思就是定义成unsigned long long,计算过程中会自动对264取模 (我先还以为是大数取模呢)
#include
using namespace std;
typedef unsigned long long ll;
const int N=105,M=26,K=5;
struct matrix
{
ll m[N][N];
matrix() // 构造函数,初始化
{
memset(m,0,sizeof(m));
};
};
struct trie
{
int ch[N][M];
bool cnt[N];
int fail[N];
int tot;
queue<int>q;
void init()
{
tot=0;
memset(cnt,0,sizeof(cnt));
memset(fail,0,sizeof(fail));
memset(ch,0,sizeof(ch));
}
void ins(char s[])
{
int u=0;
for(int i=0;s[i];i++)
{
int x=s[i]-'a'; // a~z -> 0~25
if(!ch[u][x])ch[u][x]=++tot;
u=ch[u][x];
}
cnt[u]=1;
}
void build_fail()
{
for(int i=0;i<M;i++)
{
if(ch[0][i])
q.push(ch[0][i]);
}
while(!q.empty())
{
int u=q.front();q.pop();
for(int i=0;i<M;i++)
{
int &v=ch[u][i];
int f=ch[fail[u]][i];
if(v)
{
fail[v]=f;
cnt[v]|=cnt[f];
q.push(v);
}
else v=f;
}
}
}
matrix build_matrix() // 得到邻接矩阵
{
matrix ans=matrix();
for(int i=0;i<=tot;i++)
{
if(cnt[i])continue;
for(int j=0;j<M;j++)
{
int v=ch[i][j];
if(!cnt[v])
ans.m[i][v]++;
}
}
return ans;
}
}ac;
matrix mul(matrix s1,matrix s2,int sz)
{
matrix ans=matrix();
for(int i=0;i<sz;i++)
for(int j=0;j<sz;j++)
for(int k=0;k<sz;k++)
ans.m[i][j]+=s1.m[i][k]*s2.m[k][j];// 不用取模!
return ans;
}
matrix matrix_pow(matrix a,int b,int sz)
{
matrix ans=matrix();
for(int i=0;i<sz;i++)
ans.m[i][i]=1; // 单位矩阵
while(b)
{
if(b&1)ans=mul(ans,a,sz);
a=mul(a,a,sz);
b/=2;
}
return ans;
}
int n,m;
char t[K];
int main()
{
ios::sync_with_stdio(false);
while(cin>>n>>m)
{
ac.init();
for(int i=1;i<=n;i++)
{
cin>>t;
ac.ins(t);
}
ac.build_fail();
matrix tmp=matrix();
tmp.m[0][0]=1,tmp.m[0][1]=1,tmp.m[1][1]=26;
matrix s1=matrix_pow(tmp,m+1,2);
ll sum1=s1.m[0][1]-1; // 总方案数sum1
matrix a=ac.build_matrix(); // 邻接矩阵
int sz=ac.tot+1;
matrix b=matrix();
for(int i=0;i<sz;i++)
{
b.m[i][i]=1;
b.m[i][i+sz]=1;
}
for(int i=0;i<sz;i++)
for(int j=0;j<sz;j++)
b.m[i+sz][j+sz]=a.m[i][j];
matrix s2=matrix_pow(b,m+1,2*sz);
ll sum2=0;
for(int i=sz;i<2*sz;i++)
sum2+=s2.m[0][i];
sum2--; //不含模式串的方案数sum2
//printf("sum1=%I64u sum2=%I64u\n",sum1,sum2); // debug
printf("%I64u\n",sum1-sum2);
}
return 0;
}
/*
2 3
aa
ab
sum1=18278 sum2=18174
ans:104
2 13
aa
ab
sum1=2580398988131886038 sum2=2493353857086648626
ans:87045131045237412
2 2000000000
aa
ab
sum1=8116567392432202710 sum2=14915077526685486680
ans:11648233939456267646
*/
这题需要用到高精度,然后因为幂次比较小可以不用快速幂,直接DP就行,我想着java有大数,那就拿java写一下大数的矩阵快速幂吧。
注意编码格式,因为java没有unsigned char。将输入写成Scanner cin = new Scanner(new BufferedInputStream(System.in), "ISO-8859-1");
即可。
Java代码:
//package Main; //package信息一定要去掉,否则RE
import java.io.BufferedInputStream;
import java.math.BigInteger;
import java.util.LinkedList;
import java.util.Queue;
import java.util.Scanner;
class Matrix {
BigInteger m[][];
Matrix(int sz, int type) { // type控制零矩阵/单位矩阵
m = new BigInteger[sz][sz];
for(int i = 0; i < sz; i++)
for(int j = 0; j < sz; j++)
m[i][j] = BigInteger.ZERO;
if(type == 1) {
for(int i = 0; i < sz; i++)
m[i][i] = BigInteger.ONE;
}
}
}
class Trie {
static final int N = 105, M = 55, K = 256;
int ch[][] = new int[N][M];
int fail[] = new int[N];
boolean cnt[] = new boolean[N];
int tot = 0;
int len;
Queue<Integer> q = new LinkedList<Integer>();
int mp[] = new int[K];
void ins(String s) {
int u = 0;
for(int i = 0; i < s.length(); i++)
{
int x=mp[s.charAt(i)]; // ASCII码 -> 0~len-1(len是字母表长度)
if(ch[u][x] == 0) ch[u][x] = ++tot;
u = ch[u][x];
}
cnt[u] = true;
}
void build_fail() {
for(int i = 0; i < len; i++) {
if(ch[0][i] != 0)
q.offer(ch[0][i]);
}
while(!q.isEmpty()) { // 队列非空
int u = q.poll(); // 取出并删除队头的元素
for(int i = 0; i < len; i++) {
int v = ch[u][i];
int f = ch[fail[u]][i];
if(v != 0) {
if(cnt[f] == true) cnt[v] = true;
fail[v] = f;
q.offer(v);
}
else ch[u][i] = f;
}
}
}
Matrix build_matrix() {
int sz = tot + 1;
Matrix ans = new Matrix(sz, 0);
for(int i = 0; i < sz; i++) {
if(cnt[i]) continue;
for(int j = 0;j < len; j++) {
int v = ch[i][j];
if(!cnt[v])
ans.m[i][v] = ans.m[i][v].add(BigInteger.ONE);
}
}
return ans;
}
Matrix mul(Matrix s1, Matrix s2) {
int sz = tot + 1;
Matrix ans = new Matrix(sz, 0);
for(int i = 0; i < sz; i++)
for(int j = 0; j < sz; j++)
for(int k = 0; k < sz; k++)
ans.m[i][j] = ans.m[i][j].add(s1.m[i][k].multiply(s2.m[k][j]));
return ans;
}
Matrix matrix_pow(Matrix a, int b) {
int sz = tot + 1;
Matrix ans = new Matrix(sz, 1);
while(b != 0) {
if(b%2 == 1) ans = mul(ans,a);
a = mul(a,a);
b /= 2;
}
return ans;
}
}
public class Main {
public static void main(String[] args) {
Scanner cin = new Scanner(new BufferedInputStream(System.in), "ISO-8859-1");
//Scanner cin = new Scanner(System.in); 会RE
Trie ac = new Trie();
ac.len = cin.nextInt();
int m = cin.nextInt();
int n = cin.nextInt();
String s = cin.next();
for(int i = 0; i < ac.len; i++) {
ac.mp[s.charAt(i)] = i;
}
for(int i = 1; i <= n; i++) {
String t = cin.next(); // 病毒串
ac.ins(t);
}
ac.build_fail();
Matrix a = ac.build_matrix(); // 邻接矩阵
Matrix ans = ac.matrix_pow(a,m);
BigInteger sum = BigInteger.ZERO;
int sz = ac.tot + 1;
for(int i = 0; i < sz; i++) {
sum = sum.add(ans.m[0][i]);
}
System.out.println(sum);
}
}
C++代码(没写大数,不能AC,只是作为java代码的“翻译”):
#include
#include
#include
#include
#include
using namespace std;
typedef unsigned long long ll;
typedef unsigned char uc;
const int N=105,M=55,K=256;
int mp[K];
int len,n,m;
struct matrix
{
ll m[N][N];
matrix()
{
memset(m,0,sizeof(m));
}
};
struct trie
{
int ch[N][M];
int fail[N];
bool cnt[N];
int tot;
queue<int>q;
void init()
{
memset(cnt,0,sizeof(cnt));
memset(fail,0,sizeof(fail));
memset(ch,0,sizeof(ch));
tot=0;
}
void ins(uc s[])
{
int u=0;
for(int i=0;s[i];i++)
{
int x=mp[s[i]]; // ASCII码 -> 0~len-1(len是字母表长度)
if(!ch[u][x])ch[u][x]=++tot;
u=ch[u][x];
}
cnt[u]=1;
}
void build_fail()
{
for(int i=0;i<len;i++)
{
if(ch[0][i])
q.push(ch[0][i]);
}
while(!q.empty())
{
int u=q.front();q.pop();
for(int i=0;i<len;i++)
{
int &v=ch[u][i];
int f=ch[fail[u]][i];
if(v)
{
cnt[v]|=cnt[f];
fail[v]=f;
q.push(v);
}
else v=f;
}
}
}
matrix build_matrix()
{
int sz=tot+1;
matrix ans=matrix();
for(int i=0;i<sz;i++)
{
if(cnt[i])continue;
for(int j=0;j<len;j++)
{
int v=ch[i][j];
if(!cnt[v])
ans.m[i][v]++;
}
}
return ans;
}
}ac;
matrix mul(matrix s1,matrix s2)
{
int sz=ac.tot+1;
matrix ans=matrix();
for(int i=0;i<sz;i++)
for(int j=0;j<sz;j++)
for(int k=0;k<sz;k++)
ans.m[i][j]+=s1.m[i][k]*s2.m[k][j];
return ans;
}
matrix matrix_pow(matrix a,int b)
{
int sz=ac.tot+1;
matrix ans=matrix();
for(int i=0;i<sz;i++)
ans.m[i][i]=1;
while(b)
{
if(b&1)ans=mul(ans,a);
a=mul(a,a);
b/=2;
}
return ans;
}
int main()
{
ios::sync_with_stdio(false);
uc c,t[12];
cin>>len>>m>>n;
memset(mp,0,sizeof(mp));
ac.init();
for(int i=0;i<len;i++)
{
cin>>c;
mp[c]=i;
}
for(int i=1;i<=n;i++)
{
cin>>t;
ac.ins(t);
}
ac.build_fail();
int sz=ac.tot+1;
matrix a=ac.build_matrix();
matrix ans=matrix_pow(a,m);
ll sum=0;
for(int i=0;i<sz;i++)
sum+=ans.m[0][i];
printf("%llu",sum);
return 0;
}