题意:把n个数分成k部分,使得每部分价值之和最大。每部分的价值为不同数的个数。很容易的可以想到dp,用dp[i][k]表示把前i个数分成k部分所获的最大值。则 dp[i][k]=max(dp[j][k−1]+num[j+1][i],k−1≤j<i)
num[i][j]表示a[i..j]的不同数的个数,但是这样是 O(n2m) 的,会tle。我们考虑优化状态转移的时间。(比赛时一直在想怎么用莫队。。可能是莫队看多了。。)正解给出了分治的方法以及线段树方法。我采用了线段树。维护一棵最大值的线段树。每个叶子表示 dp[j][k−1]+num[j+1][i] 。那么每次状态转移其实就是区间(k-1,i-1)查询最大值。我们还需要考虑每次做dp[i][k]之前如何把线段树做正确,及如何更新num。我们做dp[i]时,其实只是要把num[j+1][i-1]更新为num[j+1][i]。显然只有区间[j+1][i-1]没有出现过a[i]这个数时才会对num有影响。我们对于每个位置i,记一个数组pre,表示a[i]这个数上次出现的位置。则在pre[i]+1到i-1都没有出现过a[i]。所以我们对线段树上做一次区间更新(pre[i],i-1),都加上1即可,用懒标记维护。这样的话每次状态转移只需 O(logn) 的时间,复杂度就可以达到 (nmlogn)
#include
#include
#include
#define N 350000
int n,m,a[N],pre[N],h[N],dp[N][51];//pre[i]表示a[i]上一次出现的位置
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return x*f;
}
inline int max(int x,int y){return x>y?x:y;}
std::set<int> s;
struct node{
int lazy,mx,l,r;
}tree[N<<2];
inline void pushup(int p){
tree[p].mx=max(tree[p<<1].mx,tree[p<<1|1].mx);
}
void build(int k,int p,int l,int r){
tree[p].l=l;tree[p].r=r;
tree[p].lazy=0;
if(l==r){
tree[p].mx=dp[l][k];tree[p].lazy=0;return;
}
int mid=(l+r)>>1;
build(k,p<<1,l,mid);build(k,p<<1|1,mid+1,r);
pushup(p);
}
inline void pushdown(int p){
if(!tree[p].lazy) return;
tree[p<<1].lazy+=tree[p].lazy;tree[p<<1|1].lazy+=tree[p].lazy;
tree[p<<1].mx+=tree[p].lazy;tree[p<<1|1].mx+=tree[p].lazy;
tree[p].lazy=0;
}
void update(int p,int l,int r,int x,int y){
if(x<=l&&r<=y){
tree[p].lazy++;tree[p].mx++;
return;
}
pushdown(p);//要下放啊啊啊。。
int mid=(l+r)>>1;
if(x<=mid) update(p<<1,l,mid,x,y);
if(y>mid) update(p<<1|1,mid+1,r,x,y);
pushup(p);
}
int getmax(int p,int l,int r,int x,int y){
if(x<=l&&r<=y) return tree[p].mx;
pushdown(p);
int mid=(l+r)>>1,res=0;
if(x<=mid) res=max(res,getmax(p<<1,l,mid,x,y));
if(y>mid) res=max(res,getmax(p<<1|1,mid+1,r,x,y));
return res;
}
void solve(int k){
build(k-1,1,1,n);
for(int i=k;i<=n;++i){
update(1,1,n,pre[i],i-1);
//pre[i]+1...i-1都没出现过a[i],num[pre[i]+1][i]++
dp[i][k]=max(dp[i][k-1],getmax(1,1,n,k-1,i-1));
//找一个j,使得dp[j][k-1]+num[j+1][i]最大
}
}
int main(){
// freopen("a.in","r",stdin);
n=read();m=read();
for(int i=1;i<=n;++i){
a[i]=read();pre[i]=h[a[i]];h[a[i]]=i;
s.insert(a[i]);dp[i][1]=s.size();
}
for(int i=2;i<=m;++i) solve(i);
printf("%d\n",dp[n][m]);
return 0;
}