【数论】逆元求组合数(模板+题)

模板

int fac[N];

int power(int x, int k, int p) {
    int ret = 1;
    while(k) {
        if(k & 1) ret = ret * x % p;
        k >>= 1;
        x = x * x % p;
    }
    return ret % p;
}
int inv(int x, int p) {     //求x关于模p的逆元
    return power(x, p - 2, p) % MOD;
}
int C(int a, int b) {       //直接用逆元求解组合数C(a, b)
    return fac[a] * inv(fac[b], MOD) % MOD * inv(fac[a - b], MOD) % MOD;
}

void solve()
{
	//预处理
	fac[0] = 1;
    for(int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i % MOD;

	//求C i k (i在下,k在上)
	cout << C(i, k) << endl;		
}

例题

原题链接:ABC - 151 - E - Max-Min Sums

【数论】逆元求组合数(模板+题)_第1张图片
【数论】逆元求组合数(模板+题)_第2张图片

样例1

输入:

4 2
1 1 3 4

输出:

11

样例2

输入:

6 3
10 10 10 -10 -10 -10

输出:

360

样例3

输入:

3 1
1 1 1

输出:

0

样例4

输入:

10 6
1000000000 1000000000 1000000000 1000000000 1000000000 0 0 0 0 0

输出:

999998537

思路

  1. 如果Ckn暴力枚举元素个数为k的子集,然后找最大值最小值,显然超时。理想状态应该找一个O~(n logn)~的复杂度。
  2. 如果给元素排个序,然后会发现其实每个元素被当成最小值和最大值的次数是可计算的,也就是用组合数学的方式计算每个数作为子集中最小值和最大值的次数,然后最大值加一起,减去最小值加一起,即为最终答案。
  3. 找规律:对于前k - 1个数,不可能作为子集中的最大值。对于其他的数,第i个数作为子集中的最大值的次数为Ck-1i-1,作为子集中的最小值的次数为Ck-1n-i
  4. 枚举所有数ai :ans = (ai * Ck-1i-1) - (ai * Ck-1n-i)。
  5. 如果是打表记录组合数的话,需要记录C [1e5] [1e5] 范围的组合数,显然是存不下的。
  6. 这个数据范围可以用逆元求组合数,注意这里需要对 MOD=1e9+7取模处理

代码

#include
#pragma GCC optimize(2)
#define int long long
#define ld long double
#define gcd __gcd
#define PII pair<int, int>

using namespace std;
const int N = 1e5 + 10;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
bool cmp(PII a, PII b) { return a.second > b.second; } // 大到小
struct edge { int v, w, next; } e[N];

int ans, n, k, a[N];
int fac[N];


int power(int x, int k, int p) {
    int ret = 1;
    while(k) {
        if(k & 1) ret = ret * x % p;
        k >>= 1;
        x = x * x % p;
    }
    return ret % p;
}
int inv(int x, int p) {     //求x关于模p的逆元
    return power(x, p - 2, p) % MOD;
}
int C(int a, int b) {       //直接用逆元求解组合数C(a, b)
    return fac[a] * inv(fac[b], MOD) % MOD * inv(fac[a - b], MOD) % MOD;
}

void solve()
{
    cin >> n >> k;
    for (int i = 1; i <= n; ++i) {
        cin >> a[i];
    }

    //预处理
    fac[0] = 1;
    for(int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i % MOD;

    sort(a + 1, a + n + 1);

    //用作max
    for (int i = k; i <= n; ++i) {
        ans = (ans + a[i] * C(i - 1, k - 1) % MOD) % MOD ;
        ans %= MOD;
    }
    //用作min
    for (int i = 1; i <= n - k + 1; ++i) {
        ans = (ans - a[i] * C(n - i, k - 1) % MOD) % MOD;
    }

    cout << (ans + MOD) % MOD << endl;
}

int32_t main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int _ = 1;
//    cin >> _;
    while(_--) solve();
    return 0;
}

你可能感兴趣的:(算法竞赛,算法,数据结构,笔记)