最小斯坦纳树

$dp[i][state]$ 表示以$i$为根,指定集合中的点的连通状态为state的生成树的最小总权值

有两种转移方向:

1、先通过连通状态的子集进行转移。

2、在当前枚举的连通状态下,对该连通状态进行松弛操作。

P4294 [WC2008]游览计划

注意景点的个数不超过10个。

$dp[i][j][state]$ 表示在$[i, j]$这个点与state中对应点连通的最小代价。

那么就可以用状压DP + spfa求解。

由于要输出方案,可以记录每个状态的前一个状态,最后dfs跑一遍就行了。

// #pragma GCC optimize(2)
// #pragma GCC optimize(3)
// #pragma GCC optimize(4)
#include 
#include  
#include  
#include   
#include   
#include   
#include    
#include    
#include    
#include    <string>
#include    
#include     
#include     
#include     
#include      
#include       
#include       <set>
#include   
#include 
#include 
// #include
// using namespace __gnu_pbds;
using namespace std;
#define pb push_back
#define fi first
#define se second
#define debug(x) cerr<<#x << " := " << x << endl;
#define bug cerr<<"-----------------------"<#define FOR(a, b, c) for(int a = b; a <= c; ++ a)

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair pll;

const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9+7;

template
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
    while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x=f?-x:x;
}

/**********showtime************/
            const int maxn = 12;
            int a[maxn][maxn];
            int dp[maxn][maxn][1055];
            struct node{
                int x, y, state;
            } pre[maxn][maxn][1055];

            queueque;
            int nx[4][2] = {
                {0, 1}, {1, 0},{-1,0},{0,-1}
            };
            int n,m;
            int vis[maxn][maxn];
            void spfa(int now) {
                while(!que.empty()) {
                    pii tmp = que.front(); que.pop();
                    for(int i=0; i<4; i++) {
                        int x = tmp.fi + nx[i][0];
                        int y = tmp.se + nx[i][1];
                        if(x < 1 || x > n || y < 1 || y > m) continue;
                        if(dp[x][y][now] > dp[tmp.fi][tmp.se][now] + a[x][y]) {
                            dp[x][y][now] = dp[tmp.fi][tmp.se][now] + a[x][y];
                            pre[x][y][now] = node{tmp.fi, tmp.se, now};
                            if(!vis[x][y]) {
                                que.push(pii(x, y));
                                vis[x][y] = 1;
                            }
                        }
                    }
                    vis[tmp.fi][tmp.se] = 0;
                }
            }
            void dfs(int x, int y, int now) {
                if(x == 0 || y == 0) return;
                vis[x][y] = 1;
                node tmp = pre[x][y][now];
                dfs(tmp.x, tmp.y, tmp.state);
                if(tmp.x == x && tmp.y == y)
                    dfs(tmp.x, tmp.y, now - tmp.state);
            }
