感谢grass8sheep提供的思路。
首先,我们可以用 D P DP DP解决这个问题。
设 f i , j f_{i,j} fi,j表示前 i i i个数中有 j j j个为 1 1 1的位置为红色的最大价值。则转移如下:
初始 f 0 , j = 0 f_{0,j}=0 f0,j=0。
考虑差分序列,记作 { d i } \{d_i\} {di}。则 s i = 1 s_i=1 si=1的转移等价于,对于一段连续的满足 < r i − b i
但是打表可以发现,答案不是凸的,也就是说 d i d_i di不具有单调性。事实上有一个结论:每次结束后,将 d i d_i di按从大到小排序,这并不会影响答案。因此用平衡树维护即可,操作一对应区间平移;操作二对应前缀减 1 1 1,然后将差值为一的两个连续段交换。
复杂度 O ( n log n ) O(n\log n) O(nlogn)。
关于结论的证明:设 d p j dp_j dpj表示考虑完前 i i i个数后选了 j j j个 1 1 1的最大价值, d p j = a dp_{j}=a dpj=a, d p j + 1 = a + b dp_{j+1}=a+b dpj+1=a+b, d p j + 2 = a + 2 b + 1 dp_{j+2}=a+2b+1 dpj+2=a+2b+1。设之后的方案中选了 x x x个 0 0 0,那么我们要让 d p i − i x dp_i-ix dpi−ix最大。发现交换了 d j + 1 d_{j+1} dj+1和 d j + 2 d_{j+2} dj+2后 j + 1 j+1 j+1仍然不可能成为答案。(考虑是一条直线来截每个点使得截矩最大,因为斜率是整数,而相邻两点间斜率之差又不超过 1 1 1,因此不可能截到中间那个点)
因为每次操作是前缀减 1 1 1,所以交换的两个段之差不会超过 1 1 1,因此结论是正确的。
remark \text{remark} remark 注意到 D P DP DP只要不漏就好了,因此在不影响正确性的情况下我们可以修正 D P DP DP值。
类似的 D P DP DP思路:[USACO21DEC] Paired Up P(做法不一样,但是都有对 D P DP DP最优性的一些思考)
#include
#define ll long long
#define pb push_back
#define inf 0x3f3f3f3f3f3f3f3f
#define fi first
#define se second
using namespace std;
const int N=4e5+5;
int T,n,tot,rt;
ll r[N],b[N];
string str;
mt19937 gen(time(0));
struct node{
int fix,l,r,sz;
ll tag,val;
}t[N];
void pushup(int p){
t[p].sz=t[t[p].l].sz+t[t[p].r].sz+1;
}
int newnode(ll val){
tot++;
t[tot].fix=gen(),t[tot].l=t[tot].r=t[tot].tag=0,t[tot].sz=1,t[tot].val=val;
return tot;
}
void add(int p,ll x){
if(!p)return;
t[p].val+=x,t[p].tag+=x;
}
void pushdown(int p){
if(t[p].tag)add(t[p].l,t[p].tag),add(t[p].r,t[p].tag),t[p].tag=0;
}
int merge(int x,int y){
if(!x||!y)return x+y;
if(t[x].fix>t[y].fix){
pushdown(x);
t[x].r=merge(t[x].r,y);
pushup(x);
return x;
}
else{
pushdown(y);
t[y].l=merge(x,t[y].l);
pushup(y);
return y;
}
}
void split0(int rt,int &x,int &y,ll val){
if(!rt){
x=y=0;
return;
}
pushdown(rt);
if(t[rt].val>=val){
x=rt;
split0(t[x].r,t[x].r,y,val);
pushup(x);
}
else{
y=rt;
split0(t[y].l,x,t[y].l,val);
pushup(y);
}
}
void split1(int rt,int &x,int &y,int val){
if(!rt){
x=y=0;
return;
}
pushdown(rt);
if(t[t[rt].l].sz+1<=val){
x=rt;
split1(t[x].r,t[x].r,y,val-t[t[rt].l].sz-1);
pushup(x);
}
else{
y=rt;
split1(t[y].l,x,t[y].l,val);
pushup(y);
}
}
int rs(int x){
while(t[x].r)x=t[x].r;
return x;
}
int ls(int x){
while(t[x].l)x=t[x].l;
return x;
}
int cnt;
ll c[N];
void dfs(int x){
pushdown(x);
if(t[x].l)dfs(t[x].l);
c[++cnt]=t[x].val;
if(t[x].r)dfs(t[x].r);
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>T;
while(T--){
cin>>n>>str;
for(int i=1;i<=n;i++)cin>>r[i];
for(int i=1;i<=n;i++)cin>>b[i];
rt=tot=0;ll sm=0;int c1=0;
for(int i=1;i<=n;i++){
if(r[i]<=b[i]){
sm+=b[i];
continue;
}
else if(str[i-1]=='1'){
c1++,sm+=b[i];
int x,y;
split0(rt,x,y,r[i]-b[i]);
rt=merge(x,merge(newnode(r[i]-b[i]),y));
}
else{
sm+=r[i];
int x,y;
split1(rt,x,y,min(1ll*c1,r[i]-b[i]));
if(!x||!y){
add(x,-1);
rt=x+y;
}
else{
int _x=rs(x),_y=ls(y);
if(t[_x].val==t[_y].val){
ll val=t[_x].val;
int a,b,c,d;
split0(x,a,b,val+1);
split0(y,c,d,val);
add(a,-1),add(b,-1);
rt=merge(merge(a,c),merge(b,d));
}
else{
add(x,-1);
rt=merge(x,y);
}
}
}
}
cnt=0,dfs(rt);
ll res=sm;
for(int i=1;i<=c1;i++){
sm+=c[i],res=max(res,sm);
}
cout<<res<<"\n";
}
}