SDUT 1008 最大公共子串 变维dp

http://acm.sdut.edu.cn/sdutoj/problem.php?action=showproblem&problemid=1008

题意:

给定n个串,求这n个串的的最长公共子序列。

思路:

首先说一下自己的思路,dp确实挺弱的没想到变维dp。

就想了一个贪心的算法,贪心是这样的:把s[0]与后边的串匹配,求出所有满足最大长度的子串,然后依据每个串在后边可能出现的概率取最大的继续往后求最长公共子序列,直到结束。这里出现的概率计算:我们统计出现的字符总数,然后计算每个字符出现的概率,然后就可以求子串在后边未参与匹配的串中出现的概率了。这个算法要回溯出所有可能的解,数据大的话会超时,自己出了组数据想把自己查了,本地机子跑的很慢,可是竟然没有把自己查了。可见我们的服务器跑的多快啊。。。。这里不是正解,只是想留下这个贪心的思想,本人好不容易想出来的...

#include <iostream>

#include <cstdio>

#include <cstdlib>

#include <cstring>

#include <algorithm>

#include <cmath>

#include <queue>

#include <stack>

#include <set>

#include <map>

#include <cstring>



#define CL(a,num) memset((a),(num),sizeof(a))

#define iabs(x)  ((x) > 0 ? (x) : -(x))

#define Min(a,b) (a) > (b)? (b):(a)

#define Max(a,b) (a) > (b)? (a):(b)



#define ll __int64

#define inf 0x7f7f7f7f

#define MOD 1073741824

#define lc l,m,rt<<1

#define rc m + 1,r,rt<<1|1

#define pi acos(-1.0)

#define test puts("<------------------->")

#define maxn 100007

#define M 30007

#define N 107

using namespace std;

//freopen("din.txt","r",stdin);



int dp[N][M],dir[N][M];

char s[N][M],t_ans[M],ans[M];

int num[40],len[40];

double rate[40],t_rate;

char tmp[M];



void getS0(char *S,int i,int j,int lcsL,int curl){

    if (curl == lcsL + 1){

        double sum = 0;

        int ki = 0;

        //记录公共子串

        for (int k = lcsL; k >= 1; --k){

          //  printf("%c ",t_ans[k]);

            ans[ki++] = t_ans[k];

            if (t_ans[k] >= '0' && t_ans[k] <= '9'){

                sum += rate[t_ans[k] - '0'];

            }

            else{

                sum += rate[t_ans[k] - 'a' + 10];

            }

        }

        // printf("\n");

        ans[ki] = '\0';

        //去概率最大的

        if (t_rate == 0 || t_rate < sum){

            t_rate = sum;

            strcpy(tmp,ans);

        }

    }

    if (i >= 1 && j >= 1){

        if (dir[i][j] == 1){

            t_ans[curl] = S[i - 1];

            getS0(S,i - 1,j - 1,lcsL,curl + 1);

        }

        else if (dir[i][j] == 2){

            getS0(S,i - 1,j,lcsL,curl);

        }

        else if (dir[i][j] == 3){

            getS0(S,i,j - 1,lcsL,curl);

        }

        else{

            getS0(S,i - 1,j,lcsL,curl);

            getS0(S,i,j - 1,lcsL,curl);

        }

    }

}

int match(char *ts,char *ps){

    int i,j;

    int len1 = strlen(ts);

    int len2 = strlen(ps);

    CL(dp,0);

    CL(dir,0);//记录方向

    for (i = 1; i <= len1; ++i){

        for (j = 1; j <= len2; ++j){

            if (ts[i - 1] == ps[j - 1]){

                dp[i][j] = dp[i - 1][j - 1] + 1;

                dir[i][j] = 1;

            }

            else{

                if (dp[i - 1][j] > dp[i][j - 1]){

                    dp[i][j] = dp[i - 1][j];

                    dir[i][j] = 2;

                }

                else if (dp[i - 1][j] < dp[i][j - 1]){

                    dp[i][j] = dp[i][j - 1];

                    dir[i][j] = 3;

                }

                else{

                    dp[i][j] = dp[i][j - 1];

                    dir[i][j] = 4;

                }

            }

        }

    }

    getS0(ts,len1,len2,dp[len1][len2],1);

    return dp[len1][len2];

}

