SOSDP顾名思义就是,救命啊是DP Sum over Subsets(SOS)DP
本文约等于https://codeforces.com/blog/entry/45223的机翻中文版
给一个2^n长度的数组A,现在对于任意x要预处理出函数F(x)的返回值。
F(x)的定义:SUM(A[i] | x & i == i)
即 i 的二进制表示被x包含,F(x)返回所有满足条件 i 的A[ i ]总和。
//枚举每一种x
for(int mask = 0;mask < (1<<N); ++mask){
//枚举每一个i
for(int i = 0;i < (1<<N); ++i){
if((mask&i) == i){
F[mask] += A[i];
}
}
}
这种解法枚举了很多无用的i,对于给定的x,可以用更高效的方法枚举其包含的二进制数。
//依旧枚举所有的x
for (int mask = 0; mask < (1<<n); mask++){
F[mask] = A[0];
// 枚举了所有有效的i
for(int i = mask; i > 0; i = (i-1) & mask){
F[mask] += A[i];
}
}
复杂度证明:枚举i含有k个1位,则有C(n,k)种可能,每种可能有2^k种情况。
暴力的做法最大的浪费在于把x的二进制位混在一起枚举。
要把枚举的x,按照二进制拆分。
设计dp状态:
dp[被数字x包含][且最右边y位和x相同] = 的下标贡献总和
初始状态:
dp[i][0] = A[i];
转移方程:
for(int i = 1 ; i < (1 << maxn) ; i++){
for(int j = 1 ; j < maxn ; j++){
//如果j位是1,例如??1xx,则包含了??(1)xx,??(0)xx
if(i & (1 << (j-1)))
dp[i][j] = dp[i ^ (1 << (j-1))][j - 1] + dp[i][j - 1];
else
dp[i][j] = dp[i][j - 1];
}
}
原作者在举例这个dp时下标到了-1,(因为只是为了过度并不打算实现)
下图为注释图,逗号右边的数字比我实现的小1
可以发现无论第j位的数字是0或者1,都会继承j - 1位的状态。利用滚动数组的思想(滚动的是位数,而不是枚举的数字x),可以把dp优化成1维的,且非常容易实现。
//滚动数组优化而来
for(int j = 1 ; j < maxn ; j++){
for(int i = 1 ; i < (1 << maxn) ; i++){
if(i & (1 << (j-1)))dp[i] += dp[i ^ (1 << (j-1))];
}
}
思考:
因为滚动背包优化后,外循环是当前位数,内循环是枚举的数字。
数字应该从大到小枚举,不然就会成为完全背包,一种状态可能被统计多次,但是这里并没有这么做,为什么?
https://codeforces.ml/gym/102576/problem/B
改进前
#include
using namespace std;
const int maxn = 21;
//dp[数字i][最右j位包含的] 数字个数(包括本身)
//x包含y, 指x & y = y,且最右边j位相同
long long dp[1 << maxn][maxn];
int a[1 << maxn];
int main(){
int T;
scanf("%d",&T);
while(T--){
for(int i = 1 ; i < (1 << maxn) ; i++){
for(int j = 0 ; j < maxn ; j++)dp[i][j] = 0;
}
int n;
scanf("%d",&n);
for(int i = 1 ; i <= n ; i++){
scanf("%d", &a[i]);
dp[a[i]][0]++;
}
for(int i = 1 ; i < (1 << maxn) ; i++){
for(int j = 1 ; j < maxn ; j++){
//如果j位是1,例如??1xx,则包含了??(1)xx,??(0)xx
//
if(i & (1 << (j-1)))
dp[i][j] = dp[i ^ (1 << (j-1))][j - 1] + dp[i][j - 1];
else
dp[i][j] = dp[i][j - 1];
}
}
long long ans = 0;
for(int i = 1 ; i <= n ; i++)ans += dp[a[i]][maxn - 1];
printf("%lld\n",ans);
}
}
改进后
#include
using namespace std;
const int maxn = 21;
long long dp[1 << maxn];
int a[1 << maxn];
int main(){
int T;
scanf("%d",&T);
while(T--){
for(int i = 1 ; i < (1 << maxn) ; i++){
dp[i] = 0;
}
int n;
scanf("%d",&n);
for(int i = 1 ; i <= n ; i++){
scanf("%d", &a[i]);
dp[a[i]]++;
}
//滚动数组优化而来
for(int j = 1 ; j < maxn ; j++){
for(int i = 1 ; i < (1 << maxn) ; i++){
if(i & (1 << (j-1)))dp[i] += dp[i ^ (1 << (j-1))];
}
}
long long ans = 0;
for(int i = 1 ; i <= n ; i++)ans += dp[a[i]];
printf("%lld\n",ans);
}
}