UOJ #275: 组合数问题 题解

Description

组合数 Cmn C n m 表示的是从 n 个物品中选出 m 个物品的方案数。举个例子,从 (1,2,3) 三个物品中选择两个物品可以有 (1,2),(1,3),(2,3)这三种选择方法。根据组合数的定义,我们可以给出计算组合数 Cmn C n m 的一般公式:

Cmn=n!m!(nm)! C n m = n ! m ! ( n − m ) !

其中 n!=1×2×⋯×n(额外的,当 n=0时, n!=1)

小葱想知道如果给定 n,m和 k,对于所有的 0≤i≤n,0≤j≤min(i,m)有多少对 (i,j) 满足 Cji C i j 是 k 的倍数。

答案对 1e9+7取模。

Input

第一行有两个整数 t,k,其中 t 代表该测试点总共有多少组测试数据。
接下来 t 行每行两个整数 n,m。

Output

t 行,每行一个整数代表所有的 0≤i≤n,0≤j≤min(i,m) 中有多少对 (i,j) 满足 Cji C i j 是 k 的倍数。

Sample Input 1

1 2
3 3

Sample Output 1

1

Explanation

在所有可能的情况中,只有 C12=2 C 2 1 = 2 是2的倍数。

Sample Input 2

2 5
4 5
6 7

Sample Output 2

0
7

Sample Input 3

3 23
23333333 23333333
233333333 233333333
2333333333 2333333333

Sample Output 3

851883128
959557926
680723120

Hint

对于 20%的测试点,1≤n,m≤100;
对于另外 15%的测试点,n≤m;
对于另外 15%的测试点, k=2;
对于另外 15%的测试点, m≤10;
对于 100%的测试点, 1≤n,m≤1e18,1≤t,k≤100,且 k是一个质数。
时间限制:1s
空间限制:512MB


非常好的题目,感觉这个lucas的应用算是常见套路了吧
题目要求 CMN0(modk) C N M ≡ 0 ( m o d k ) ,因为k是质数,所以根据Lucas定理, CM/kN/kCMmodkNmodk0(modk) C N / k M / k ∗ C N m o d k M m o d k ≡ 0 ( m o d k )
我们发现这个东西很像10进制转k进制,把这个东西用lucas不断展开之后,就相当于把N和M转成k进制数后每位对着算C,那么是k的倍数的充要条件就是N和M的k进制表示中存在一位N比M小
所以就转成一个数位dp了,按理来说转移是可以做到 O(1) O ( 1 ) 的,但非常麻烦,考虑到k比较小,就写比较蠢的转移了

#include 
using namespace std;

#define LL long long
#define LB long double
#define ull unsigned long long
#define x first
#define y second
#define pb push_back
#define pf push_front
#define mp make_pair
#define Pair pair
#define pLL pair
#define pii pair

const int INF=2e9;
const LL LINF=2e16;
const int magic=348;
const int MOD=1e9+7;
const double eps=1e-10;
const double pi=acos(-1);

inline int getint()
{
    bool f;char ch;int res;
    while (!isdigit(ch=getchar()) && ch!='-') {}
    if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
    while (isdigit(ch=getchar())) res=res*10+ch-'0';
    return f?res:-res;
}

inline LL getLL()
{
    bool f;char ch;LL res;
    while (!isdigit(ch=getchar()) && ch!='-') {}
    if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
    while (isdigit(ch=getchar())) res=res*10+ch-'0';
    return f?res:-res;
}

int dp[148][48];
int limi[148],limj[148],tot1,tot2;
LL n,m;int k;

inline int add(int x) {if (x>=MOD) x-=MOD;return x;}
inline int sub(int x) {if (x<0) x+=MOD;return x;}

int main ()
{
    //freopen ("a.in","r",stdin);
    //freopen ("a.out","w",stdout);
    int ca;ca=getint();k=getint();
    while (ca--)
    {
        n=getLL();m=getLL();LL tmp;int i,p,q,il,ir,jl,jr,iLim,jLim,equal,smaller,Mask,toMask;
        m=min(n,m);tot1=tot2=0;
        tmp=n;while (tmp) limi[++tot1]=tmp%k,tmp/=k;
        tmp=m;while (tmp) limj[++tot2]=tmp%k,tmp/=k;
        for (i=tot2+1;i<=tot1;i++) limj[i]=0;
        reverse(limi+1,limi+tot1+1);reverse(limj+1,limj+tot1+1);
        memset(dp,0,sizeof(dp));
        dp[0][15]=1;
        for (i=0;i<=tot1-1;i++)
            for (Mask=0;Mask<=15;Mask++)
                if (dp[i][Mask])
                    for (toMask=0;toMask<=15;toMask++)
                    {
                        if (!(toMask&8) && (toMask&4 || (!(toMask&4) && (Mask&4)))) continue;
                        if ((toMask&Mask)==toMask)
                        {
                            int coef=0;
                            if (toMask&1) il=ir=limi[i+1];
                            else if (Mask&1) il=0,ir=limi[i+1]-1; else il=0,ir=k-1;
                            if (toMask&2) jl=jr=limj[i+1];
                            else if (Mask&2) jl=0,jr=limj[i+1]-1; else jl=0,jr=k-1;
                            for (p=il;p<=ir;p++)
                                for (q=jl;q<=jr;q++)
                                {
                                    if ((toMask&4) && p!=q) continue;
                                    if (!(toMask&4) && (Mask&4) && p<=q) continue;
                                    if ((toMask&8) && pcontinue;
                                    if (!(toMask&8) && (Mask&8) && p>=q) continue;
                                    coef++;
                                }
                            dp[i+1][toMask]=add(dp[i+1][toMask]+(1ll*coef*dp[i][Mask])%MOD);
                        }
                    }
        int ans=0;
        for (Mask=0;Mask<=15;Mask++)
            if (!(Mask&8)) ans=add(ans+dp[tot1][Mask]);
        printf("%d\n",ans);
    }
    return 0;
}

你可能感兴趣的:(数位dp,Lucas定理)