给定一个 n n n 个点的树。你需要将它放到一个 2 × n 2 \times n 2×n 的网格里,每个格子至多放一个点,在树上相邻的点在网格中也必须相邻,并且 1 1 1 号点必须放在左上角。
求方案数对 1 0 9 + 7 10^9+7 109+7 取模的结果。
n ≤ 1 0 5 n \leq 10^5 n≤105
设 f u f_u fu表示把 u u u放在左上角的方案数。
以上诸如 k , w , y k,w,y k,w,y节点的查找。可以通过预处理每个点到它子树中第一个度数不是 2 2 2的点(包括这个点)及到这个点的距离 O ( 1 ) O(1) O(1)计算。
复杂度 O ( n ) O(n) O(n)。
#include
using namespace std;
typedef long long ll;
const int maxn = 300005, mod = 1e9 + 7;
inline int gi()
{
char c = getchar();
while (c < '0' || c > '9') c = getchar();
int sum = 0;
while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
return sum;
}
inline int add(int a, int b) {return a + b >= mod ? a + b - mod : a + b;}
int n, cnt;
int to[maxn][4], deg[maxn], seq[maxn], pos[maxn], fa[maxn], nxt[maxn], dis[maxn], f[maxn];
inline void adde(int u, int v)
{
if (deg[u] == 3 || deg[v] == 3) puts("0"), exit(0);
to[u][deg[u]++] = v; to[v][deg[v]++] = u;
}
void dfs(int u)
{
seq[++cnt] = u; pos[u] = n - cnt + 1;
for (int i = 0; i < deg[u]; ++i)
if (to[u][i] != fa[u]) fa[to[u][i]] = u, dfs(to[u][i]);
}
int get(int u, int a, int b, int &v)
{
int flg = 1;
for (int i = 0; i < deg[u]; ++i)
if (to[u][i] != a && to[u][i] != b) {
if (!flg) return 0;
v = to[u][i]; flg = 0;
}
return flg;
}
int main()
{
freopen("E.in", "r", stdin);
freopen("E.out", "w", stdout);
n = gi();
for (int i = 1; i < n; ++i) adde(gi(), gi());
if (deg[1] == 3) return puts("0"), 0;
dfs(1);
reverse(seq + 1, seq + n + 1);
for (int i = 1; i <= n; ++i) {
int u = seq[i], v;
if (deg[u] + (u == 1) == 2) get(u, fa[u], -1, v), nxt[u] = nxt[v], dis[u] = dis[v] + 1;
else nxt[u] = u, dis[u] = 1;
}
for (int t = 1; t <= n; ++t) {
int u = seq[t], v, w, x, y, k;
if (deg[u] + (u == 1) == 1) f[u] = 1;
else if (deg[u] + (u == 1) == 2) {
get(u, fa[u], -1, v); k = nxt[u];
//down
if (deg[v] == 1) ++f[u];
else if (deg[v] == 2) get(v, fa[v], -1, w), f[u] = add(f[u], f[w]);
//right
f[u] = add(f[u], f[v]);
if (deg[k] == 1) f[u] += (~dis[u] & 1) & (dis[u] > 2);
else {
for (int i = 0; w = to[k][i], i < deg[k]; ++i)
if (w != fa[k] && deg[nxt[w]] == 1) {
if (dis[w] == dis[u]) get(k, fa[k], w, x), f[u] = add(f[u], f[x]);
if (dis[w] == dis[u] - 2) get(k, fa[k], w, x), f[u] = add(f[u], f[x]);
} else if (w != fa[k] && deg[w] == 3) {
get(k, w, fa[k], v);
for (int j = 0; x = to[w][j], j < deg[w]; ++j)
if (x != fa[w] && deg[nxt[x]] == 1 && dis[x] == dis[u] - 1) {
get(w, fa[w], x, y);
bool flg = 0;
if (dis[y] < dis[v]) swap(y, v), flg = 1;
if (dis[y] == dis[v] && deg[nxt[y]] == 1 && deg[nxt[v]] == 1) ++f[u];
else if (dis[v] < dis[y] && deg[nxt[v]] == 1) f[u] = add(f[u], f[seq[pos[y] - dis[v]]]);
if (flg) swap(y, v);
}
}
}
} else {
for (int i = 0; v = to[u][i], i < deg[u]; ++i)
if (v != fa[u]) {
get(u, fa[u], v, w);
if (dis[v] + 1 == dis[w] && deg[nxt[v]] == 1 && deg[nxt[w]] == 1) ++f[u];
else if (dis[v] + 1 > dis[w] && deg[nxt[w]] == 1) f[u] = add(f[u], f[seq[pos[v] - dis[w] + 1]]);
else if (dis[v] + 1 < dis[w] && deg[nxt[v]] == 1) f[u] = add(f[u], f[seq[pos[w] - dis[v] - 1]]);
}
}
}
printf("%d\n", f[1]);
return 0;
}