int main(){
            scanf("%d%d", &n, &m);
            memset(dp, inf, sizeof(dp));
            int num = 0;
            for(int i=1; i<=n; i++) {
                for(int j=1; j<=m; j++) {
                    scanf("%d", &a[i][j]);
                    if(a[i][j] == 0) {
                        dp[i][j][(1<0;
                        num++;
                    }
                }
            }
            int all = (1<1;

            for(int state = 0; state <= all; state ++) {
                for(int i=1; i<=n; i++) {
                    for(int j=1; j<=m; j++) {
                        for(int s0 = (s0-1) & state; s0; s0 = (s0-1) & state) {

                            if(dp[i][j][state] > dp[i][j][s0] + dp[i][j][state - s0] - a[i][j]) {
                                dp[i][j][state] = dp[i][j][s0] + dp[i][j][state - s0] - a[i][j];
                                pre[i][j][state] = node{i, j, s0};

                            }
                        }
                        if(dp[i][j][state] < inf) que.push(pii(i, j)), vis[i][j] = 1;
                    }
                }
                spfa(state);
            }
            int ax, ay, mn = inf;
            for(int i=1; i<=n; i++) {
                for(int j=1; j<=m; j++) {
                    if(dp[i][j][all] < mn) {
                        mn = dp[i][j][all];
                        ax = i;
                        ay = j;
                    }
                }
            }
            printf("%d\n", mn);
            memset(vis, 0, sizeof(vis));
            dfs(ax, ay, all);
            for(int i=1; i<=n; i++) {
                for(int j=1; j<=m; j++) {
                    if(a[i][j] == 0) printf("x");
                    else if(vis[i][j]) printf("o");
                    else printf("_");
                }
                puts("");
            }
            return 0;
}
View Code

HDU-4085 Peach Blossom Spring

 给定一个$n \le 50 , m \le 1000$ 的无向图,让你用最小的修路总花费,使得1到k号点($k \le 5$),与最后k个点相连,就是说1到k号点每个点都有一个匹配点,匹配点两辆不同。

 

斯坦纳树,但是题目没有要求这$2 \times k$个点都连通,所以我们利用斯坦纳小树dp转移,求出斯坦纳森林。

 

// #pragma GCC optimize(2)
// #pragma GCC optimize(3)
// #pragma GCC optimize(4)
#include 
#include  
#include  
#include   
#include   
#include   
#include    
#include    
#include    
#include    <string>
#include    
#include     
#include     
#include     
#include      
#include       
#include       <set>
#include   
#include 
#include 
// #include
// using namespace __gnu_pbds;
using namespace std;
#define pb push_back
#define fi first
#define se second
#define debug(x) cerr<<#x << " := " << x << endl;
#define bug cerr<<"-----------------------"<#define FOR(a, b, c) for(int a = b; a <= c; ++ a)

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair pll;

const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9+7;

template
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
    while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x=f?-x:x;
}

/**********showtime************/
            const int maxn = 55;
            vectormp[maxn];
            queue<int>que;
            int vis[55];
            int dp[maxn][1055],g[1055];
            void spfa(int now) {
                while(!que.empty()) {
                    int u = que.front(); que.pop();
                    for(pii p : mp[u]) {
                        if(dp[p.fi][now] > dp[u][now] + p.se)
                        {
                            dp[p.fi][now] = dp[u][now] + p.se;
                            if(vis[p.fi] == 0) {
                                vis[p.fi] = 1;
                                que.push(p.fi);
                            }
                        }
                    }
                    vis[u] = 0;
                }
            }
            int n,m,k;
            bool check(int state ){
                int cnt = 0;
                for(int i=1; i<=k; i++) {
                    if(state % 2 == 1) cnt++;
                    state = state / 2;
                }

                for(int i=1; i<=k; i++) {
                    if(state % 2 == 1) cnt--;
                    state = state / 2;
                }
                return cnt == 0;
            }
int main(){
            int T;  scanf("%d", &T);
            while(T--) {
                scanf("%d%d%d", &n, &m, &k);
                for(int i=1; i<=m; i++) {
                    int u,v,w;
                    scanf("%d%d%d", &u, &v, &w);
                    mp[u].pb(pii(v, w));
                    mp[v].pb(pii(u, w));
                }
                int num = 2 * k;
                int all = (1 << num) - 1;
                for(int i=1; i<=n; i++) for(int state = 0; state <= all; state ++ ) dp[i][state] = inf;
                for(int i=1; i<=k; i++) dp[i][(1<<(i-1))] = 0;
                for(int i=n, cur = 2*k; i>=n-k+1; i--, cur--) {
                    dp[i][1<<(cur-1)] = 0;
                }
                for(int state=0; state <= all; state ++) {
                    for(int i=1; i<=n; i++) {
                        for(int s0 = (state-1)&state; s0; s0 = (s0-1) & state) {
                            dp[i][state] = min(dp[i][state], dp[i][s0] + dp[i][state - s0]);
                        }
                        if(dp[i][state] < inf) que.push(i), vis[i] = 1;
                    }
                    spfa(state);
                }
                
                //由于最后没有要求得出一个斯坦纳树,而是一个斯坦纳森林,于是
                //用小斯坦纳树组合一下
                for(int state=0; state<=all; state++) {
                    g[state] = inf;
                    if(check(state)) {
                        for(int i=1; i<=n; i++)
                            g[state] = min(g[state], dp[i][state]);
                    }
                }
                
                for(int state = 0; state <= all; state++) {
                    if(check(state) == 0) continue;
                    for(int s0 = (state-1)&state; s0; s0 = (s0-1) & state) {
                        if(check(s0)) {
                            g[state] = min(g[state], g[s0] + g[state - s0]);
                        }
                    }
                }

                if(g[all] < inf) printf("%d\n", g[all]);
                else printf("No solution\n");
                for(int i=1; i<=n; i++) mp[i].clear();
            }
            return 0;
}
View Code

 

ZOJ-3613 Wormhole Transport

题意:

有n个星球,其中最多四个星球是资源星球,最多四个星球有不同个数的工厂。

一个资源星球只能供给一个工厂。

有不同类型的路可以修,问在最多工厂被供给的前提下,最小的修路费用。

思路:

朴素的斯坦纳树转移。

然后需要通过森林DP,转移条件是工厂个数 $\ge$ 资源个数

 

// #pragma GCC optimize(2)
// #pragma GCC optimize(3)
// #pragma GCC optimize(4)
#include 
#include  
#include  
#include   
#include   
#include   
#include    
#include    
#include    
#include    <string>
#include    
#include     
#include     
#include     
#include      
#include       
#include       <set>
#include   
#include 
#include 
// #include
// using namespace __gnu_pbds;
using namespace std;
#define pb push_back
#define fi first
#define se second
#define debug(x) cerr<<#x << " := " << x << endl;
#define bug cerr<<"-----------------------"<#define FOR(a, b, c) for(int a = b; a <= c; ++ a)

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair pll;

const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9+7;

template
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
    while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x=f?-x:x;
}

/**********showtime************/

            const int maxn = 205;
            int pt[10], flag[10];
            int vis[maxn];
            queue<int>que;
            pii p[maxn];
            vectormp[maxn];
            int dp[maxn][300];
            int g[300];
            void spfa(int now) {
                while(!que.empty()) {
                    int u = que.front(); que.pop();
                    for(pii p : mp[u]) {
                        if(dp[p.fi][now] > dp[u][now] + p.se) {
                            dp[p.fi][now] = dp[u][now] + p.se;
                            if(vis[p.fi] == 0) {
                                que.push(p.fi);
                                vis[p.fi] = 1;
                            }
                        }
                    }
                    vis[u] = 0;
                }
            }
            int cnt = 0;
            bool check(int now) {
                int c[2];
                c[0] = c[1] = 0;
                for(int i=0; i) {
                    if(now % 2 == 1) {
                        if(flag[i])
                            c[1]++;
                        else c[0] += pt[i];
                    }
                    now = now / 2;
                }
                return c[1] <= c[0];
            }
            int cal(int now) {
                int res = 0;
                for(int i=0; i) {
                    if(now % 2 == 1) {
                        res += flag[i];
                    }
                    now = now / 2;
                }
                return res;
            }
int main(){
            int n;
            while(~scanf("%d", &n)) {
                for(int i=1; i<=n; i++) {
                    for(int j=0; j<300; j++)
                        dp[i][j] = inf;
                }
                cnt = 0;
                int res1 = 0;
                for(int i=1; i<=n; i++) {
                    scanf("%d%d", &p[i].fi, &p[i].se);
                    if(p[i].fi && p[i].se)
                    {
                        res1++;
                        p[i].se = 0;
                        p[i].fi--;
                    }
                    if(p[i].se) {
                        flag[cnt] = 1; //
                        pt[cnt] = p[i].fi;
                        dp[i][1<0;
                        cnt++;
                    }
                    else if(p[i].fi) {
                        pt[cnt] = p[i].fi;
                        flag[cnt] = 0;
                        dp[i][1<0;
                        cnt++;
                    }
                }

                int m;  scanf("%d", &m);
                for(int i=1; i<=m; i++) {
                    int u,v,w;
                    scanf("%d%d%d", &u, &v, &w);
                    mp[u].pb(pii(v, w));
                    mp[v].pb(pii(u, w));
                }

                int all = (1 << cnt) - 1;
                for(int state = 0; state <= all; state++) {
                    for(int i=1; i<=n; i++) {
                        for(int s0 = (state-1) & state; s0; s0 = (s0-1) & state) {
                            dp[i][state] = min(dp[i][state], dp[i][s0] + dp[i][state - s0]);
                        }
                        if(dp[i][state] < inf) que.push(i);
                    }
                    spfa(state);
                }
                for(int state=0; state<=all; state++) {
                    g[state] = inf;
                    if(check(state) == 0) continue;
                    for(int i=1; i<=n; i++) {
                        g[state] = min(g[state], dp[i][state]);
                    }
                }

                int res2 = 0, ans = 0;
                for(int state = 0; state <= all; state ++) {
                    if(check(state) == 0) continue;
                    for(int s0 = (state - 1) & state; s0; s0 = (s0-1) & state) {
                        if(check(s0) && check(state - s0)) {
                            g[state] = min(g[state], g[s0] + g[state - s0]);
                        }
                    }

                    if(cal(state) > res2) {
                        res2 = cal(state);
                        ans = g[state];
                    }
                    else if(cal(state) == res2) {
                        ans = min(ans, g[state]);
                    }
                }
                printf("%d %d\n", res1 + res2, ans);

                for(int i=1; i<=n; i++) mp[i].clear();
            }
            return 0;
}
View Code

 

 

参考和学习:

https://www.cnblogs.com/clno1/p/10990936.html

你可能感兴趣的:(最小斯坦纳树)