CodeForces - 1264D2 Beautiful Bracket Sequence(生成函数 + 组合计数)

CodeForces - 1264D2 Beautiful Bracket Sequence(生成函数 + 组合计数)_第1张图片
CodeForces - 1264D2 Beautiful Bracket Sequence(生成函数 + 组合计数)_第2张图片

大致题意

给你一个由左右括号和?组成的字符串,现在?可以替换成左右括号的任意一个。定义一个字符串的深度为最大的左右括号嵌套数。现在问,所有的替换方案产生的字符串的深度总和是多少。

做法

如果有 n n n括号,那么就会有 2 n 2^n 2n个字符串,显然直接计算不可以。

考虑一个字符为 ‘(’ 的位置 i i i,如果他要对最后的深度产生影响,当且仅当它和它左边的 ‘(’ 数目小于等于它右边的 ‘)’ 数目。那么,我们就可以考虑枚举每一个位置,看每一个位置对最后答案的贡献。

考虑位置 i i i,它和它左边有 a a a个左括号, b b b个问号,它右边有 c c c个右括号, d d d个问号。那么显然,它的贡献就是:
∑ a + x < c + y C ( b , x ) ∗ C ( d , y ) \sum_{a+xa+x<c+yC(b,x)C(d,y)
z = d − y z=d-y z=dy,那么可以写成:
∑ x + z < c + d − a C ( b , x ) ∗ C ( d , z ) \sum_{x+zx+z<c+daC(b,x)C(d,z)
我们知道,对于组合数 C ( n , m ) C(n,m) C(n,m),他的生成函数是 ( x + 1 ) n (x+1)^n (x+1)n,其中第 m m m项系数就是 C ( n , m ) C(n,m) C(n,m)的答案。那么上式可以写成两个生成函数的乘积。
( x + 1 ) b ∗ ( x + 1 ) d = ( x + 1 ) b + d (x+1)^b*(x+1)^d=(x+1)^{b+d} (x+1)b(x+1)d=(x+1)b+d
由于 x + z < c + d − a x+zx+z<c+da,所以答案就是:
∑ k < c + d − a C ( b + d , k ) \sum_{kk<c+daC(b+d,k)
然后我们发现, b + d b+d b+d相当于是所有问号的个数是一个固定的值,也就只需要求一次组合数的前缀和即可。因此直接对每个为左括号的位置统计贡献即可。对于为问号的位置,由于它也可以变成左括号,因此我们要把他们变成左括号统计他们的贡献,这时 b + d b+d b+d会减一。相当于总的只需要求两次组合数的前缀和即可,时间复杂度 O ( N ) O(N) O(N)

代码

懒得写线性求逆了,实际复杂度 O ( N l o g N ) O(NlogN) O(NlogN),但是可以写到 O ( N ) O(N) O(N)

#include
#define INF 0x3f3f3f3f
#define eps 1e-5
#define pi 3.141592653589793
#define LL long long
#define pb push_back
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<"      :   "<
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;

const int mod = 998244353;
const int N = 1e6 + 7;

unordered_map<int,vector<int> > mp;
char s[N];

LL qpow(LL x,LL n)
{
    LL res=1;
    while(n)
    {
        if (n&1) res=res*x%mod;
        x=x*x%mod; n>>=1;
    }
    return res;
}

LL cal(LL l,LL r,LL L,LL R)
{
    LL n=L+R,m=r+R-l;
    if (m<0) return 0;
    if (n<m) m=n;
    if (mp.count(n)) return mp[n][m];
    LL res=1,sum=1;
    std::vector<int> v;
    for(int i=0;i<=n;i++)
    {
        v.pb(sum);
        res=res*(n-i)%mod*qpow(i+1,mod-2)%mod;
        sum=(sum+res)%mod;
    }
    mp[n]=v;
    return mp[n][m];
}

int main(int argc, char const *argv[])
{
    LL ans=0;
    scanf("%s",s);
    int len=strlen(s);
    int l=0,L=0,r=0,R=0;
    for(int i=len-1;i>=0;i--)
    {
        if (s[i]==')') r++;
        if (s[i]=='?') R++;
    }
    for(int i=0;i<len;i++)
    {
        if (s[i]=='(') l++;
        if (s[i]=='?') L++,R--;
        if (s[i]==')') r--;
        if (s[i]=='(')ans=(ans+cal(l,r,L,R))%mod;
        if (s[i]=='?')
        {
            l++,L--;
            ans=(ans+cal(l,r,L,R))%mod;
            l--,L++;
        }
    }
    printf("%lld\n",ans);
    return 0;
}

你可能感兴趣的:(CodeForces,母函数,组合计数)