UVALive 3675 Sorted bit sequence
将区间 [L,R] 内的所有整数按照其二进制表示中1 的数量从小到大排序。如果 1 的数量 相同,则按照数的大小排序。求这个序列中的第 K 个数。其中,负数使用补码来表示:一个负数的二进制表示与其相反数的二进制之和恰好等于 232 。
数据规模: L×R≥0,−231≤L≤R≤231−1,1≤K≤min(R−L+1,2147473547) 。
参考论文中的分析方法。
首先注意到一个条件 L∗R≥0 。
我们先考虑 m和n 同为正数的情况。
由于排序的第一关键字是1的数量,第二关键字是数的大小,因此我们很容易确定答案中1的个数:依次统计区间[m,n]内二进制表示中含 1的数量为 0,1,2,…的数,直到累加的答案超过 K ,则当前值就是答案含 1 的个数,假设是 s 。这个数位dp可以解决,枚举需要统计的1的个数,dfs
。
同时,我们也求出了答案是第几个[m,n]中含 s个 1 的数。因此,只需二分答案,求出 [L,ans] 中含 s 个 1 的数的个数进行判断即可。
对于 L<0,R<0 的情况,我是把区间变为正的区间,根据一个负数的二进制表示与其相反数的二进制之和恰好等于 232 ,最后再把结果变回来就好了。
需要特殊处理 L=0或者R=0 的情况。
因为 L∗R≥0 了,所以实际上最麻烦的讨论: L<0,R>0 的情况已经避免掉了。
//https://icpcarchive.ecs.baylor.edu/index.php?option=com_onlinejudge&Itemid=8&category=245&page=show_problem&problem=1676
#include
#include
#include
#include
using namespace std;
typedef long long ll;
const ll base = 1ll << 32;
int T, digit[35];
ll L, R, K, LL, RR;
ll dp[35][35][35], cnt[4][35];
ll dfs(int pos, int pre, int limit, int sum)
{
if (pos == -1) return pre == sum;
if (pre > sum) return 0;
if (!limit && dp[pos][pre][sum] != -1) return dp[pos][pre][sum];
int last = limit ? digit[pos] : 1;
ll ret = 0;
for (int i = 0; i <= last; ++i) {
ret += dfs(pos - 1, pre + i, limit && (i == last), sum);
}
if (!limit) dp[pos][pre][sum] = ret;
return ret;
}
ll solve(ll x, int id, int flag)
{
memset(digit, 0, sizeof (digit));
int len = 0;
// printf("x = %lld\n", x);
while (x) {
digit[len++] = x % 2;
x /= 2;
}
if (flag != -1) {
return dfs(len - 1, 0, 1, flag);
}
memset(cnt[id], 0, sizeof(cnt[id]));
for (int i = 1; i <= 32; ++i) { // 别忘了32!
cnt[id][i] = dfs(len - 1, 0, 1, i); // 枚举有i个1
// printf("cnt[%d][%d] = %lld\n", id, i, cnt[id][i]);
}
return 0;
}
void work(int flag)
{
solve(LL, 0, -1);
solve(RR, 1, -1);
ll prefix = 0;
int goal;
for (int i = 1; i <= 32; ++i) { // 别忘了32!
prefix += (cnt[1][i] - cnt[0][i]);
if (prefix >= K) {
prefix -= (cnt[1][i] - cnt[0][i]);
goal = i;
break;
}
}
// printf("goal = %d\n", goal);
ll left = K - prefix + cnt[0][goal];
ll high = RR, low = LL, mid;
while (low < high) {
mid = (1ll * low + high) / 2;
ll tmp = solve(mid, -1, goal);
if (tmp < left) low = mid + 1;
else high = mid;
}
ll ans = high;
if (flag) ans = -(base - high);
printf("%lld\n", ans);
}
int main()
{
memset(dp, -1, sizeof(dp));
scanf("%d", &T);
while (T--) { // 注意n * m >= 0
scanf("%lld%lld%d", &L, &R, &K);
if (L < 0 && R < 0) {
LL = base + L - 1, RR = base + R;
work(1);
} else if (L < 0 && R == 0) {
if (K == 1) printf("0\n");
else {
K--;
LL = base + L - 1, RR = base - 1;
work(1);
}
} else if (L == 0 && R == 0) { // K = 0
printf("0\n");
} else if (L == 0 && R > 0) {
if (K == 1) printf("0\n");
else {
K--;
LL = 0, RR = R;
work(0);
}
} else { // L > 0 && R > 0
LL = L - 1, RR = R;
work(0);
}
}
return 0;
}