jzoj3214. 【SDOI2013】方程

Description

给定方程

X1+X 2+…+Xn=m

我们对第 1… n1 个变量 进行一些限制 :

X1≤A1

X2≤A2

Xn1 ≤An1

我们对第 n1+1… n1+1… n1+ n2 个变量 进行一些限制 :

X_(n1+1)≥A_(n1+1)

X_(n1+2)≥A_(n1+2)

X_(n1+n2) ≥A_(n1+n2)

求:在满足这些限制的前提下, 该方程正整数解的个数。

答案可能很大,请输出对 p取模 后的答案 ,也即 答案除以 p的余数。

Input

输入含有多组数据 ,第一行两个 正整数 T,p。T表示这个测试点内的 数据 组数 ,p的含义见题目描述 。

对于每组数据,第一行 四个非负 整数 n,n1 ,n2 ,m。

第二行 n1+ n2 个正整数,表示 A1…n1+n2 。请注意,如果n1+n2等于0 ,那么 这一行会成为一个空行。

Output

共 T行,每行一个正整数 表示 取模后的答案。

Sample Input

3 10007
3 1 1 6
3 3
3 0 0 5

3 1 1 3
3 3

Sample Output

3
6
0

Data Constraint

jzoj3214. 【SDOI2013】方程_第1张图片

Hint

对于第一组数 据, 三组解为 (1,3,2 ),(1,4,1) (1,4,1),(2,3,1) 。

对于第二组 数据 ,六组解为 (1,1,3) ,(1,2,2),(1,3,1) ,(2,1,2) ,(2,2,1),(3,1,1) 。

赛时

瞎推发现只会弄p=10007的情况,而且还发现n1可以直接容斥
然后就暴力套lucas定理+exgcd。
然后就不知道为什么炸了。10分。

题解

这题是真的秒♂啊。
首先,我们发现那个限制条件奇小无比。
首先大于的条件比较好解决,直接把m减去 ∑ A 2 i \sum A2_i A2i即可
但是那个小于的条件似乎比较棘手。但是这个玩意只有8,所以可以考虑直接暴力容斥。
容斥完后我们现在要求的东西只剩下一个 C m − 1 n − 1 C_{m-1}^{n-1} Cm1n1,然后这个玩意由于模数问题和阶乘问题,不能直接暴力求。
所以要引入一个神奇的东东叫做拓展lucas。(exlucas定理)
不会的戳这里
然后就很轻松地解决了。

当然,我的程序似乎打的比较丑,所以加了个预处理优化了下。

代码

#include 
#include 
#include 
#include 
using namespace std;

int n1,n2,x[10],y[10],d[10],now,zs[100000],p[100000],q[100000],count;
int n,m,T;
long long ans,sum1,sum2,jc[1000010],an1,an2,an3,answer,mo,jcc[2];
bool bz[1000000];

inline int read() {
	int x = 0, f = 0; char c = getchar();
	while (c < '0' || c > '9') f = (c == '-') ? 1 : f, c = getchar();
	while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
	return f ? -x : x;
}

__attribute__((optimize("-O3")))
long long qsm(long long a,long long b)
{
	long long t=1;
	long long y=a;
	while (b>0)
	{
		if ((b&1)==1) t=t*y%mo;
		y=y*y%mo;
		b/=2;
	}
	return t;
}

__attribute__((optimize("-O3")))
long long f(long long n,long long mo,long long p)
{
	if (n==0) return 1;
	long long an1=1;
	long long an2=1;
	long long je1=n/p;
	long long je2=n/mo;
	if (mo==10007)
	an1=jc[mo-1];
	else
	if (mo==10201)
	an1=jcc[0];
	else
	for (long long i=1;i<=mo;i++)
	{
		if (i%p!=0)
		{
			an1=an1*i%mo;
			count++;
			
		}
	}
	an1=qsm(an1,je2)%mo;
	for (long long i=mo*je2+1;i<=n;i++)
	{
		if (i%p!=0)
		{
			an2=an2*i%mo;
		}
	}
	long long an3=f(je1,mo,p);
	long long ans=an1*an2%mo*an3%mo;
	return ans;
}

