组合数取模之逆元方法+模板

参自:

http://www.cnblogs.com/liziran/p/6804803.html

https://baike.baidu.com/item/%E8%B4%B9%E9%A9%AC%E5%B0%8F%E5%AE%9A%E7%90%86/4776158?fr=aladdin


现在目标是求 Cnm%p,p为素数(经典p=1e9+7)

虽然有 Cnm=n!m!(n−m)!,但由于取模的性质对于除法不适用,所以 Cnm%p (n!%pm!%p∗(n−m)!%p)%p

所以需要把“除法”转换成“乘法”,才能借助取模的性质在不爆long long的情况下计算组合数。这时候就需要用到“逆元”!

  逆元:对于a和p,若a*b%p≡1,则称b为a%p的逆元。

那这个逆元有什么用呢?试想一下求 (ab)%p,如果你知道b%p的逆元是c,那么就可以转变成 (ab)%p = a*c%p = (a%p)(c%p)%p

那怎么求逆元呢?这时候就要引入强大的费马小定理!

  费马小定理(Fermat's little theorem)数论中的一个重要定理,在1636年提出,其内容为: 假如p是质数,且gcd(a,p)=1,那么 a(p-1)≡1(mod p),即:假如a是整数,p是质数,且a,p互质(即两者只有一个公约数1),那么a的(p-1)次方除以p的余数恒等于1。

接着因为 ap−1 =  ap−2∗a,所以有 ap−2∗a%p≡1!对比逆元的定义可得, ap−2是a的逆元!

所以问题就转换成求解 ap−2,即变成求快速幂的问题了(当然这需要满足p为素数)。

现在总结一下求解 Cnm%p的步骤:

  1. 通过循环,预先算好所有小于max_number的阶乘(%p)的结果,存到fac[max_number]里 (fac[i] = i!%p)
  2. 求m!%p的逆元(即求fac[m]的逆元):根据费马小定理,x%p的逆元为 xp−2,因此通过快速幂,求解 fac[m]p−2%p,记为M
  3. 求(n-m)!%p的逆元:同理为求解 fac[n−m]p−2%p,记为NM
  4. Cnm%p = ((fac[n]*M)%p*NM)%p

模板:

#include ///codeforces 869C代码 主函数三个循环可以合并成一个循环+三个if
using namespace std;
const int MAXN = 5050;
const int mod = 998244353;
typedef unsigned long long LL;

LL a,b,c,aa,bb,cc;

LL inv[MAXN],fac[MAXN];

inline int Inv(int x){///x^(mod-2)
	int res = 1;
	int p = mod - 2;
	while (p) {
		if (p & 1) res = LL(res) * x % mod;
		p >>= 1;
		x = LL(x) * x % mod;
	}
	return res;
}

inline int C(int n, int k){
	if (n < 0 || k < 0 || k > n) return 0;
	return LL(fac[n]) * inv[k] % mod * inv[n - k] % mod;
}

void init(){
	fac[0] = inv[0] = 1;
	for (int i = 1; i < MAXN; i++) {
		fac[i] = LL(fac[i - 1]) * i % mod;
		inv[i] = Inv(fac[i]);///预处理fac[i]^(p-2)
	}
}

int main(){
    init();
    cin>>a>>b>>c;
    LL ans=0LL,ans1=0LL,ans2=0LL,ans3=0LL;
    LL tmp;
    aa=min(a,b);
    for(LL i=0;i<=aa;++i){
        tmp=(LL)C(a,i);
        tmp=tmp*(LL)C(b,i)%mod;
        tmp=tmp*(LL)fac[i]%mod;
        ans1=(ans1+tmp)%mod;
    }
    bb=min(c,b);
    for(LL i=0;i<=bb;++i){
        tmp=(LL)C(c,i);
        tmp=tmp*(LL)C(b,i)%mod;
        tmp=tmp*(LL)fac[i]%mod;
        ans2=(ans2+tmp)%mod;
    }
    cc=min(a,c);
    for(LL i=0;i<=cc;++i){
        tmp=(LL)C(a,i);
        tmp=tmp*(LL)C(c,i)%mod;
        tmp=tmp*(LL)fac[i]%mod;
        ans3=(ans3+tmp)%mod;
    }
    ans=(ans1*ans2)%mod*ans3%mod;
    printf("%d\n",ans%mod);
    return 0;
}





你可能感兴趣的:(数论)