题目大意:给出一棵树,求异或和为[0..m-1]的非空连通子图的个数。
FWT+树形DP
f[i][j] 表示以i为根异或和为j的连通子树的个数(注意必须是i的子树中)
f[x][j ^ k]=f[x][j ^ k]+f[x][j]∗f[son][k] 这个转移方程的瓶颈在于 f[x][j]∗f[son][k] ,转移是 O(m2)
可以发现转移实际上就是异或卷积,可以用FWT优化。
FWT这东西第一次接触,感觉原理就算懂了也会忘,所以直接背的板子。。。
时间复杂度 O(nmlogm)
#include
#include
#include
#include
#include
#define N 2003
#define p 1000000007
#define LL long long
using namespace std;
const LL ret=(p+1)/2;
LL dp[N][N],ans[N],tmp[N];
int tot,n,m,nxt[N],v[N],point[N],val[N];
void add(int x,int y)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void FWT(LL *a,int n)
{
for (int i=1;i1)
for (int p1=i<<1,j=0;jfor (int k=0;kx=a[j+k]; LL y=a[j+k+i];
a[j+k]=(x+y)%p;
a[j+k+i]=(x-y+p)%p;
}
}
void UFWT(LL *a,int n)
{
for (int i=1;i1)
for (int p1=i<<1,j=0;jfor (int k=0;kx=a[j+k]; LL y=a[j+k+i];
a[j+k]=(x+y)%p*ret%p;
a[j+k+i]=((x-y)*ret%p+p)%p;
}
}
void solve(LL *a,LL *b,int n)
{
FWT(a,n); FWT(b,n);
for (int i=0;i*b[i]%p;
UFWT(a,n);
}
void dfs(int x,int fa)
{
dp[x][val[x]]=1;
for (int i=point[x];i;i=nxt[i]) {
if (v[i]==fa) continue;
dfs(v[i],x);
for (int j=0;j<m;j++) tmp[j]=dp[x][j];
solve(dp[x],dp[v[i]],m);
for (int j=0;j<m;j++)
dp[x][j]=(tmp[j]+dp[x][j])%p;
}
for (int i=0;i<m;i++)
ans[i]=(ans[i]+dp[x][i])%p;
}
int main()
{
freopen("a.in","r",stdin);
// freopen("my.out","w",stdout);
int T; scanf("%d",&T);
while (T--) {
tot=0;
memset(point,0,sizeof(point));
memset(dp,0,sizeof(dp));
memset(ans,0,sizeof(ans));
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%d",&val[i]);
for (int i=1;iint x,y; scanf("%d%d",&x,&y);
add(x,y);
}
dfs(1,0);
for (int i=0;i<m-1;i++) printf("%I64d ",ans[i]);
printf("%I64d\n",ans[m-1]);
}
}
点分治+树形DP
对于每个节点,点分治到他的时候直接做树形依赖即可。
令 f[son][j ^ val[son]]=f[x][j] ,带入计算,最后用计算后f[x]数组更新总答案
这样每个点都之后被计算logn次,时间复杂度 O(nmlogn)
#include
#include
#include
#include
#include
#define N 2003
#define p 1000000007
#define LL long long
using namespace std;
LL dp[N][N],ans[N];
int tot,n,m,nxt[N],v[N],point[N],val[N],f[N],size[N],root,vis[N],sum;
void add(int x,int y)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void getroot(int x,int fa)
{
f[x]=0; size[x]=1;
for (int i=point[x];i;i=nxt[i]) {
if (v[i]==fa||vis[v[i]]) continue;
getroot(v[i],x);
size[x]+=size[v[i]];
f[x]=max(f[x],size[v[i]]);
}
f[x]=max(f[x],sum-size[x]);
if (f[x]void calc(int x,int fa)
{
for (int i=point[x];i;i=nxt[i]) {
if (v[i]==fa||vis[v[i]]) continue;
for (int j=0;jfor (int j=0;jfor (int j=0;j0;
}
}
void solve(int x)
{
vis[x]=1;
dp[x][val[x]]=1; calc(x,0);
for (int i=0;ifor (int i=0;i0;
for (int i=point[x];i;i=nxt[i]) {
if (vis[v[i]]) continue;
sum=size[v[i]]; root=0;
getroot(v[i],x);
solve(root);
}
}
int main()
{
freopen("a.in","r",stdin);
// freopen("my.out","w",stdout);
int T; scanf("%d",&T);
while (T--) {
tot=0;
memset(point,0,sizeof(point));
memset(dp,0,sizeof(dp));
memset(ans,0,sizeof(ans));
memset(vis,0,sizeof(vis));
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%d",&val[i]);
for (int i=1;iint x,y; scanf("%d%d",&x,&y);
add(x,y);
}
f[0]=p; root=0; sum=n;
getroot(1,0);
solve(root);
for (int i=0;i1;i++) printf("%I64d ",ans[i]);
printf("%I64d\n",ans[m-1]);
}
}