取球问题——解题报告

题目链接:

http://192.168.1.251/contest/9/problem/25

题目大意:

n n n个不同的小球,编号为 1 ∼ n 1\sim n 1n。求从这n个小球中取出至少m个小球的方案数。
数据范围: 3 ≤ n ≤ 1 0 9 3\leq n\leq10^9 3n109 3 ≤ m ≤ 1 0 5 3\leq m\leq10^5 3m105

题目分析:

首先分析题目的数据范围,可知时间复杂度大致与 m m m l o g n logn logn有关。
如果直接按照题目的要求得到答案,那么就是求出 C n m + . . . + C n n C_n^m+...+C_n^n Cnm+...+Cnn,而在现在的数据范围下,显然是不太合适的。因为 m ≪ n m\ll n mn,所以很显然,根据二项式定理算 2 n − ( C n 0 + . . . + C n m − 1 ) 2^n-(C_n^0+...+C_n^{m-1}) 2n(Cn0+...+Cnm1)显然更划算一些。所以我们只要在 O ( m ) O(m) O(m)的时间内求出 C n 0 + . . + C n m − 1 C_n^0+..+C_n^{m-1} Cn0+..+Cnm1的值即可。而通过观察我们可以发现一些规律。
A n m A_n^m Anm的含义是从 n n n个中有序选择 m m m个。而 A n m A_n^m Anm A n m − 1 A_n^{m-1} Anm1的关系在哪里呢?实际上 A n m = A n 1 × A n − 1 m − 1 = n × A n − 1 m − 1 A_n^m=A_n^1\times A_{n-1}^{m-1}=n\times A_{n-1}^{m-1} Anm=An1×An1m1=n×An1m1。以此类推 A n − 1 m − 1 = A n − 1 1 × A n − 2 m − 2 = ( n − 1 ) × A n − 2 m − 2 A_{n-1}^{m-1} = A_{n-1}^1 \times A_{n-2}^{m-2}=(n-1) \times A_{n-2}^{m-2} An1m1=An11×An2m2=(n1)×An2m2,后面的项都可以这样拆开来。因为显然每一项都可以展开,所以这里很明显的展现出了递归的特性。
因为题目要求求和,所以:
A n 1 + A n 2 + . . . + A n m − 1 = n × ( 1 + A n − 1 1 . . . + A n − 1 m − 2 ) = n × ( 1 + ( n − 1 ) × ( 1 + . . . + A n − 2 m − 3 ) ) A_n^1+A_n^2+...+A_n^{m-1}=n\times(1+A_{n-1}^1...+A_{n-1}^{m-2})=n\times(1+(n-1)\times(1+...+A_{n-2}^{m-3})) An1+An2+...+Anm1=n×(1+An11...+An1m2)=n×(1+(n1)×(1+...+An2m3))
转换成递归的表达形式就是
f ( p o s ) = p o s × ( 1 + d f s ( p o s − 1 ) ) f(pos) = pos \times (1 + dfs(pos-1)) f(pos)=pos×(1+dfs(pos1))
而从 A n x A_n^x Anx C n x C_n^x Cnx仅仅只需要除以 x ! x! x!即可。所以很容易转换过来:
d f s ( p o s ) = p o s × ( 1 n − p o s + 1 + d f s ( p o s − 1 ) ) dfs(pos) = pos \times(\frac{1}{n-pos+1}+dfs(pos-1)) dfs(pos)=pos×(npos+11+dfs(pos1))
取模的逆元问题在这里不做赘述了。最后时间复杂度为 O ( m + l o g n ) O(m+logn) O(m+logn)。在本题的可接受范围内。

正解程序:

#include 
#include 
#include 
#include 
#define mod 1000000007

using namespace std;
typedef long long ll;
ll n,m,fac[1000010],inv[1000010];
ll dfs(ll pos)
{
	if(pos==n-m+2)
		return pos*inv[n-pos+1]%mod;
	return pos*((inv[n-pos+1]+dfs(pos-1))%mod)%mod;
}
ll quickmi(ll a,ll p)
{
	ll ans=1;
	while(p)
	{
		if(p&1)
			ans=ans*a%mod;
		a=a*a%mod;
		p>>=1;
	}
	return ans;
}
int main()
{
	scanf("%lld%lld",&n,&m);
	fac[0]=1;
    for(ll i=1;i<=m;i++)
        fac[i]=fac[i-1]*i%mod;
   	inv[1]=1;
    for(ll i=2;i<=m;i++)
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    inv[0]=1;
    for(ll i=1;i<=m;i++)
        inv[i]=inv[i]*inv[i-1]%mod;
	ll ans=dfs(n)+1;
	printf("%lld\n",(quickmi(2,n)-ans+mod)%mod);
	
	return 0;
}

你可能感兴趣的:(#,排列组合,递归,数学,排列组合)