2024牛客寒假算法基础集训营2

C Tokitsukaze and Min-Max XOR

题目大意

  • 给定一个数组a
  • a任取数构成序列b
  • 序列b满足b_i=min\ b,b_j=max\ b,b_i\ xor\ b_j \leq k,(可以只取一个数)
  • 问能构造出多少个b
  • mod=1e9+7

解题思路

  • maxmin
  • 双枚举时间复杂度到O(n^2),考虑利用01Trie加速统计b_i \ xor\ b_j< k的方案
  • 01Trie,即将数字按二进制位拆分挂在树上
  • 对于一个数,它在树上经过的点,均加上它对答案的贡献
  • 所以树上的某一点存的信息为,以这个点的数位为分界,在它之前(包括它)均为某固定值,而在它后均为任意值的数对答案的贡献
  • 若当前值为a_j,令其为max,则小于a_j的有j-1个,不考虑限制的话,共有2^{j-1}种选取min的方式
  • 若选取了a_i作为min,则1\rightarrow i均确定了是否选取
  • 所以min=a_i,max=a_j,方案数为2^{j-1}/2^i
  • 若在01Trie上,当前为第t位,则前t-1位均与a_j\ xor\ k对应数位相同
  • k的第t位为1,且01Trie存在着与a_j的第t位相同的点
  • a_j与该点往下的所有值异或后,第t位均为0
  • 答案加上该点值,即ans+=\sum_{i} Val(a_i),a_i\ xor\ a_j< k
  • 不必在从该点往下走挨个统计,实现了加速
  • 若能在01Trie上走到最后,则存在a_i\ xor\ a_j=k,答案加上该点值
  • 最后处理完了a_j,在将其加入01Trie
  • 由于b数组长度可以为1,所以初始ans=n
  • 最终复杂度为O(nlogn)
  • 注意Trie上为31\rightarrow 0

import java.io.*;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Random;
import java.util.StringTokenizer;
import java.util.Vector;








public class Main{
	static long md=(long)1e9+7L;
	
	static long qm(long a,long b) {
		long res=1;
		while(b>0) {
			if((b&1)==1)res=(res*a)%md;
			a=a*a%md;
			b>>=1;
		}
		return res;
	}
	
	static
	class Trie{
		long[] cnt;//一个数分成32位,共32层,最后一层最多为n
    	int[][] tr;//01trie
    	int t;
    	public Trie(int n) {
    		cnt=new long[n*32];
    		tr=new int[n*32][2];
    		t=0;
    	}
    	void insert(int x,long y) {
    		int p=0;
    		for (int i = 31; i >= 0; i--) {
    			int j=(x>>i)&1;
    			if(tr[p][j]==0) {
    				t++;
    				tr[p][j]=t;
    			}
    			p=tr[p][j];
    			cnt[p]=(cnt[p]+y)%md;//在沿途放上,便于统计小于k的情况
    			
    		}
    	}
    	long query(int x,int k) {
    		long res=0;
    		int p=0;
    		for (int i = 31; i >= 0; i--) {
    			int xt=(x>>i)&1;
    			int kt=(k>>i)&1;
    			if(kt==1&&tr[p][xt]!=0) {
    				res=(res+cnt[tr[p][xt]])%md;
    				//统计i^j0) {
	    	int n=input.nextInt();
	    	int k=input.nextInt();
	    	int[] a=new int[n+1];
	    	for(int i=1;i<=n;++i)a[i]=input.nextInt();
	    	Arrays.sort(a,1,n+1);
	    	
	    	
	    	Trie Tp=new Trie(n+1);
	    	long ans=n;//min=i,max=j,i=j的情况
	    	for(int i=1;i<=n;++i) {
	    		//此时trie上的数均小于a[i]
	    		int x=a[i];
	    		ans=(ans+Tp.query(x, k)*qm(2, i-1)%md)%md;
	    		long inv=qm(qm(2L, i), md-2);
	    		Tp.insert(x, inv);
	    	}
	    	out.println(ans);
	    	T--;
	    }
 	    out.flush();
	    out.close();
	}
	static
	class AReader{
	    BufferedReader bf;
	    StringTokenizer st;
	    BufferedWriter bw;

	    public AReader(){
	        bf=new BufferedReader(new InputStreamReader(System.in));
	        st=new StringTokenizer("");
	        bw=new BufferedWriter(new OutputStreamWriter(System.out));
	    }
	    public String nextLine() throws IOException{
	        return bf.readLine();
	    }
	    public String next() throws IOException{
	        while(!st.hasMoreTokens()){
	            st=new StringTokenizer(bf.readLine());
	        }
	        return st.nextToken();
	    }
	    public char nextChar() throws IOException{
	        //确定下一个token只有一个字符的时候再用
	        return next().charAt(0);
	    }
	    public int nextInt() throws IOException{
	        return Integer.parseInt(next());
	    }
	    public long nextLong() throws IOException{
	        return Long.parseLong(next());
	    }
	    public double nextDouble() throws IOException{
	        return Double.parseDouble(next());
	    }
	    public float nextFloat() throws IOException{
	        return Float.parseFloat(next());
	    }
	    public byte nextByte() throws IOException{
	        return Byte.parseByte(next());
	    }
	    public short nextShort() throws IOException{
	        return Short.parseShort(next());
	    }
	    public BigInteger nextBigInteger() throws IOException{
	        return new BigInteger(next());
	    }
	    public void println() throws IOException {
	        bw.newLine();
	    }
	    public void println(int[] arr) throws IOException{
	        for (int value : arr) {
	            bw.write(value + " ");
	        }
	        println();
	    }
	    public void println(int l, int r, int[] arr) throws IOException{
	        for (int i = l; i <= r; i ++) {
	            bw.write(arr[i] + " ");
	        }
	        println();
	    }
	    public void println(int a) throws IOException{
	        bw.write(String.valueOf(a));
	        bw.newLine();
	    }
	    public void print(int a) throws IOException{
	        bw.write(String.valueOf(a));
	    }
	    public void println(String a) throws IOException{
	        bw.write(a);
	        bw.newLine();
	    }
	    public void print(String a) throws IOException{
	        bw.write(a);
	    }
	    public void println(long a) throws IOException{
	        bw.write(String.valueOf(a));
	        bw.newLine();
	    }
	    public void print(long a) throws IOException{
	        bw.write(String.valueOf(a));
	    }
	    public void println(double a) throws IOException{
	        bw.write(String.valueOf(a));
	        bw.newLine();
	    }
	    public void print(double a) throws IOException{
	        bw.write(String.valueOf(a));
	    }
	    public void print(char a) throws IOException{
	        bw.write(String.valueOf(a));
	    }
	    public void println(char a) throws IOException{
	        bw.write(String.valueOf(a));
	        bw.newLine();
	    }
	}
}

你可能感兴趣的:(算法)