__attribute__((optimize("-O3")))
long long g(long long n,long long mo,long long p)
{
	if (n<p) return 0;
	return g(n/p,mo,p)+n/p;
}

__attribute__((optimize("-O3")))
void exgcd(long long &x1,long long &y1,long long aa,long long bb){  
	if (bb==0)
	{
		x1=1; 
		y1=0; 
		return;
	} 
	long long x0,y0;
	exgcd(x0,y0,bb,aa%bb);       
	x1=y0; y1=x0-aa/bb*y0;   
	if (x1<0) 
	{
		x1+=bb;
		y1-=aa; 
	}
	if (x1>bb) 
	{
		x1-=bb;
		y1+=aa;  
	}
}

__attribute__((optimize("-O3")))
long long ny(long long x,long long y)
{
	long long p=0;
	long long q=0;
	exgcd(p,q,x,y);
	return p;
}

__attribute__((optimize("-O3")))
long long C(long long n,long long m,long long mo,long long p)
{
	long long jl1=ny(f(n,mo,p),mo);
	long long jl2=ny(f(m-n,mo,p),mo);
	long long ans=f(m,mo,p)*jl1%mo*jl2%mo;
	long long jl3=qsm(p,g(m,mo,p)-g(n,mo,p)-g(m-n,mo,p));
	ans=ans*jl3%mo;
	return ans;
}

__attribute__((optimize("-O3")))
long long excrt(long long n,long long m,long long mo)
{
	long long kk=mo;
	now=0;
	for (int i=1;i<=zs[0];i++)
	{
		if (kk%zs[i]==0)
		{
			now++;
			p[now]=zs[i];
			q[now]=0;
		}
		while (kk%zs[i]==0) q[now]++,kk=kk/zs[i];
	}
	long long ans=0;
	for (register int i=1;i<=now;i++)
	{
		long long op=1;
		for (register int j=1;j<=q[i];j++) op=op*p[i];
		ans=(ans+C(n,m,op,p[i])*ny(mo/op,op)%mo*(mo/op)%mo)%mo;
	}
	return ans;
}

__attribute__((optimize("-O3")))
void dfs(int xx,int p)
{
	if (xx>n1)
	{
		int op=0;
		int gs=0;
		for (register int i=1;i<=n1;i++)
		{
			gs+=d[i];
			if (d[i]==1)
			{
				op+=x[i];
			}
		}
		if (p-op>=n)
		{
			int oq=excrt(n-1,p-op-1,mo);
			if (gs%2==1) ans=(ans-oq+mo)%mo;
			else ans=(ans+oq)%mo;
		}
	}
	else
	{
		d[xx]=1;
		dfs(xx+1,p);
		d[xx]=0;
		dfs(xx+1,p);
	}
}

__attribute__((optimize("-O3")))
void rc(int p)
{
	memset(d,0,sizeof(d));
	dfs(1,p);
}

__attribute__((optimize("-O3")))
int main()
{
	freopen("data.in", "r", stdin);
//	freopen(".out", "w", stdout);
	for (int i=2;i<=10007;i++)
	{
		if (bz[i]==0)
		{
			zs[0]++;
			zs[zs[0]]=i;
			bz[i]=1;
			for (int j=1;j<=1000000/i;j++)
			{
				bz[i*j]=1;
			}
		}
	}
	scanf("%d%d",&T,&mo);
	jc[0]=1;
	for (long long i=1;i<=10007;i++)
	{
		jc[i]=jc[i-1]*i%10007;
	}
	jcc[0]=1;
	for (long long i=1;i<=10200;i++)
	{
		if (i%101!=0)
		{
			jcc[0]=jcc[0]*i%10201;
		}
	}
	while (T>0)
	{
		T--;
		scanf("%d%d%d%d",&n,&n1,&n2,&m);
		sum1=0;
		for (register int i=1;i<=n1;i++)
		{
			x[i]=read();
			sum1+=x[i];
		}
		sum2=0;
		for (register int i=1;i<=n2;i++)
		{
			y[i]=read();
			sum2+=y[i];
		}
		m=m-sum2+n2;
		ans=0;
		rc(m);
		printf("%lld\n",ans);
	}
//	printf("%d\n",count);
	return 0;
}

你可能感兴趣的:(数学杂论)