对于一个长度为 n n n 的 01
字符串 S S S,请你求出将其分为至少 k k k 段,将每一段看为二进制数求和后的最大值以及取到这个最大值的划分方案的数量。
n ≤ 2 × 1 0 6 n\le2\times10^6 n≤2×106
字符串编号从 1 1 1 开始。
若 n = k n=k n=k,最大值很好求,方案数就是 1 1 1。
若前 k k k 个都没有 1 1 1,设第一个 1 1 1 出现的位置为 m m m,最大值是选 m m m 开始的后缀,划分方案是在前 m m m 个空位插大于等于 k − 1 k-1 k−1 个隔板,用组合数可以轻松求出。
否则,最大值一定是选长度为 n − k + 1 n-k+1 n−k+1 的子串 a a a,剩下每个数单独为一个段。
现在考虑怎样选这样的子串使答案最大。显然选二进制最大的。
证明:设二进制最大的子串为 a a a,另一个不是最大的子串为 b b b,设 1 1 1 的总数为 x x x。由于 1 1 1 的总数始终不变,即证明: x − popcount(a) + a ≥ x − popcount ( b ) + b x-\operatorname{popcount(a)}+a\ge x-\operatorname{popcount}(b)+b x−popcount(a)+a≥x−popcount(b)+b。化简得 a − popcount(a) ≥ b − popcount ( b ) a-\operatorname{popcount(a)}\ge b-\operatorname{popcount}(b) a−popcount(a)≥b−popcount(b)
设 f ( n ) = n − popcount ( n ) f(n)=n-\operatorname{popcount}(n) f(n)=n−popcount(n),即证它是不减函数。对其差分得 popcount ( n ) − popcount ( n + 1 ) + 1 \operatorname{popcount}(n)-\operatorname{popcount}(n+1)+1 popcount(n)−popcount(n+1)+1,显然若 n + 1 n+1 n+1, popcount ( n ) \operatorname{popcount}(n) popcount(n) 要么增加 1 1 1(此时 n n n 为偶数时取等),要么减少,得证。
由上面结论和取等条件,我们要做的是:先找最大的子串,然后统计前 n − k n-k n−k 个相同的子串个数,即为方案数。
问题转换为求长度为 n − k + 1 n-k+1 n−k+1 的字典序最大的子串。
这个问题可以用二分+哈希解决。
我们有 k k k 个子串待比较,令第一个子串是当前的最大子串,后面更新。
对于当前最大子串 s s s 和左端点为 i i i 的子串 t t t,若 t t t 可以更新 s s s,即 t > s t>s t>s,则 s s s 和 t t t 会有一段公共前缀(可能没有),接下来一个数字 t t t 为 1 1 1, s s s 为 0 0 0,我们想要快速求出公共前缀的长度,这里就二分长度 m i d mid mid,如果二者长度为 m i d mid mid 的子串不相等,就把 m i d mid mid 变小,否则变大;快速判断子串是否相等可以用哈希。
时间复杂度为 O ( k log ( n − k ) ) O(k\log(n-k)) O(klog(n−k))。
具体实现看代码
#include
using namespace std;
typedef long long ll;
const ll mod=998244353,mod1=1e9+7;
const int N=2e6+1;
int n,k,sum[N];
char a[N];
ll f[N],inv[N],pw1[N],pw2[N],a1[N],a2[N];
ll ksm(ll a,ll b)
{
ll ans=1;
while(b){
if(b&1) ans=ans*a%mod;
b>>=1;
a=a*a%mod;
}
return ans;
}
struct node
{
int fl,son[2];
}tr[2000001];
pair<ll,ll> geth(int l,int r)
{
return make_pair((a1[r]-a1[l-1]*pw1[r-l+1]%mod+mod)%mod,(a2[r]-a2[l-1]*pw2[r-l+1]%mod1+mod1)%mod1);
}
int main()
{
freopen("divide.in","r",stdin);
freopen("divide.out","w",stdout);
scanf("%d%d%s",&n,&k,a+1);
pw1[0]=pw2[0]=1;
for(int i=1;i<=n;i++) pw1[i]=pw1[i-1]*2%mod,pw2[i]=pw2[i-1]*2%mod1;
for(int i=1;i<=n;i++) a1[i]=(a1[i-1]*2+a[i]-48)%mod,a2[i]=(a2[i-1]*2+a[i]-48)%mod1,sum[i]=sum[i-1]+a[i]-48;
int fl=0;
for(int i=1;i<=k;i++) if(a[i]==49){fl=1;break;}
if(n==k){
int x=0;
for(int i=1;i<=n;i++) x+=a[i]-48;
printf("%d 1",x);
}
else if(!fl){
int m=0;
while(m<n&&a[m+1]=='0') m++;
if(m==n) m--;
f[0]=1;
for(int i=1;i<=n;i++) f[i]=f[i-1]*i%mod;
inv[n]=ksm(f[n],mod-2);
for(int i=n-1;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod;
ll x=0,ans=0;
for(int i=m+1;i<=n;i++) x=(x*2+a[i]-48)%mod;
for(int i=k-1;i<=m;i++) ans=(ans+f[m]*inv[i]%mod*inv[m-i])%mod;
printf("%lld %lld",x,ans);
}
else{
int maxi=1,ans=0;
for(int i=2;i<=k;i++){
int l=0,r=n-k+1,ans=l;
while(l<=r){
int mid=l+r>>1;
if(geth(i,i+mid-1)!=geth(maxi,maxi+mid-1)) r=mid-1;
else l=mid+1,ans=mid;
}
if(a[maxi+ans]<a[i+ans]) maxi=i;
}
for(int i=1;i<=k;i++) if(geth(i,i+n-k-1)==geth(maxi,maxi+n-k-1)) ans++;
printf("%lld %d",(geth(maxi,maxi+n-k).first+sum[maxi-1]+sum[n]-sum[maxi+n-k])%mod,ans);
}
}