传送门
给定一个长度为 n n n 的非负整数数组 a a a 和一个整数 k k k
求出 a a a 中有多少个非空子序列: a i , a i + 1 , . . . a m a_i,a_{i + 1},...a_m ai,ai+1,...am 满足:
答案模除 1 e 9 + 7 1e9 + 7 1e9+7
数据范围: 1 ≤ n ≤ 2 ⋅ 1 0 5 , 0 ≤ k ≤ 1 0 9 1 \leq n \leq 2\cdot 10^5, 0 \leq k \leq 10^9 1≤n≤2⋅105,0≤k≤109
t a g s : tags: tags: T r i e Trie Trie 维护异或和小于 k k k 的信息
先将数组 a a a 排序,可以发现:如果我们确定了 m i n = a j min = a_j min=aj 和 m a x = a i max = a_i max=ai 值之后,中间的元素可以任意选,共有 2 i − j − 1 2^{i - j - 1} 2i−j−1 种可能性 ( i < j ) (i < j) (i<j), i = j i = j i=j 时异或和是 0 0 0,所以也符合题目要求。
那么我们可以枚举最大值 m a x = a i max = a_i max=ai,将其前面所有的值都插入到 T r i e Trie Trie 中,对于一个 a j a_j aj,它贡献的答案为: 2 i − j − 1 = 2 i − 1 2 j 2^{i - j - 1} = \dfrac{2^{i - 1}}{2 ^ j} 2i−j−1=2j2i−1,所以我们只需要维护 1 2 j \dfrac{1}{2^j} 2j1 这个信息就可以。对于当前的 a i a_i ai,它作为 m a x max max 的答案贡献为: 1 + 2 i − 1 × ∑ a j ⨁ a i ≤ k 1 2 j 1 + 2 ^ {i - 1} \times \sum_{a_j \bigoplus a_i \leq k} \dfrac{1} {2^j} 1+2i−1×∑aj⨁ai≤k2j1,前面的 1 1 1 是 i = j i = j i=j 的情况,后面是所有符合条件的 j j j
如何在 T r i e Trie Trie 上得到所有 a j ⨁ a i ≤ k a_j \bigoplus a_i \leq k aj⨁ai≤k 的那些 a j a_j aj 的信息?
我们可以每一步按照 a i ⨁ k a_i \bigoplus k ai⨁k 游走,也就是当前 a j ⨁ a i a_j \bigoplus a_i aj⨁ai 的前缀与 k k k 的前缀相等,如果可以一直游走到最底层的话,说明这个叶子节点是所有 a j ⨁ a i = k a_j \bigoplus a_i = k aj⨁ai=k 的信息集合。那么小于的情况,对于当前 k k k 的这一位是 1 1 1 的话,我们可以假设 a j ⨁ a i a_j \bigoplus a_i aj⨁ai 与 k k k 出现的第一个不一样的位就是这一位,前面的前缀异或都是与 k k k 一样的。那么显然 a j ⨁ a i a_j \bigoplus a_i aj⨁ai 的这一位必须得是 0 0 0,因为它们异或要小于 k k k,只要 a j a_j aj 与 a i a_i ai 这一位相等,它们这一位异或小于 k k k,那么对于后面的位, a j a_j aj 如何变化都不会影响它们异或小于 k k k 这一事实,因为前面已经出现了一个小于 k k k 的位。
所以对于 k k k 为 1 1 1 的那些位,我们看看能不能找到与 a i a_i ai 这一位相等的那些 a j a_j aj,统计它们的信息即可。注意在 T r i e Trie Trie 上游走的时候要判断一下下一个节点是否存在
时间复杂度: O ( n log a i ) O(n \log a_i) O(nlogai)
#include
#define fore(i,l,r) for(int i=(int)(l);i<(int)(r);++i)
#define fi first
#define se second
#define endl '\n'
#define ull unsigned long long
#define ALL(v) v.begin(), v.end()
#define Debug(x, ed) std::cerr << #x << " = " << x << ed;
const int INF=0x3f3f3f3f;
const long long INFLL=1e18;
typedef long long ll;
template<class T>
constexpr T power(T a, ll b){
T res = 1;
while(b){
if(b&1) res = res * a;
a = a * a;
b >>= 1;
}
return res;
}
constexpr ll mul(ll a,ll b,ll mod){ //快速乘,避免两个long long相乘取模溢出
ll res = a * b - ll(1.L * a * b / mod) * mod;
res %= mod;
if(res < 0) res += mod; //误差
return res;
}
template<ll P>
struct MLL{
ll x;
constexpr MLL() = default;
constexpr MLL(ll x) : x(norm(x % getMod())) {}
static ll Mod;
constexpr static ll getMod(){
if(P > 0) return P;
return Mod;
}
constexpr static void setMod(int _Mod){
Mod = _Mod;
}
constexpr ll norm(ll x) const{
if(x < 0){
x += getMod();
}
if(x >= getMod()){
x -= getMod();
}
return x;
}
constexpr ll val() const{
return x;
}
explicit constexpr operator ll() const{
return x; //将结构体显示转换为ll类型: ll res = static_cast(OBJ)
}
constexpr MLL operator -() const{ //负号,等价于加上Mod
MLL res;
res.x = norm(getMod() - x);
return res;
}
constexpr MLL inv() const{
assert(x != 0);
return power(*this, getMod() - 2); //用费马小定理求逆
}
constexpr MLL& operator *= (MLL rhs) & { //& 表示“this”指针不能指向一个临时对象或const对象
x = mul(x, rhs.x, getMod()); //该函数只能被一个左值调用
return *this;
}
constexpr MLL& operator += (MLL rhs) & {
x = norm(x + rhs.x);
return *this;
}
constexpr MLL& operator -= (MLL rhs) & {
x = norm(x - rhs.x);
return *this;
}
constexpr MLL& operator /= (MLL rhs) & {
return *this *= rhs.inv();
}
friend constexpr MLL operator * (MLL lhs, MLL rhs){
MLL res = lhs;
res *= rhs;
return res;
}
friend constexpr MLL operator + (MLL lhs, MLL rhs){
MLL res = lhs;
res += rhs;
return res;
}
friend constexpr MLL operator - (MLL lhs, MLL rhs){
MLL res = lhs;
res -= rhs;
return res;
}
friend constexpr MLL operator / (MLL lhs, MLL rhs){
MLL res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream& operator >> (std::istream& is, MLL& a){
ll v;
is >> v;
a = MLL(v);
return is;
}
friend constexpr std::ostream& operator << (std::ostream& os, MLL& a){
return os << a.val();
}
friend constexpr bool operator == (MLL lhs, MLL rhs){
return lhs.val() == rhs.val();
}
friend constexpr bool operator != (MLL lhs, MLL rhs){
return lhs.val() != rhs.val();
}
};
const ll mod = 1e9 + 7;
using Z = MLL<mod>;
const int N = 200050;
int cnt;
struct node{
int son[2];
Z val;
}tree[N * 30];
void init(){
fore(i, 0, cnt + 1){
tree[i].son[0] = tree[i].son[1] = 0;
tree[i].val = 0;
}
cnt = 0;
}
void insert(int x, Z w){
int now = 0;
for(int i = 30; i >= 0; --i){
int nxt = x >> i & 1;
if(!tree[now].son[nxt])
tree[now].son[nxt] = ++cnt;
now = tree[now].son[nxt];
tree[now].val += w;
}
}
Z query(int x, int k){
Z res = 0;
int now = 0;
for(int i = 30; i >= 0; --i){ //now往下走并且前缀异或一定等于k
int nxt = x >> i & 1; //x当前的位
if(k >> i & 1){ //如果k这一位是1,我们假设这一位是mx和mn第一个不同的高位,也就是异或结果为0
res += tree[tree[now].son[nxt]].val; //那么要加上mx和mn这一位相等的贡献
now = tree[now].son[nxt ^ 1]; //为了保持前缀异或等于k,我们要往k_i ^ x ^ 1转移
}
else now = tree[now].son[nxt]; //k_i = 0,0 ^ nxt = nxt
if(!now) break; //下面已经没有有用的节点了
}
res += tree[now].val; //异或等于k的情况,搜到了叶子节点
return res;
}
int main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int t;
std::cin >> t;
while(t--){
init();
int n, k;
std::cin >> n >> k;
std::vector<int> a(n + 1);
fore(i, 1, n + 1) std::cin >> a[i];
std::sort(a.begin() + 1, a.end());
Z ans = 0;
fore(i, 1, n + 1){
ans += 1 + query(a[i], k) * power(Z(2), i - 1); //1是只有它自己的情况
insert(a[i], power(Z(2), i).inv());
}
std::cout << ans << endl;
}
return 0;
}