给定方程
X1+X 2+…+Xn=m
我们对第 1… n1 个变量 进行一些限制 :
X1≤A1
X2≤A2
…
Xn1 ≤An1
我们对第 n1+1… n1+1… n1+ n2 个变量 进行一些限制 :
X_(n1+1)≥A_(n1+1)
X_(n1+2)≥A_(n1+2)
…
X_(n1+n2) ≥A_(n1+n2)
求:在满足这些限制的前提下, 该方程正整数解的个数。
答案可能很大,请输出对 p取模 后的答案 ,也即 答案除以 p的余数。
输入含有多组数据 ,第一行两个 正整数 T,p。T表示这个测试点内的 数据 组数 ,p的含义见题目描述 。
对于每组数据,第一行 四个非负 整数 n,n1 ,n2 ,m。
第二行 n1+ n2 个正整数,表示 A1…n1+n2 。请注意,如果n1+n2等于0 ,那么 这一行会成为一个空行。
共 T行,每行一个正整数 表示 取模后的答案。
3 10007
3 1 1 6
3 3
3 0 0 5
3 1 1 3
3 3
3
6
0
对于第一组数 据, 三组解为 (1,3,2 ),(1,4,1) (1,4,1),(2,3,1) 。
对于第二组 数据 ,六组解为 (1,1,3) ,(1,2,2),(1,3,1) ,(2,1,2) ,(2,2,1),(3,1,1) 。
瞎推发现只会弄p=10007的情况,而且还发现n1可以直接容斥
然后就暴力套lucas定理+exgcd。
然后就不知道为什么炸了。10分。
这题是真的秒♂啊。
首先,我们发现那个限制条件奇小无比。
首先大于的条件比较好解决,直接把m减去 ∑ A 2 i \sum A2_i ∑A2i即可
但是那个小于的条件似乎比较棘手。但是这个玩意只有8,所以可以考虑直接暴力容斥。
容斥完后我们现在要求的东西只剩下一个 C m − 1 n − 1 C_{m-1}^{n-1} Cm−1n−1,然后这个玩意由于模数问题和阶乘问题,不能直接暴力求。
所以要引入一个神奇的东东叫做拓展lucas。(exlucas定理)
不会的戳这里
然后就很轻松地解决了。
当然,我的程序似乎打的比较丑,所以加了个预处理优化了下。
#include
#include
#include
#include
using namespace std;
int n1,n2,x[10],y[10],d[10],now,zs[100000],p[100000],q[100000],count;
int n,m,T;
long long ans,sum1,sum2,jc[1000010],an1,an2,an3,answer,mo,jcc[2];
bool bz[1000000];
inline int read() {
int x = 0, f = 0; char c = getchar();
while (c < '0' || c > '9') f = (c == '-') ? 1 : f, c = getchar();
while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return f ? -x : x;
}
__attribute__((optimize("-O3")))
long long qsm(long long a,long long b)
{
long long t=1;
long long y=a;
while (b>0)
{
if ((b&1)==1) t=t*y%mo;
y=y*y%mo;
b/=2;
}
return t;
}
__attribute__((optimize("-O3")))
long long f(long long n,long long mo,long long p)
{
if (n==0) return 1;
long long an1=1;
long long an2=1;
long long je1=n/p;
long long je2=n/mo;
if (mo==10007)
an1=jc[mo-1];
else
if (mo==10201)
an1=jcc[0];
else
for (long long i=1;i<=mo;i++)
{
if (i%p!=0)
{
an1=an1*i%mo;
count++;
}
}
an1=qsm(an1,je2)%mo;
for (long long i=mo*je2+1;i<=n;i++)
{
if (i%p!=0)
{
an2=an2*i%mo;
}
}
long long an3=f(je1,mo,p);
long long ans=an1*an2%mo*an3%mo;
return ans;
}
__attribute__((optimize("-O3")))
long long g(long long n,long long mo,long long p)
{
if (n<p) return 0;
return g(n/p,mo,p)+n/p;
}
__attribute__((optimize("-O3")))
void exgcd(long long &x1,long long &y1,long long aa,long long bb){
if (bb==0)
{
x1=1;
y1=0;
return;
}
long long x0,y0;
exgcd(x0,y0,bb,aa%bb);
x1=y0; y1=x0-aa/bb*y0;
if (x1<0)
{
x1+=bb;
y1-=aa;
}
if (x1>bb)
{
x1-=bb;
y1+=aa;
}
}
__attribute__((optimize("-O3")))
long long ny(long long x,long long y)
{
long long p=0;
long long q=0;
exgcd(p,q,x,y);
return p;
}
__attribute__((optimize("-O3")))
long long C(long long n,long long m,long long mo,long long p)
{
long long jl1=ny(f(n,mo,p),mo);
long long jl2=ny(f(m-n,mo,p),mo);
long long ans=f(m,mo,p)*jl1%mo*jl2%mo;
long long jl3=qsm(p,g(m,mo,p)-g(n,mo,p)-g(m-n,mo,p));
ans=ans*jl3%mo;
return ans;
}
__attribute__((optimize("-O3")))
long long excrt(long long n,long long m,long long mo)
{
long long kk=mo;
now=0;
for (int i=1;i<=zs[0];i++)
{
if (kk%zs[i]==0)
{
now++;
p[now]=zs[i];
q[now]=0;
}
while (kk%zs[i]==0) q[now]++,kk=kk/zs[i];
}
long long ans=0;
for (register int i=1;i<=now;i++)
{
long long op=1;
for (register int j=1;j<=q[i];j++) op=op*p[i];
ans=(ans+C(n,m,op,p[i])*ny(mo/op,op)%mo*(mo/op)%mo)%mo;
}
return ans;
}
__attribute__((optimize("-O3")))
void dfs(int xx,int p)
{
if (xx>n1)
{
int op=0;
int gs=0;
for (register int i=1;i<=n1;i++)
{
gs+=d[i];
if (d[i]==1)
{
op+=x[i];
}
}
if (p-op>=n)
{
int oq=excrt(n-1,p-op-1,mo);
if (gs%2==1) ans=(ans-oq+mo)%mo;
else ans=(ans+oq)%mo;
}
}
else
{
d[xx]=1;
dfs(xx+1,p);
d[xx]=0;
dfs(xx+1,p);
}
}
__attribute__((optimize("-O3")))
void rc(int p)
{
memset(d,0,sizeof(d));
dfs(1,p);
}
__attribute__((optimize("-O3")))
int main()
{
freopen("data.in", "r", stdin);
// freopen(".out", "w", stdout);
for (int i=2;i<=10007;i++)
{
if (bz[i]==0)
{
zs[0]++;
zs[zs[0]]=i;
bz[i]=1;
for (int j=1;j<=1000000/i;j++)
{
bz[i*j]=1;
}
}
}
scanf("%d%d",&T,&mo);
jc[0]=1;
for (long long i=1;i<=10007;i++)
{
jc[i]=jc[i-1]*i%10007;
}
jcc[0]=1;
for (long long i=1;i<=10200;i++)
{
if (i%101!=0)
{
jcc[0]=jcc[0]*i%10201;
}
}
while (T>0)
{
T--;
scanf("%d%d%d%d",&n,&n1,&n2,&m);
sum1=0;
for (register int i=1;i<=n1;i++)
{
x[i]=read();
sum1+=x[i];
}
sum2=0;
for (register int i=1;i<=n2;i++)
{
y[i]=read();
sum2+=y[i];
}
m=m-sum2+n2;
ans=0;
rc(m);
printf("%lld\n",ans);
}
// printf("%d\n",count);
return 0;
}