给定一棵树,每个点上有一种颜色。求包含所有颜色的路径数。
颜色数<=10
求树上路径条数,点分治跑不了了。
关键在于怎么快速统计。其实就是需要快速求出能够使得并集为全集的路径数。
这就要处理出路径颜色的子集了。
用高维前缀和搞一搞。
那么这是个什么东西呢。其实就是个状压。具体来说似乎我知道有以下几种写法。
假设我们要处理一个状态S的子集的存在性,那么可以这么写。
for(int k=0;k<=S;++k) if((S|k)==S) exist[k]=1;
或者
for(int j=S;j>=0;--j) for(int i=0;i//k为最高位数
if((1<1<
或:
for(int i=0;ifor(int j=S;j>=0;--j){
if(!((1<1<
P.S. :显然写第一种(因为只有一个状态)
一开始有多个状态时先初始化,然后写下面的两种方法中的一种方法。
但一般我们要用到的是统计一个集合的子集状态数量或超集状态数量。
这个时候这么写:
for(int k=0;k<=S;++k) if((S|k)==S) cnt[k]+=cnt[S];//适用于状态一个个加入
或着
//k为最高位数
for(int i=0;ifor(int j=S;j>=0;--j){ //适用于状态一次性加入
if(!((1<1<
也可以写成这样:
for(int j=S;j>=0;--j) for(int i=0;iif((1<1<
P.S. 如果一次次加但加入较多状态时可新用一个数组存起来,再加 到原统计数组上面(不然第二、三种方法就肯定算重了)。不过求方便就直接写第一种了,当然要保证不 TLE。
会统计子集了这题就秒了,上代码:
#include
#include
#include
#include
#include
#include
#include
#include
#define Set(a,b) memset(a,b,sizeof(a))
using namespace std;
const int N=5e4+10;
const int MAXN=(1<<10)+100;
typedef long long ll;
struct edge{
int to,next;
}a[N<<1];
int head[N];int cnt=0;
int val[N];
inline void add(int x,int y)
{
a[++cnt]=(edge){y,head[x]};head[x]=cnt;
}
int n,K;int SZ;
int rt;int size[N];
int f[N];bool vis[N];
void Find(int u,int fa)
{
size[u]=1;f[u]=0;
for(register int v,i=head[u];i;i=a[i].next){
v=a[i].to;
if(v==fa||vis[v]) continue;
Find(v,u);
size[u]+=size[v];
f[u]=max(f[u],size[v]);
}
f[u]=max(f[u],SZ-size[u]);
if(rt==-1||f[u]return;
}
ll ans=0;
int dp[MAXN];
int st[N];int top=0;
int full;
void dfs(int u,int S,int fa)
{
if(vis[u]) return;
st[++top]=S;
for(register int v,i=head[u];i;i=a[i].next)
{
v=a[i].to;
if(v==fa) continue;
dfs(v,S|(1<return;
}
inline void Sum()
{
for(register int i=1;i<=top;++i)
for(register int k=0;k<=st[i];++k)
if((st[i]|k)==st[i]) dp[k]++;//好像跑得蛮快的
}
void Div(int u)
{
if(vis[u]) return;vis[u]=1;
for(register int i=1;i<=full;++i) dp[i]=0;dp[1<1;dp[0]=1;
for(register int v,i=head[u];i;i=a[i].next){
v=a[i].to;if(vis[v]) continue;top=0;
dfs(v,(1<1<0);
for(register int j=1;j<=top;++j) ans+=dp[full^st[j]];
Sum();
}
for(register int v,i=head[u];i;i=a[i].next){
v=a[i].to;
if(vis[v]) continue;
SZ=size[v];rt=-1;
Find(v,0);
Div(rt);
}
return;
}
int main()
{
while(scanf("%d",&n)!=EOF){
scanf("%d",&K);Set(head,0);cnt=0;rt=-1;Set(vis,0);Set(dp,0);ans=0;
for(register int i=1;i<=n;++i) scanf("%d",&val[i]),--val[i];
register int x,y;full=(1<1;
for(register int i=1;iscanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
if(K==1){printf("%lld\n",1ll*n*n);continue;}
SZ=n;
Find(1,0);
Div(rt);
printf("%lld\n",ans*2);
}
}