Ayoub thinks that he is a very smart person, so he created a function f(s)
, where s is a binary string (a string which contains only symbols “0” and “1”). The function f(s) is equal to the number of substrings in the string s
that contains at least one symbol, that is equal to “1”.
More formally, f(s)
is equal to the number of pairs of integers (l,r), such that 1≤l≤r≤|s| (where |s| is equal to the length of string s), such that at least one of the symbols sl,sl+1,…,sr
is equal to “1”.
For example, if s=
“01010” then f(s)=12, because there are 12 such pairs (l,r): (1,2),(1,3),(1,4),(1,5),(2,2),(2,3),(2,4),(2,5),(3,4),(3,5),(4,4),(4,5)
Ayoub also thinks that he is smarter than Mahmoud so he gave him two integers n
and m and asked him this problem. For all binary strings s of length n which contains exactly m symbols equal to “1”, find the maximum value of f(s)
.
Mahmoud couldn’t solve the problem so he asked you for help. Can you help him?
Input
The input consists of multiple test cases. The first line contains a single integer t
(1≤t≤105
) — the number of test cases. The description of the test cases follows.
The only line for each test case contains two integers n
, m (1≤n≤109, 0≤m≤n
) — the length of the string and the number of symbols equal to “1” in it.
Output
For every test case print one integer number — the maximum value of f(s)
over all strings s of length n, which has exactly m symbols, equal to “1”.
题意:给你n个数字(只包含0,1),然后告诉有m个 1 ;问最多有多少个子串含有1;
这题的思维还是比较难想的,可以换个角度思考问题,不要只想 1 的位置该怎么摆放,可以思考一下 0 的位置该怎么放才能使整个字符串只含 0 的子串最少,这里想到的就是所有 0 尽量均分成(m+1)块,这样只含 0 的子串会最少;
举个例子:8个 0 分成2份,如果 4 4分,子串是20个;如果3 5分,子串是21个;
如果不能均分,也要尽量均分,比如 18 分成4份,5 5 4 4 就比 6 4 4 4好;
代码:
#include
#define ll long long
#define pa pair
#define lson k<<1
#define rson k<<1|1
//ios::sync_with_stdio(false);
using namespace std;
const int N=100100;
const int M=200100;
const ll mod=998244353;
ll solve(ll p){
return p*(p+1)/2;
}
int main(){
ios::sync_with_stdio(false);
int t;
cin>>t;
while(t--){
ll n,m;
cin>>n>>m;
if((n-m)%(m+1)==0){
ll ans=solve(n)-(m+1)*solve((n-m)/(m+1));
cout<<ans<<endl;
}
else{
ll a=(n-m)%(m+1);
ll b=m+1-a;
ll c=(n-m)/(m+1);
ll ans=solve(n)-a*solve(c+1)-b*solve(c);
cout<<ans<<endl;
}
}
return 0;
}