[loj2340][FWT][子集卷积]州区划分

Description

传送门

题解

看懂题需要一会…
朴素的dp就可以列出一个方程
f [ m a s k ] = 1 r [ i ] p ∑ j ∣ k = m a s k f [ j ] ∗ r [ k ] p f[mask]=\frac{1}{r[i]^p}\sum_{j|k=mask} f[j]*r[k]^p f[mask]=r[i]p1jk=maskf[j]r[k]p
其中 r [ i ] r[i] r[i]表示 i i i状态下的人数
那么暴力枚举子集就是 3 n 3^n 3n的噜
然后发现这其实是一个裸的子集卷积
于是就可以直接卷一下完事了…

#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define LL long long
#define mp(x,y) make_pair(x,y)
#define pll pair
#define pii pair
using namespace std;
inline int read()
{
	int f=1,x=0;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
int stack[20];
inline void write(int x)
{
	if(x<0){putchar('-');x=-x;}
    if(!x){putchar('0');return;}
    int top=0;
    while(x)stack[++top]=x%10,x/=10;
    while(top)putchar(stack[top--]+'0');
}
inline void pr1(int x){write(x);putchar(' ');}
inline void pr2(int x){write(x);putchar('\n');}
const int MAXN=22;
const int mod=998244353;
const int MAXMASK=(1<<21);
int A[MAXN],B[MAXN];
void fwt(int *y,int len,int on)
{
	for(int i=1;i<len;i<<=1)
		for(int j=0;j<len;j+=(i<<1))
			for(int k=0;k<i;k++)
			{
				if(on==1)y[j+k+i]=(y[j+k]+y[j+k+i])%mod;
				else y[j+k+i]=(y[j+k+i]-y[j+k]+mod)%mod;
			}
}
int pow_mod(int a,int b)
{
	int ret=1;
	while(b)
	{
		if(b&1)ret=1LL*ret*a%mod;
		a=1LL*a*a%mod;b>>=1;
	}
	return ret;
}
int f[MAXN][MAXMASK],ok[MAXMASK],bin[25],w[MAXN],r[MAXMASK],ct[MAXMASK];

int C[MAXN][MAXMASK],inv[MAXMASK];
int n,m,P,mp[MAXN][MAXN];
int du[MAXN],rt[MAXN];
int findrt(int x){return rt[x]==x?rt[x]:rt[x]=findrt(rt[x]);}
void ad(int &x,int y){x+=y;if(x>=mod)x-=mod;}
int main()
{
//	freopen("walk2.in","r",stdin);
	bin[0]=1;for(int i=1;i<=21;i++)bin[i]=bin[i-1]<<1;
	n=read();m=read();P=read();
	for(int i=1;i<=m;i++)
	{
		int x=read(),y=read();
		mp[x][y]++;mp[y][x]++;
	}
	for(int i=1;i<=n;i++)w[i]=read();
	for(int i=0;i<bin[n];i++)
	{
		ok[i]=1;memset(du,0,sizeof(du));
		int cnt=0;
		for(int j=1;j<=n;j++)rt[j]=j;
		for(int j=1;j<=n;j++)if(i&bin[j-1])r[i]+=w[j],cnt++,ct[i]++;
		for(int j=1;j<=n;j++)if(i&bin[j-1])
			for(int k=j+1;k<=n;k++)if(i&bin[k-1]&&mp[j][k])
			{
				du[j]++,du[k]++;
				int u=findrt(j),v=findrt(k);
				if(u!=v)rt[u]=v,cnt--;
			}
		if(ct[i]==1)ok[i]=0;
		bool tf=false;
		for(int j=1;j<=n;j++)if(du[j]&1){tf=true;break;}
		tf|=(cnt!=1);
		ok[i]&=tf;
		r[i]=pow_mod(r[i],P);
		if(ok[i])ad(C[ct[i]][i],r[i]);
		inv[i]=pow_mod(r[i],mod-2);
	}
	for(int i=1;i<=n;i++)fwt(C[i],bin[n],1);
	f[0][0]=1;fwt(f[0],bin[n],1);
	int ans;
	for(int i=1;i<=n;i++)
	{
		for(int j=0;j<i;j++)
			for(int k=0;k<bin[n];k++)ad(f[i][k],1LL*f[j][k]*C[i-j][k]%mod);
		fwt(f[i],bin[n],-1);
		for(int j=0;j<bin[n];j++)f[i][j]=1LL*f[i][j]*inv[j]%mod;
//		if(i==n)ans=f[i][bin[n]-1];
		if(i!=n)fwt(f[i],bin[n],1);
//			for(int k=0;k
	}
	pr2(f[n][bin[n]-1]);
	return 0;
}

你可能感兴趣的:(loj,FWT)