http://acm.hdu.edu.cn/showproblem.php?pid=3333
题意:
给你一个有N个数的序列,并且给你一些区间[ L , R ],要你查询给定区间内不同整数的和(相同的整数只需要算一次)。
思路:
想了好久都没有比较好的办法来处理这个题目,因为线段树的左右两个孩子都是相互独立的,因此无法处理父结点要包含两个孩子结点信息的情况(也就是父结点必须要包含左右孩子都有什么数的信息)。这样我们就不能直接用常规的线段树。到网上搜了解题报告才不禁感叹本题的神奇。本题的正确解法应该是:离线+离散化+线段树。 具体应该是这样的, 首先我们将所有的整数的值进行离散化,然后读入所有的query ,并且按照每个query的结束点进行排序,然后就是依次枚举原序列,当枚举到第i和元素的时候,如果第i个元素在前面的某个位置出现过(这个很容易实现, 我们只需要用一个pre数组记录即可),那个就把前面那个位置的值删除(删除就是在那个位置加上一个-val),然后在第i个位置将val插入。 这样做是正确的,这是因为对于结束点在第i个点之后的query ,一定是不在需要它之前的那个值了, 因为第i个位置又要插入了,而为什么要不插入第i位置的值,却要删除它前面出现的那个位置的值,这是因为如果不插入,就又可能导致query的时候该值不包括在query的区间内,因此需要删除前面的,再插入后面的。这样当每次枚举的点i与query的结束点相同时,这时候该query就可以得出解了, 我们只需要询问线段树中[ query.l , query.r ]内的和就可以了。
代码:
#include<stdio.h> #include<string.h> #include<algorithm> typedef __int64 LL ; const int MAXN = 30010 ; int N , M ; LL val[MAXN] , v[MAXN] ; int hash[MAXN] ; int find(LL vv, int l , int r){ int mid ; while( l<r ){ mid = (l + r) >> 1 ; if( v[mid] < vv ) l = mid + 1 ; else r = mid ; } return l ; } struct Node{ int s , e , i ; LL ans ; }q[100010] ; bool comp1(Node n1 , Node n2){ return n1.e < n2.e ; } bool comp2(Node n1 , Node n2){ return n1.i < n2.i ; } int pre[MAXN] ; LL sum[MAXN<<2] ; void update(int l ,int r , int idx, int pos , LL v ){ //printf("LR:%d %d\n" ,l ,r); if(l == r) { sum[idx] += v; return ; } int mid = (l + r) >> 1 , ls = idx<<1 , rs = idx<<1|1 ; if( pos<=mid ) update(l , mid , ls , pos , v); else update(mid+1, r, rs , pos , v) ; sum[idx] = sum[ls] + sum[rs] ; //printf("LRS:%d %d %d\n",l ,r ,sum[idx] ); } LL query(int l ,int r, int idx, int a, int b){ if(l==a && r==b) return sum[idx] ; int mid = (l + r) >> 1 , ls = idx<<1 ,rs = idx<<1|1 ; if( b<=mid ) return query(l , mid , ls , a ,b ); else if( mid<a ) return query(mid+1, r , rs ,a ,b ) ; else{ return query(l , mid , ls , a , mid) + query( mid+1, r , rs, mid+1, b) ; } } void solve( int nn ){ scanf("%d",&M); for(int i=1;i<=M;i++){ scanf("%d %d",&q[i].s , &q[i].e ); q[i].i = i ; } std::sort(q+1,q+1+M,comp1); memset( pre, -1 , sizeof(pre) ); memset( sum , 0 , sizeof(sum) ); int pos = 1 ; for(int i=1;i<=N;i++){ if( pre[ hash[i] ] != -1 ){ //元素已经存在 update(1 , N , 1 ,pre[ hash[i] ] , -val[i] ); } pre[ hash[i] ] = i ; update(1 , N , 1 , i , val[i] ) ; while( pos<=M && q[pos].e==i ){ q[pos].ans = query(1, N ,1 , q[pos].s , q[pos].e ); pos++ ; } } std::sort(q+1,q+1+M, comp2); for(int i=1;i<=M;i++){ printf("%I64d\n",q[i].ans); } } int main(){ int T ; scanf("%d",&T) ; for(int cas=1;cas<=T;cas++){ scanf("%d",&N); for(int i=1;i<=N;i++){ scanf("%I64d",&val[i]); v[i] = val[i] ; } std::sort(v+1, v+1+N) ; int n = 2 ; for(int i=2;i<=N;i++){ if( v[i] != v[i-1] ) v[n++] = v[i] ; } n-- ; for(int i=1;i<=N;i++){ hash[i] = find( val[i] , 1 , n ) ; } solve( n ); } return 0 ; }