有一棵\(n(1 \leq n \leq 13)\)个节点的树,节点的标号为\(1 \sim n\),它的根节点是\(1\)。
现在已知它的\(m(0 \leq m < n)\)条边,和\(q(0 \leq q \leq 100)\)个\(LCA\)的关系:\(LCA(a_i, \, b_i)=c_i\)
求满足这些要求的树的个数。
为了方便,用\(0 \sim n-1\)来表示树的节点。
用\(d(root, \, mask)\)表示以\(root\)为根,选了\(mask\)这些点,而且满足题中所有要求的树的个数。
其中\(mask\)为所选的这些点的二进制表示。
那么所求的答案为\(d(0, \, 2^n-1)\)
边界情况是,树中只有一个点时,\(f(root, \, mask) = 1\)。
状态转移方程:
\(d(root, \, mask) = \sum ( d(newRoot, \, newMask) \times d(root, mask \bigoplus newMask) )\)
\(\bigoplus\)表示异或运算
其中\(d(newRoot, \, newMask)\)是我们枚举的\(d(root, \, mask)\)状态下的一个子树。
\(newRoot\)是该子树的根,\(newMask\)是子树节点的集合。
其他子树和\(root\)合起来的状态数就是\(d(root, mask \bigoplus newMask)\)。
对于同一棵树,我们枚举了一次它的第一棵子树,接着又枚举了它第二棵子树,这样会有重复计算。
所以我们规定一个特殊点,每次只枚举这个特殊点所在子树的集合。
我们可以规定这个特殊点就是\(mask\)中除\(root\)外,编号最小(或最大)的点。
状态转移条件:
只有满足题中的限制,状态才能转移,所以我们要去掉转移时不符合要求的情况:
最后分析(YY)一下官方题解中时间复杂度\(3^n\)是怎么来的。
因为状压以后有一个\(2^n\),然后考虑每次转移:
对于\(bitcount(mask) = k\),他一共有\(2^k\)个子集,这样有\(k\)个元素的集合有\(C_n^k\)个。
所以时间复杂度为\(\sum\limits_{k=1}^{n}(C_n^k \cdot 2^k)=(2+1)^n=3^n\)
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int maxn = 15;
const int maxs = 10000;
const int maxq = 100 + 10;
int n, m, q;
LL d[maxn][maxs];
int edge[maxn][maxn];
int a[maxq], b[maxq], c[maxq];
int lowbit(int x) { return x & (-x); }
int in(int i, int S) { return ((S >> i) & 1); }
LL DP(int u, int S) {
LL& ans = d[u][S];
if(ans != -1) return ans;
ans = 0;
int St = S - (1 << u);
int t;
for(t = 0; t < n; t++) if(in(t, St)) break;
for(int _S = St; _S; _S = (_S-1)&St) if(in(t, _S)) {
bool flag = true;
for(int i = 0; i < n; i++) if(i != u) {
for(int j = 0; j < n; j++) if(j != u) {
if(edge[i][j] && (in(i, _S) ^ in(j, _S))) {
flag = false;
break;
}
}
if(!flag) break;
}
if(!flag) continue;
int v, cnt = 0;
for(int i = 0; i < n; i++) {
if(edge[u][i] && in(i, _S)) {
cnt++;
v = i;
}
}
if(cnt >= 2) continue;
for(int i = 0; i < q; i++) {
if(c[i] == u && in(a[i], _S) && in(b[i], _S)) {
flag = false; break;
}
if(in(c[i], _S) && (!in(a[i], _S) || !in(b[i], _S))) {
flag = false; break;
}
}
if(!flag) continue;
if(cnt == 1) {
ans += DP(v, _S) * DP(u, S - _S);
} else {
for(v = 0; v < n; v++) if(in(v, _S))
ans += DP(v, _S) * DP(u, S - _S);
}
}
return ans;
}
int main()
{
scanf("%d%d%d", &n, &m, &q);
for(int i = 0; i < m; i++) {
int u, v; scanf("%d%d", &u, &v);
u--; v--;
edge[u][v] = edge[v][u] = 1;
}
for(int i = 0; i < q; i++) {
scanf("%d%d%d", a + i, b + i, c + i);
a[i]--; b[i]--; c[i]--;
}
int all = (1 << n) - 1;
memset(d, -1, sizeof(d));
for(int i = 0; i < n; i++) d[i][1 << i] = 1;
printf("%lld\n", DP(0, all));
return 0;
}