LOJ3158
LuoguP5470
给定两个长度为 n n n的正整数序列 { a i } \begin{Bmatrix}a_i\end{Bmatrix} {ai}与 { b i } \begin{Bmatrix}b_i\end{Bmatrix} {bi},序列的下标为 1 , 2 , ⋯   , n 1,2,\cdots ,n 1,2,⋯,n。现在你需要分别对两个序列各指定恰好 K K K个下标,要求至少有 L L L个下标在两个序列中都被指定,使得这 K K K个下标在序列中对应的元素的总和最大。
#include
#define LL long long
#define MAXN 300000
using namespace std;
template<typename T>void Read(T &cn)
{
char c;int sig = 1;
while(!isdigit(c = getchar()))if(c == '-')sig = -1; cn = c-48;
while(isdigit(c = getchar()))cn = cn*10+c-48; cn*=sig;
}
template<typename T>void Write(T cn)
{
if(cn < 0) {putchar('-'); cn = 0-cn; }
int wei = 0; T cm = 0; int cx = cn%10; cn/=10;
while(cn)cm = cm*10+cn%10,cn/=10,wei++;
while(wei--)putchar(cm%10+48),cm/=10;
putchar(cx+48);
}
int n,k,l,cd,t;
int a[MAXN+1],b[MAXN+1],tai[MAXN+1];
int Va,Vb;
bool Cmp(const int cn,const int cm) {return a[cn]*Va + b[cn]*Vb < a[cm]*Va + b[cm]*Vb; }
struct Heap{
int dui[MAXN+1],dlen;
int vis[MAXN+1];
int mu,VA,VB;
void set(int cn,int cm,int cx) {mu = cn; VA = cm; VB = cx; dlen = 0; memset(vis,0,sizeof(vis)); }
void insert(int cn) {if(vis[cn])return; dui[++dlen] = cn; Va = VA; Vb = VB; push_heap(dui+1,dui+dlen+1,Cmp); vis[cn] = 1; }
int get() {return dlen ? dui[1] : 0; }
int get1(int cn) {while(dlen && tai[dui[1]] != cn)del(); return get(); }
int get2(int cn) {while(dlen && tai[dui[1]] & cn)del(); return get(); }
void del() {if(!dlen)return; vis[dui[1]] = 0; Va = VA; Vb = VB; pop_heap(dui+1,dui+dlen+1,Cmp); dlen--; }
}D1,D2,D3,D4,D5;
LL ans;
int main()
{
freopen("sequence.in","r",stdin);
freopen("sequence.out","w",stdout);
Read(t);
while(t--)
{
Read(n); Read(k); Read(l);
for(int i = 1;i<=n;i++)Read(a[i]);
for(int i = 1;i<=n;i++)Read(b[i]);
ans = 0; D1.set(-1,1,0); D2.set(-1,0,1); D3.set(1,0,1); D4.set(2,1,0); D5.set(0,1,1);
memset(tai,0,sizeof(tai));
cd = 0;
for(int i = 1;i<=n;i++)D1.insert(i),D2.insert(i),D5.insert(i);
for(int i = 1;i<=k;i++)
{
if(i-1-cd < k-l){
int lin1,lin2;
lin1 = D1.get2(1); D1.del();
lin2 = D2.get2(2); D2.del();
if(tai[lin1] == 2)cd++; tai[lin1] |= 1; if(tai[lin1] == 1)D3.insert(lin1);
if(tai[lin2] == 1)cd++; tai[lin2] |= 2; if(tai[lin2] == 2)D4.insert(lin2);
ans = ans + a[lin1] + b[lin2];
}
else{
int lin1,lin2_1,lin2_2,lin3_1,lin3_2,lin1z,lin2z,lin3z;
lin1 = D5.get1(0); lin1z = lin1 ? a[lin1] + b[lin1] : 0;
lin2_1 = D3.get1(1); lin2_2 = D1.get2(1); lin2z = lin2_1*lin2_2 ? b[lin2_1] + a[lin2_2] : 0;
lin3_1 = D4.get1(2); lin3_2 = D2.get2(2); lin3z = lin3_1*lin3_2 ? a[lin3_1] + b[lin3_2] : 0;
if(lin1z >= lin2z && lin1z >= lin3z){
ans = ans + lin1z; tai[lin1] = 3; D5.del();
cd++; continue;
}
if(lin2z >= lin1z && lin2z >= lin3z){
ans = ans + lin2z; tai[lin2_1] |= 2; tai[lin2_2] |= 1;
if(tai[lin2_2] == 1)D3.insert(lin2_2); else cd++;
cd++; continue;
}
if(lin3z >= lin1z && lin3z >= lin2z){
ans = ans + lin3z; tai[lin3_1] |= 1; tai[lin3_2] |= 2;
if(tai[lin3_2] == 2)D4.insert(lin3_2); else cd++;
cd++; continue;
}
}
}
Write(ans); putchar('\n');
}
return 0;
}