int main(){

    //freopen("din.txt","r",stdin);

    int i,j;

    int T,n;

    int sum;

    scanf("%d",&T);

    while (T--){

        scanf("%d",&n);

        sum = 0;

        CL(num,0);//统计每个字符出现的个数

        for (i = 0; i < n; ++i){

            scanf("%s",s[i]);

            int len = strlen(s[i]);

            sum += len;

            if (i == 0 || i == 1) continue;

            for (j = 0; j < len; ++j){

                if (s[i][j] >= '0' && s[i][j] <= '9'){

                    num[s[i][j] - '0']++;

                }

                else{

                    num[s[i][j] - 'a' + 10]++;

                }

            }

        }

        for (i = 0; i <= 36; ++i){

            rate[i] = (1.0*num[i])/(1.0*sum);//计算概率

        }

        int mk;

        for (i = 1; i < n; ++i){

            t_rate = 0;

            mk = match(s[0],s[i]);//匹配求值

            if (mk == 0) break;

            if (i + 1 < n)//更新概率,因为前边参与的字符串已经不需要了

            {

                int len = strlen(s[i + 1]);

                sum -= len;

                for (j = 0; j < len; ++j){

                    if (s[i][j] >= '0' && s[i][j] <= '9'){

                        num[s[i][j] - '0']--;

                    }

                    else{

                        num[s[i][j] - 'a' + 10]--;

                    }

                }

                for (j = 0; j <= 36; ++j){

                    rate[j] = (1.0*num[j])/(1.0*sum);

                }

            }

            strcpy(s[0],tmp);//将得到的最长公共子序列给s[0]

        }

        if (i < n) printf("0\n");

        else printf("%d\n",mk);

    }

    return 0;

}

 

 

  

此题的正解是变维DP,两个串比较时,我们用到的是二维的dp[i][j] n个串的话我们可以变换到n维。这里主要突破点是N个串的长度的乘积不会超过30000.   这样我们就可以把每个状态表示出来然后推,记忆化搜索+倒推。

#include <iostream>

#include <cstdio>

#include <cstdlib>

#include <cstring>

#include <algorithm>

#include <cmath>

#include <queue>

#include <stack>

#include <set>

#include <map>

#include <string>



#define CL(a,num) memset((a),(num),sizeof(a))

#define iabs(x)  ((x) > 0 ? (x) : -(x))

#define Min(a,b) (a) > (b)? (b):(a)

#define Max(a,b) (a) > (b)? (a):(b)



#define ll long long

#define inf 0x7f7f7f7f

#define MOD 1073741824

#define lc l,m,rt<<1

#define rc m + 1,r,rt<<1|1

#define pi acos(-1.0)

#define test puts("<------------------->")

#define maxn 100007

#define M 107

#define N 30007

using namespace std;

//freopen("data.in","r",stdin);



int dp[N];

int len[N],end[N];

char s[M][N];

int n;



int getR(int i,int k){

    if (i >= n) return 0;

    return getR(i + 1,k*len[i]) + k*end[i];

}

int DP(){

    int i,j;

    int p = getR(0,1);//得到每一个状态

    if (dp[p] == -1){

        dp[p] = 0;

        for (i = 0; i < n - 1; ++i){

            if (s[i][end[i]] != s[i + 1][end[i + 1]]) break;

        }

        //n个串在一点匹配

        if (i >= n - 1){

            for (i = 0; i < n; ++i){

                if (end[i]) end[i]--;

                else break;

            }

            if (i >= n){

                j = DP() + 1;

                dp[p] = max(dp[p],j);

            }

            else dp[p] = 1;

            while ((--i) >= 0) end[i]++;

        }

        //不匹配我们就枚举

        else{

            for (i = 0; i < n; ++i){

                if (end[i]){

                    end[i]--;

                    j = DP();

                    dp[p] = max(dp[p],j);

                    end[i]++;

                }

            }

        }

    }

    return dp[p];

}

int main(){

    //freopen("din.txt","r",stdin);

    int T,i;

    scanf("%d",&T);

    while (T--){

        CL(dp,-1);

        scanf("%d",&n);

        for (i = 0; i < n; ++i){

            scanf("%s",s[i]);

            len[i] = strlen(s[i]);//记录每个串的长度

            end[i] = len[i] - 1;//倒推记录最后的下标

        }

        printf("%d\n",DP());

    }

    return 0;

}

  

 

你可能感兴趣的:(dp)