Acwing 322. 消木块(区间dp)

题目传送门

题意: 给你一个长度为n的序列,你每次可以选择一个连续块 [ i , j ] [i,j] [i,j]消去,消去之后 i − 1 i-1 i1 j + 1 j+1 j+1连接,这一操作的得分为 ( j − i + 1 ) 2 (j-i+1)^2 (ji+1)2,问你最多能得多少分。

思路: 很容易想到 f [ i ] [ j ] f[i][j] f[i][j]表示区间 [ i , j ] [i,j] [i,j]的最大分数,但是如果第 j + 1 j+1 j+1个和第j个一样的话,我们就不能只消掉第 j j j个而不消掉第 j + 1 j+1 j+1个。所以我们需要多加一个维度表示状态,即 f [ i ] [ j ] [ k ] f[i][j][k] f[i][j][k]表示消除 [ i , j ] [i,j] [i,j]的方块并且连带消除了 j j j之后(不包括 j j j k k k个与第 j j j个方块相同的方块时的最大值。

那么我们的状态转移有两种:

  1. j j j向前找到最后一个与第 j j j个方块不一样的方块,记其位置为 p p p,则有 f [ i ] [ j ] [ k ] = m a x ( f [ i ] [ j ] [ k ] , f [ i ] [ p − 1 ] [ 0 ] + ( j + k − p + 1 ) 2 ) f[i][j][k]=max(f[i][j][k],f[i][p-1][0]+(j+k-p+1)^2) f[i][j][k]=max(f[i][j][k],f[i][p1][0]+(j+kp+1)2)
  2. 找到上述p之后,我们从 i i i向后找,找到每一个q,有 a [ q ] ! = a [ q + 1 ] a[q]!=a[q+1] a[q]!=a[q+1]&& a [ q ] = = a [ j ] a[q]==a[j] a[q]==a[j],并满足 q + 1 < = p − 1 q+1<=p-1 q+1<=p1,维护状态 f [ i ] [ j ] [ k ] = m a x ( f [ i ] [ j ] [ k ] , f [ i ] [ q ] [ j + k − p + 1 ] + f [ q + 1 ] [ p − 1 ] [ 0 ] ) f[i][j][k]=max(f[i][j][k],f[i][q][j+k-p+1]+f[q+1][p-1][0]) f[i][j][k]=max(f[i][j][k],f[i][q][j+kp+1]+f[q+1][p1][0])

注意递归跳出的条件,并且使用递归实现。

代码:

#include
#define endl '\n'
#define null NULL
#define ls p<<1
#define rs p<<1|1
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define ll long long
//#define int long long
#define pii pair
#define ull unsigned long long
#define pdd pair
#define lowbit(x) x&-x
#define all(x) (x).begin(),(x).end()
#define sz(x) (int)(x).size()
#define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
char *fs,*ft,buf[1<<20];
#define gc() (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<20,stdin),fs==ft))?0:*fs++;
inline int read()
{
    int x=0,f=1;
    char ch=gc();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-')
            f=-1;
        ch=gc();
    }
    while(ch>='0'&&ch<='9')
    {
        x=x*10+ch-'0';
        ch=gc();
    }
    return x*f;
}
using namespace std;
const int N=2e3+55;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
const double eps=1e-6;
const double PI=acos(-1);

int a[N],f[222][222][222],cas=0;

int dfs(int x,int y,int z)
{
    if(f[x][y][z])
        return f[x][y][z];
    if(x==y)
        return f[x][y][z]=(1+z)*(1+z);
    if(x>y)
        return 0;
    int p=y;
    while(p>=x&&a[p]==a[y])
        p--;
    p++;
    f[x][y][z]=max(f[x][y][z],dfs(x,p-1,0)+(y+z-p+1)*(y+z-p+1));
    int q=x;
    while(q+1<=p-1)
    {
        if(a[q]==a[y]&&a[q+1]!=a[y])
            f[x][y][z]=max(f[x][y][z],dfs(x,q,z+y-p+1)+dfs(q+1,p-1,0));
        q++;
    }
    return f[x][y][z];
}
void solve()
{
    int n;
    cin>>n;
    memset(f,0,sizeof f);
    for(int i=1; i<=n; i++)
        cin>>a[i];
    cout<<"Case "<<++cas<<": "<<dfs(1,n,0)<<endl;
}
signed main()
{
    int t;
    cin>>t;
    while(t--)
        solve();
    return 0;
}

你可能感兴趣的:(题解)