【UOJ348】【WC2018】州区划分 状压DP FWT

题目大意

  给定⼀个 n n 个点的⽆向图,对于⼀种 n n 个点的划分 {S1,S2,,Sk} { S 1 , S 2 , … , S k } ,定义它是合法的,当且仅当每个点都在其中的一个集合中且对于任何的 i[1,k] i ∈ [ 1 , k ] ,点集 Si S i ⾮空,且导出⼦图不存在欧拉回路。

  给定数组 wi w i ,求对于所有合法的划分 {s1,s2,,sk} { s 1 , s 2 , … , s k } ,下面的式子之和

(i=1kxSiwxij=1xSjwx)p ( ∏ i = 1 k ∑ x ∈ S i w x ∑ j = 1 i ∑ x ∈ S j w x ) p

   n21 n ≤ 21

题解

  先用 O(n22n) O ( n 2 2 n ) 判断每个点集是否合法,并计算 g(S)=(xSwx)p g ( S ) = ( ∑ x ∈ S w x ) p 。如果 S S 不合法。那么 g(S)=0 g ( S ) = 0

  很容易想到一个 O(3n) O ( 3 n ) 的做法。

  设 f(S) f ( S ) 为当前选择的集合为 S S 的答案

f(S)=TSg(T)f(ST)g(S) f ( S ) = ∑ T ⊆ S g ( T ) f ( S ∖ T ) g ( S )

  可以发现这是一个子集卷积。

  一种可行的做法是把子集卷积转化为子集或卷积。

  定义 f˜ f ~ f f 的集合占位幂级数,当且仅当对于所有的 S S 满足 f˜(S) f ~ ( S ) 是一个 |S| | S | 次多项式,且 [x|S|]f˜(S)=f(S) [ x | S | ] f ~ ( S ) = f ( S )

  可以发现,若 p(S)=f˜(S)g˜(S) p ( S ) = f ~ ( S ) ⋅ g ~ ( S ) ,则 p p f×g f × g 的占位多项式。

  所以我们可以在 O(n22n) O ( n 2 2 n ) 内计算子集卷积了。

  回到这道题,我们假设每个城市可以出现在多个州里,记 hi,S h i , S 为每个州的城市个数之和为 i i ,每个州的城市的并集为 S S 的方案数。那么 F(S)=ni=1hi,Sxi F ( S ) = ∑ i = 1 n h i , S x i 就是 f(S) f ( S ) 的集合占位幂级数。

  所以 f(S)=h|S|,S f ( S ) = h | S | , S ,状态转移方程是

hi,S=j=1i|T|=jA[A|T=S]hij,Ag(T)g(S) h i , S = ∑ j = 1 i ∑ | T | = j ∑ A [ A | T = S ] h i − j , A g ( T ) g ( S )

  记 G(i)=|S|=ig(S)xS G ( i ) = ∑ | S | = i g ( S ) x S

  我们先枚举 i i ,再枚举 j j ,然后计算 hi=hij×G(j) h i = h i − j × G ( j )

  这里我们可以全程用莫比乌斯变换后的值,卷积一次就是 O(2n) O ( 2 n )

  还有,我们这里要除以 g(S) g ( S ) ,可以在做完一层( hi h i )之后变换回去,除以 g(S) g ( S ) ,再变换回来。

  这样时间复杂度就是 O(n22n) O ( n 2 2 n ) 的了。

代码

#include
#include
#include
#include
#include
#include
#include
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
void open(const char *s)
{
#ifndef ONLINE_JUDGE
    char str[100];
    sprintf(str,"%s.in",s);
    freopen(str,"r",stdin);
    sprintf(str,"%s.out",s);
    freopen(str,"w",stdout);
#endif
}
const ll p=998244353;
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
int lx[100010];
int ly[100010];
int n,m,o;
void fwt(int *a)
{
    int i,j;
    for(j=1;j<1<1)
        for(i=0;i<1<if(i&j)
                a[i]=(a[i]+a[i^j])%p;
}
void ifwt(int *a)
{
    int i,j;
    for(j=1;j<1<1)
        for(i=0;i<1<if(i&j)
                a[i]=(a[i]-a[i^j])%p;
}
int f[22][1<<21];
int g[22][1<<21];
ll inv[100010];
ll fw[1<<21];
ll fwi[1<<21];
int w[30];
int fa[100010];
int num[1<<21];
int find(int x)
{
    return fa[x]==x?x:fa[x]=find(fa[x]);
}
int c[1<<21];
int d[100010];
int main()
{
    open("walk");
    int i,j,k;
    inv[0]=inv[1]=1;
    for(i=2;i<=10000;i++)
        inv[i]=-p/i*inv[p%i]%p;
    scanf("%d%d%d",&n,&m,&o);
    for(i=1;i<=m;i++)
        scanf("%d%d",&lx[i],&ly[i]);
    for(i=1;i<=n;i++)
        scanf("%d",&w[i]);
    int all=(1<1;
    for(i=1;i<=all;i++)
    {
        for(j=1;j<=n;j++)
        {
            d[j]=0;
            fa[j]=j;
        }
        for(j=1;j<=m;j++)
            if(((i>>(lx[j]-1))&1)&&((i>>(ly[j]-1))&1))
            {
                d[lx[j]]^=1;
                d[ly[j]]^=1;
                int fx=find(lx[j]);
                int fy=find(ly[j]);
                if(fx!=fy)
                    fa[fx]=fy;
            }
        fw[i]=0;
        for(j=1;j<=n;j++)
            if((i>>(j-1))&1)
                fw[i]+=w[j];
        fwi[i]=inv[fw[i]];
        fw[i]=fp(fw[i],o);
        fwi[i]=fp(fwi[i],o);
        int cnt=0;
        c[i]=0;
        for(j=1;j<=n;j++)
            if((i>>(j-1))&1)
            {
                num[i]++;
                if(fa[j]==j)
                {
                    cnt++;
                    if(cnt>=2)
                        c[i]=1;
                }
                if(d[j])
                    c[i]=1;
            }
        fw[i]*=c[i];
        g[num[i]][i]=fw[i];
    }
    f[0][0]=1;
    for(i=1;i<=n;i++)
        fwt(g[i]);
    fwt(f[0]);
    for(i=1;i<=n;i++)
    {
        for(j=1;j<=i;j++)
            for(k=0;k<=all;k++)
                f[i][k]=(f[i][k]+(ll)f[i-j][k]*g[j][k])%p;
        ifwt(f[i]);
        for(j=0;j<=all;j++)
            f[i][j]=f[i][j]*fwi[j]%p;
        fwt(f[i]);
    }
    ifwt(f[n]);
    ll ans=f[n][all];
    ans=(ans+p)%p;
    printf("%lld\n",ans);
    return 0;
}

你可能感兴趣的:(FWT,DP--状压DP,DP)