题目大意:统计一棵无根树的DFS序中字典序小于B的方案数。
我们先考虑这样一个问题:一棵有根树,从根开始的DFS序有几种。
显然一棵树的DFS序可以视为一个排列。
我们发现一颗有根树的DFS序数与所有节点的儿子数有关。
如下图,du数组记录儿子数量。
DFS是深度优先,必须遍历完当前子树才能离开当前子树,所以在DFS序中,每颗子树是连续的。
设f[x]表示遍历以x为根的子树的方案数,du为当前节点的儿子数,S为当前节点的儿子集合,则有:
$f[x]=A_{du}^{du}* \prod_{y \in S}dp[y]$
进而发现,除了根结点,所有节点的儿子数量为其度数减一,设每个点的度数为d,则以x为根时DFS序的个数个g[x]为:
$g[x]=A_{d[x]}^{d[x]}* \prod _{i!=x} A_{d[i]-1}^{d[i]-1}$
打表求阶乘及逆元,$O(N)$算出g[1],然后发现,换根时改变的仅有两个点,可以$O(1)$求出以任意点为根的方案数。
然后考虑有限制的情况。
我们可以用类似数位DP的思想进行统计。
DFS遍历整棵树,记录一个时间戳dfn,表示当前已匹配到了B数组中的第dfn个。
每次在当前节点的子结点中搜索B数组中的下一个元素,如果没有搜到,并且当前节点还有未遍历过的点,则匹配失败,立即结束搜索,若当前节点已全部遍历,跳回到父节点继续搜索。
整个搜索的过程可以视为删除其中的一些边,然后在删除的树上统计方案数。
用类似上文的思想,当一条边被删去时,父节点的儿子数减一,子节点的儿子数不变。
用全局变量res记录方案数,设删除的边指向的父节点为x,其剩下的儿子个数为du[x],则:
$res=res \div du[x]=res*(A_{du[x]}^{du[x]})^{-1}*A_{du[x]-1}^{du[x]-1}$
用数位DP的思想,我们要统计所有字典序小于B的方案,所以合法的DFS序只有当前面有一位比B小时,后面才可以比B大。
设DFS序构成一个数列A,则若A中有一位小于B,就会对下面解除限制,就求出当前节点中小于B[dfn+1]的子结点数量,乘上res累加到ans里。
由于要查询一些数中小于某个数的数的个数,还要支持动态修改,也就是删边所以要用数据结构维护。
可用的数据结构有树状数组维护vector,动态开点权值线段树,c++扩展库pb_ds,或者手打平衡树(其实用不着平衡树,普通BST就能解决)。
由于一个节点可以被多次扩展,所以要加个while(1)之类的。
若搜不到结果,标记一个全局变量op,一路回溯至结束,不然受while(1)影响,DFS可能继续扩展。
最重要的一点,别忘了取模。
时间复杂度$O(N*log_2N)$
Code:
1 #include2 #include 3 #include 4 #include 5 #define LL long long 6 #define rint register int 7 using namespace std; 8 const int N=300010; 9 const LL mod=1e9+7; 10 const LL modd=2333333; 11 int n,m=0,dfn=0,op=0; 12 LL ans=0,res=1,now; 13 int fi[N],b[N],du[N],rt[N]; 14 LL dp[N],jc[N],inv[N]; 15 struct edge{ 16 int v,ne; 17 }e[N<<1]; 18 struct node{ 19 int lc,rc,val,si,w; 20 }; 21 struct BST{ 22 int cnt; 23 node t[N]; 24 void update(int p) 25 { 26 t[p].si=t[p].w; 27 if(t[p].lc!=0) t[p].si+=t[t[p].lc].si; 28 if(t[p].rc!=0) t[p].si+=t[t[p].rc].si; 29 } 30 void insert(int &p,int val) 31 { 32 if(p==0){ 33 p=++cnt;t[p].val=val; 34 t[p].w=t[p].si=1; 35 t[p].lc=t[p].rc=0; 36 return; 37 } 38 if(t[p].val>val) insert(t[p].lc,val); 39 else insert(t[p].rc,val); 40 update(p); 41 } 42 void remove(int p,int val) 43 { 44 if(t[p].val==val){ 45 t[p].w=0; 46 update(p); 47 return; 48 } 49 if(t[p].val>val) remove(t[p].lc,val); 50 else remove(t[p].rc,val); 51 update(p); 52 } 53 int find(int p,int val){ 54 if(p==0) return 0; 55 else if(t[p].val==val) return t[t[p].lc].si; 56 else if(t[p].val>val) return find(t[p].lc,val); 57 else return t[t[p].lc].si+t[p].w+find(t[p].rc,val); 58 } 59 }s; 60 struct hash_map{ 61 int cnt; 62 int he[modd],ne[N],w[N]; 63 LL key[N]; 64 int find(LL x) 65 { 66 LL pos=x%modd; 67 for(int i=he[pos];i!=0;i=ne[i]){ 68 if(key[i]==x) return w[i]; 69 } 70 return 0; 71 } 72 void insert(LL x) 73 { 74 LL pos=x%modd; 75 key[++cnt]=x;w[cnt]=1; 76 ne[cnt]=he[pos];he[pos]=cnt; 77 } 78 void change(LL x) 79 { 80 LL pos=x%modd; 81 for(int i=he[pos];i!=0;i=ne[i]){ 82 if(key[i]==x){ 83 w[i]=2;return; 84 } 85 } 86 } 87 }h; 88 inline void add(int x,int y) 89 { 90 e[++m].v=y; 91 e[m].ne=fi[x];fi[x]=m; 92 } 93 inline int read() 94 { 95 int s=0;char c=getchar(); 96 while(c<'0'||c>'9') c=getchar(); 97 while(c>='0'&&c<='9'){ 98 s=(s<<3)+(s<<1)+c-'0'; 99 c=getchar(); 100 } 101 return s; 102 } 103 inline LL qpow(LL x,LL y) 104 { 105 LL ans=1; 106 while(y>0){ 107 if((y&1)==1) ans=ans*x%mod; 108 x=x*x%mod; 109 y>>=1; 110 } 111 return ans; 112 } 113 void dfs1(int x,int p) 114 { 115 int cn=0;dp[x]=1; 116 for(rint i=fi[x];i!=0;i=e[i].ne){ 117 int y=e[i].v; 118 if(y==p) continue; 119 h.insert((LL)N*x+y); 120 s.insert(rt[x],y); 121 dfs1(y,x); 122 cn++;dp[x]=dp[x]*dp[y]%mod; 123 } 124 dp[x]=dp[x]*jc[cn]%mod; 125 } 126 void dfs2(int x,int p) 127 { 128 LL p1=0,p2=0; 129 for(rint i=fi[x];i!=0;i=e[i].ne){ 130 int y=e[i].v; 131 if(y==p) continue; 132 p1++; 133 } 134 while(1){ 135 if(p1==0) return; 136 p2=s.find(rt[x],b[dfn+1]); 137 now=now*inv[p1]%mod*jc[p1-1]%mod; 138 ans=(ans+now*p2%mod)%mod; 139 if(h.find((LL)N*x+b[dfn+1])!=0){ 140 dfn++;p1--; 141 h.change((LL)N*x+b[dfn]); 142 s.remove(rt[x],b[dfn]); 143 dfs2(b[dfn],x); 144 if(op==1) return; 145 } 146 else{ 147 op=1;return; 148 } 149 } 150 } 151 int main() 152 { 153 n=read();h.cnt=s.cnt=s.t[0].si=0; 154 for(rint i=1;i<=n;i++) 155 b[i]=read(); 156 jc[1]=inv[1]=jc[0]=inv[0]=1; 157 for(rint i=2;i<=n;i++){ 158 int x=read(),y=read(); 159 add(x,y);add(y,x); 160 du[x]++;du[y]++; 161 jc[i]=jc[i-1]*(LL)i%mod; 162 inv[i]=qpow(jc[i],mod-2); 163 } 164 for(rint i=1;i<=n;i++){ 165 if(i==1) res=res*jc[du[i]]%mod; 166 else res=res*jc[du[i]-1]%mod; 167 } 168 if(b[1]!=1) ans=(ans+res)%mod; 169 for(rint i=2;i1];i++){ 170 res=res*inv[du[i-1]]%mod; 171 res=res*inv[du[i]-1]%mod; 172 res=res*jc[du[i-1]-1]%mod; 173 res=res*jc[du[i]]%mod; 174 ans=(ans+res)%mod; 175 } 176 dfs1(b[1],0); 177 now=dp[b[1]];dfn++; 178 dfs2(b[1],0); 179 printf("%lld\n",ans); 180 return 0; 181 }
PS:我打的hash表和BST,别的数据结构也行。BST常数小,方便打,如果数据没有特殊构造时推荐使用。