https://www.luogu.com.cn/problem/P9338
考虑暴力前 i i i 个分 j j j 段 f i , k = f j − 1 , k − 1 + g j , i f_{i,k}=f_{j-1,k-1}+g_{j,i} fi,k=fj−1,k−1+gj,i, O ( n 3 ) O(n^3) O(n3)
然后划分段数,段数显然越多越优,那么就上wqs二分, O ( n 2 log n ) O(n^2\log n) O(n2logn)
然后我们发现 g g g 可以拆,拆完之后拿斜率dp优化即可。 O ( n log n ) O(n\log n) O(nlogn)
#include
using namespace std;
#define int long long
inline int read(){int x=0,f=1;char ch=getchar(); while(ch<'0'||
ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}return x*f;}
#define Z(x) (x)*(x)
#define pb push_back
//mt19937 rand(time(0));
//mt19937_64 rand(time(0));
//srand(time(0));
#define N 2000010
//#define M
//#define mo
void Mn(int &a, int b) {
a = min(a, b);
}
struct F {
int a, t;
F operator + (const F &A) const {
F B; B.a = a + A.a;
B.t = t + A.t;
return B;
}
F operator + (const int &x) const {
F B; B.a=a+x; B.t=t;
return B;
}
bool operator < (const F &A) const {
if(a == A.a) return t < A.t;
return a < A.a;
}
}f[N];
int n, m, i, j, k, T;
int a[N], b[N];
int sa[N], sb[N], sum[N];
int ida[N], idb[N], h[N];
int X[N], Y[N];
int l, r, pos, ans, ja, jb;
int X1, X2, X3, Y1, Y2, Y3;
char s[N];
int q[N];
int calc(int x) { //一段的代价为x
memset(f, 0x3f, sizeof(f));
auto suan = [&] (int j) -> int {
return Y[j] - X[j] * i;
};
int l, r;
l = r = 0;
f[0]={0, 0}; q[0]=1;
X[0]=1; Y[0]=h[1];
for(i=1; i<=n; ++i) {
while(r-l+1 >= 2 && suan(l) > suan(l+1)) ++l;
f[i].a = suan(l); f[i].t = f[q[l]-1].t;
f[i] = f[i] + sum[i] + i + x; f[i].t++;
while(r-l+1 >= 2) {
X1 = X[r-1]; Y1 = Y[r-1];
X2 = X[r]; Y2 = Y[r];
X3 = (i+1); Y3 = f[i].a + h[i+1];
if((Y2-Y1)*(X3-X2) > (Y3-Y2)*(X2-X1)) --r;
else break;
}
q[++r] = i+1; X[r] = (i+1); Y[r] = f[i].a + h[i+1];
}
return f[n].t;
}
signed main()
{
freopen("easy.in", "r", stdin);
freopen("easy.out", "w", stdout);
n=read(); m=read(); ans=1e18;
scanf("%s", s+1);
for(i=1; i<=2*n; ++i) if(s[i]=='A') a[++ja]=i, sa[i]=1;
else b[++jb]=i, sb[i]=1;
partial_sum(sa+1, sa+2*n+1, sa+1);
partial_sum(sb+1, sb+2*n+1, sb+1);
for(i=1; i<=n; ++i) ida[i] = sb[a[i]]; //第i个a前面有多少个b
for(i=1; i<=n; ++i) idb[i] = sa[b[i]];
for(i=1; i<=n; ++i) idb[i] = max(idb[i], i - 1);
partial_sum(ida+1, ida+n+1, sum+1);
for(i=1; i<=n; ++i) h[i] = idb[i]*(i-1)-sum[idb[i]];
l=0; r=1e12; //一段的代价
while(l<r) {
int mid = (l+r)>>1;
if(calc(mid)<=m) r=mid;
else l=mid+1;
}
calc(l);
printf("%lld", f[n].a - m * l);
return 0;
}