codeforce 1332 F. Independent Set(树形 dp,树上计数问题)

codeforce 1332 F. Independent Set(树形 dp,树上计数问题)_第1张图片
codeforce 1332 F. Independent Set(树形 dp,树上计数问题)_第2张图片
codeforce 1332 F. Independent Set(树形 dp,树上计数问题)_第3张图片


题目大意:给一棵树,求这棵树的所有边生成子图的所有独立集数,边生成子图的定义为:原图拿条0条或多条边并删掉孤立点(没有和其它点连边的点)的图。


写过的为数不多的树上计数问题:

我的做法将状态分为五种: d p [ u ] [ 0 / 1 ] [ 0 / 1 ] dp[u][0/1][0/1] dp[u][0/1][0/1] r e s [ u ] res[u] res[u],分别表示:
dp[u][0/1][0/1]:以 u 为根的子树中,u 点在边生成子图中(u 点可以孤立,其它点满足边生成子图定义,允许 u 点孤立是为了合法解的转移),u 点有连子树 / 没连子树, u 点被选为独立集中的一点 / u 点不被选为独立集中的一点的方案数
res[u]:以 u 为根的子树中,u 点不在边生成子图中的方案数(显然 u 点不能选为独立集,也不能向儿子连边,所以把这个状态单独拿出来)

初始化及转移方程:
初始化:res[u] = dp[u][0][0] = dp[u][0][1] = 1res[u] = 1 代表空子图的方案数,这种状态的答案在子树合并考虑第 i 棵子树时需要用到,dp[u][0][0],dp[u][0][1] 可以赋值为 1,因为状态中定义了 u u u 点可以孤立。

res[u]:由于子树间的方案独立,用子树合并的方式,将子树中的所有合法解乘起来就得到 r e s [ u ] res[u] res[u],子树 v 的合法解为:dp[v][1][0] + dp[v][1][1] + res[v];很容易理解:v 点不包括在子图的方案数 + v 点包括在子图中的方案, v 点包括在子图中的方案里 v 不能孤立,因此要去掉 dp[v][0][0/1]

dp[u][0][0/1]:转移方程显然和 res[u] 相同,因为 u 点独立,u 是否作为独立集中的一点不会影响子树的合法解。

dp[u][1][0]:由于要和儿子连边,若 u 已经和前 i 个儿子有连边,则当前子儿子枚举连和不连进行转移,这种情况方案数为:dp[u][1][0] * (sum[v] + (dp[v][1][0] + dp[v][1][1] + res[v])),sum[v] 是 dp[v] 的四种状态之和,sum[v] 是连这个儿子的方案,(dp[v][1][0] + dp[v][1][1] + res[v]) 是不连这个儿子的方案(也就是这个儿子的合法方案),类似的,如果 u 没有和前 i 个儿子连过边,则当前儿子一定要连边,方案数为:dp[u][0][0] * sum[v]

dp[u][1][1]:和 dp[u][1][0] 类似,如果连边,因为 u u u 点已选为独立集中的一点, v v v 点则不能作为独立集中的一点,将 sum[v] 修改为 dp[v][0][0] + dp[v][1][0]

最后总答案 ans = dp[1][1][0] + dp[1][1][1] + res[1] - 1,扣掉1 扣掉的是空子图的方案,在这题中空子图不能作为一种合法解。


代码:

#include
using namespace std;
const int mod = 998244353;
typedef long long ll;
const int maxn = 3e5 + 10;
int n;
vector<int> g[maxn];
ll fpow(ll a,ll b) {
	ll r = 1;
	while (b) {
		if (b & 1) r = r * a % mod;
		b >>= 1;
		a = a * a % mod;
	}
	return r;
}
ll dp[maxn][2][2],res[maxn],tp[maxn][2][2],sum[maxn],ans[maxn];
void dfs(int u,int fa) {
	// dp[u][0/1][0/1]:以 u 为根的子树中,u 点比存在的子图的答案 
	// res[u] :以 u 为根的子树中,u 点不存在的子图的答案 
	res[u] = 1;			// 代表空集,也作为一种答案 
	dp[u][0][0] = dp[u][0][1] = 1;	// 只有自己一个点 
	sum[u] = 0;
	for (auto it : g[u]) {
		if (it == fa) continue;
		dfs(it,u);
		res[u] = (res[u] * (dp[it][1][0] + dp[it][1][1] + res[it]) % mod + mod) % mod;
		for (int i = 0; i < 2; i++)			
			for (int j = 0; j < 2; j++)			 
				tp[u][i][j] = dp[u][i][j];
		for (int i = 0; i < 2; i++) {				//选不选儿子 
			for (int j = 0; j < 2; j++) {			//u点要不要选 
				if (i == 0) {
					tp[u][i][j] = (dp[u][i][j] * (dp[it][1][0] + dp[it][1][1] + res[it]) % mod) % mod;
				} else {
					if (j == 0) {
						tp[u][i][j] = (dp[u][i][j] * (sum[it] + (dp[it][1][0] + dp[it][1][1] + res[it])) % mod + 
										dp[u][0][j] * sum[it] % mod) % mod;
					} else {
						tp[u][i][j] = (dp[u][i][j] * (dp[it][0][0] + dp[it][1][0] + dp[it][1][0] + dp[it][1][1] + res[it]) % mod + 
										dp[u][0][j] * (dp[it][0][0] + dp[it][1][0]) % mod) % mod;
					}
				}
			}
		}
		for (int i = 0; i < 2; i++)			
			for (int j = 0; j < 2; j++)			 
				dp[u][i][j] = tp[u][i][j];		
	}
	for (int i = 0; i < 2; i++)
		for (int j = 0; j < 2; j++)
			sum[u] = (sum[u] + dp[u][i][j]) % mod;
}
int main() {
	scanf("%d",&n);
	for (int i = 1; i < n; i++) {
		int u,v; scanf("%d%d",&u,&v);	
		g[u].push_back(v);
		g[v].push_back(u);
	}
	dfs(1,0);
	printf("%lld\n",(dp[1][1][0] + dp[1][1][1] + res[1] - 1 + mod) % mod);			//减一是减掉空集
	return 0;
}

你可能感兴趣的:(计数类DP)