链接:https://ac.nowcoder.com/acm/contest/881/E
来源:牛客网
题目大意:求长度为2*(n+m)的字符串数量,要求满足其中有n个'AB'子串,m个'BA'子串。
例如:
给出n=1,m=2的合法序列:
ABABAB
ABABBA
ABBAAB
ABBABA
ABBBAA
BAABBA
BAABAB
BABAAB
BABABA
BABBAA
BBAAAB
BBAABA
BBABAA
仔细观察,我们会发现每个字符串的前n个A一定是属于'AB'子串里的A,前m个B一定是属于'BA'子串里的B。
比如BAABAB,前两个B属于'BA'里的B,前一个A属于'AB'里的A。自然能组合出1个'AB',2个‘BA’。
接下来我们讨论解法:
一、动态规划(dp)
根据题目条件我们可以建立一个二维数组dp[i][j]代表i个A和j个B时满足条件的字符串数量,可以得出dp[i][j]都是由dp[i-1][j]和dp[i][j-1]的状态转移过来的,需要一直推到i=n+m,j=m+n。
一开始,空字符串dp[0][0]=1
当i+1≤n,dp[i+1][j]+=dp[i][j] ///对于上述观察法,我们可以得出当i+1≤n时,相比dp[i][j]多出了1个A,满足情况显然可以放
当i+1<=j+n,dp[i+1][j]+=dp[i][j] ///由上一步i+1大于n时,可以得出此时m中有(i+1-n)个‘A’还未配对,由上述观察法,得出此时至少需要有j个'B'和它配对
同理可得
当j+1<=m,dp[i][j+1]+=dp[i][j]
当j+1<=i+m,dp[i][j+1]+=dp[i][j]
所以我们就可以O((n+m)*(n+m))的求了
二、组合数学
首先考虑包含全部可能的解,对于n个'AB'子串,m个'BA'子串共有C(2*n+2*m,n+m)组解(包含不合法的解)。
接下来考虑只要把不合法的解减掉,剩下的便是合法的解了!
取A为1,B为-1,便可以将字符串处理成前缀和的形式,所以任意点的前缀和便为-m≤sum[i]≤n(极限情况)。
假设现在有不符合情况的字符串s,sum[i]=n+1。
从左往右枚举,考虑两种极限情况,第一种是sum[i]=-(m+1),第二种是sum[i]=n+1。
这里我们讨论第二种,第一种可由读者自己推。
对于sum[i]=n+1,就是存在一个最小的 i(i≥n+1),在第i项之前(包括第i项)有i个A,i-(n+1)个B。
我们把这个字符串设为s,此时s中有n+m个A,n+m个B(假设它符合条件),由于是假设的,这里不能通过直接算s中A,B的数量来求出不满足的情况,我们可以通过将前i项的A和B互换,将转换后的字符串记为t,通过t串间接求出s中A,B的数量。
例如:n=1,m=2。
对于s串AABABB,存在最小的i=2使得sum[i]=n+1。转换后的t串为BBBABB。
那么对于s串和t串的前i项中A和B的差值都为n+1,转换关系可以写成f(s)=t,g(t)=s;
可以得出这个过程是可逆的,所以我们只要求出t中A,B的数量,就能通过这个转换关系(A和B相差的数量)求出s串中A,B的数量。
通过转换,我们能得到t串的特点:
存在n+m-(n+1)个A,n+m+(n+1)个B,使得存在sum[j]=-(n+1)(一定存在,极限情况下前len个全是B,有sum[len]=-2*(n+1)≤-(n+1))。通过观察我们发现j与s串中的i相等。
所以t中有n+m-(n+1)个A,n+m+(n+1)个B,通过A,B的原本数量的差值为(n+1)个,能得出s串中有n+m+(n+1)个A,n+m-(n+1)个B,所以s串不符合原本含有n+m个A,n+m个B的假设,证毕。
对于这种情况(这是上述两种极限情况中的第二种情况),共有C(2*n+2*m,m-1)组解。
同理可得第一种极限情况有C(2*n+2*m,n-1)组解。
所以答案为C(2*n+2*m,n+m)-C(2*n+2*m,n-1)-C(2*n+2*m,m-1)。
代码可以通过组合数打表预处理O(4000*log(1e9+7)),最终对于每个询问O(1)输出答案。
动态规划AC代码:
语言:C++ 代码长度:1210 运行时间: 216 ms 占用内存:31832K
1 #include2 #define numm ch-48 3 #define pd putchar(' ') 4 #define pn putchar('\n') 5 #define pb push_back 6 #define mp make_pair 7 #define fi first 8 #define se second 9 #define fi first 10 #define se second 11 #define fre1 freopen("1.txt","r",stdin) 12 #define fre2 freopen("2.txt","w",stdout) 13 using namespace std; 14 template 15 void read(T &res) { 16 bool flag=false;char ch; 17 while(!isdigit(ch=getchar())) (ch=='-')&&(flag=true); 18 for(res=numm;isdigit(ch=getchar());res=(res<<1)+(res<<3)+numm); 19 flag&&(res=-res); 20 } 21 template 22 void write(T x) { 23 if(x<0) putchar('-'),x=-x; 24 if(x>9) write(x/10); 25 putchar(x%10+'0'); 26 } 27 const int maxn=2010; 28 typedef long long ll; 29 typedef long double ld; 30 const ll mod=1e9+7; 31 ll dp[maxn][maxn]; 32 int main() 33 { 34 int n,m; 35 while(scanf("%d%d",&n,&m)!=EOF) { 36 for(int i=0;i<=n+m;i++) 37 for(int j=0;j<=m+n;j++) 38 dp[i][j]=0; 39 dp[0][0]=1; 40 for(int i=0;i<=n+m;i++) 41 for(int j=0;j<=m+n;j++) { 42 if(j>=(i+1)-n) dp[i+1][j]=(dp[i+1][j]+dp[i][j])%mod; 43 if(i>=(j+1)-m) dp[i][j+1]=(dp[i][j+1]+dp[i][j])%mod; 44 } 45 write(dp[n+m][n+m]);pn; 46 } 47 return 0; 48 }
组合数学AC代码:
语言:C++ 代码长度:1416 运行时间: 5 ms 占用内存:480K
1 #include2 #define numm ch-48 3 #define pd putchar(' ') 4 #define pn putchar('\n') 5 #define pb push_back 6 #define mp make_pair 7 #define fi first 8 #define se second 9 #define fi first 10 #define se second 11 #define fre1 freopen("1.txt","r",stdin) 12 #define fre2 freopen("2.txt","w",stdout) 13 using namespace std; 14 template 15 void read(T &res) { 16 bool flag=false;char ch; 17 while(!isdigit(ch=getchar())) (ch=='-')&&(flag=true); 18 for(res=numm;isdigit(ch=getchar());res=(res<<1)+(res<<3)+numm); 19 flag&&(res=-res); 20 } 21 template 22 void write(T x) { 23 if(x<0) putchar('-'),x=-x; 24 if(x>9) write(x/10); 25 putchar(x%10+'0'); 26 } 27 const int maxn=4010; 28 typedef long long ll; 29 typedef long double ld; 30 const ll mod=1e9+7; 31 ll jie[maxn]; 32 ll inv[maxn]; 33 ll quickpow(ll a,ll b) { 34 ll ans=1; 35 while(b) { 36 if(b&1) ans=(ans*a)%mod; 37 a=(a*a)%mod; 38 b>>=1; 39 } 40 return ans; 41 } 42 void init() { 43 jie[0]=1; 44 inv[0]=quickpow(jie[0],mod-2); 45 for(int i=1;i<=4000;i++) { 46 jie[i]=(jie[i-1]*(ll)i)%mod; 47 inv[i]=quickpow(jie[i],mod-2); 48 } 49 } 50 ll C(int n,int m) { 51 return jie[n]*inv[n-m]%mod*inv[m]%mod; 52 } 53 int main() 54 { 55 int n,m; 56 init(); 57 while(scanf("%d%d",&n,&m)!=EOF) { 58 ll ans=C((n<<1)+(m<<1),n+m); 59 if(n) ans-=C((n<<1)+(m<<1),n-1); 60 if(m) ans-=C((n<<1)+(m<<1),m-1); 61 write((ans%mod+mod)%mod); 62 pn; 63 } 64 return 0; 65 }