HDU 4913 (思路题目)

题目意思是给定序列长度同为n的数列a,b;

要求出所有的子集元素的 2 ^ (max(a) ) * 3 ^ (max(b))的和。

分析:

首先对只有一个数组a,我们的通常思路是先排序 , 递推算法求解。 ans[ i ]  = ans[ i-1 ] + 2^(i-1) * 2^a[ i ];

那么多了一个b数组,可以先按b排序 。

记走到 i , 前面 a的值比ai小的有x个, 比 ai 大的 位置为 p1 , p2 , ... pk.

那么 走到i,只需统计 i 必选的和。

ans[ i ] = ans[ i-1 ] + (2^x * 2^ai  +  2^x*2^ap1 + 2^(x+1)*2^ap2 + .... + 2^(x+k-1)*2^apk) * 3^bi;

这样先给每个a[ i ] 一个rank(ai越小rank越小) ,加入线段树,维护2^(x)*2^ai, 其中x代表走到i时,比ai小的aj 有几个。

#include <cstdio>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <map>
#define ls rt<<1
#define rs rt<<1|1
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
using namespace std;
typedef long long ll;
#define rep(i,n) for(int i=0;i<(int) n; i++)
#define rep1(i,x,y) for(int i=x;i<=(int)y; i++)

const int maxn = 1e5 + 100;
const ll mod = 1e9 + 7;
struct node{
   ll sum, cnt;
}a[maxn<<2];
int n; ll col[maxn<<2];
ll pow_(int nn,int b){
   ll ans = 1 , te = nn;
   while(b){
       if(b&1) ans=ans*te%mod;
       b>>=1;
       te=te*te%mod;
   }
   return ans;
}
void push_up(int rt){
  a[rt].cnt = a[ls].cnt+a[rs].cnt;
  a[rt].sum = (a[ls].sum+a[rs].sum)%mod;
}
void push_down(int rt){
   if(col[rt] > 1){
      col[ls] = (col[ls]*col[rt])%mod;
      col[rs] = (col[rs]*col[rt])%mod;
      a[ls].sum = a[ls].sum*col[rt]%mod;
      a[rs].sum = a[rs].sum*col[rt]%mod;
      col[rt] = 1;
   }
}
void build(int l,int r,int rt){
    a[rt].sum = a[rt].cnt = 0; col[rt] = 1;
    if(l == r) return ;
    int m=(l+r)>>1;
    build(lson);
    build(rson);
}
void mul(int l,int r,int rt,int L,int R){
   if(L<=l&&r<=R){
      a[rt].sum = a[rt].sum*2%mod;
      col[rt]=col[rt]*2%mod;
      return ;
   }
   push_down(rt);
   int m=(l+r)>>1;
   if(L<=m) mul(lson,L,R);
   if(R> m) mul(rson,L,R);
   push_up(rt);
}
int query_cnt(int l,int r,int rt,int L,int R){
   if(L<=l&&r<=R)
      return a[rt].cnt;
   push_down(rt);
   int m=(l+r)>>1 , all = 0;
   if(L<=m) all += query_cnt(lson,L,R);
   if(R> m) all += query_cnt(rson,L,R);
   return all;
}
ll query_sum(int l,int r,int rt,int L,int R){
   if(L<=l&&r<=R)
      return a[rt].sum;
   push_down(rt);
   int m=(l+r)>>1 ; ll all = 0;
   if(L<=m) all += query_sum(lson,L,R);
   if(R> m) all = (all+query_sum(rson,L,R))%mod;
   return all;
}
void update(int l,int r,int rt,int p,int aa){
   if(l==r){
      a[rt].cnt = 1;
      ll x = (l==1 ? 0 : query_cnt(1,n,1,1,l-1));
      a[rt].sum = pow_(2,x+aa);
      return ;
   }
   push_down(rt);
   int m=(l+r)>>1;
   if(p <= m) update(lson,p,aa);
   if(p >  m) update(rson,p,aa);
   push_up(rt);
}
struct node2{
   int a,b,id;
}st[maxn];
int cmp1(node2 A, node2 B){ return  A.a < B.a;}
map<int,int> Rank;
int cmp2(node2 A, node2 B){ return A.b < B.b; }
int main()
{
   while(scanf("%d",&n)==1){
      build(1,n,1);
      rep1(i,1,n) scanf("%d %d",&st[i].a,&st[i].b),st[i].id = i;
      sort(st+1,st+1+n,cmp1);
      Rank.clear();
      rep1(i,1,n) Rank[st[i].id] = i ;
      sort(st+1,st+1+n,cmp2);
      ll ans = 0;
      rep1(i,1,n){
         update(1,n,1,Rank[st[i].id],st[i].a);
         ans = (ans + query_sum(1,n,1,Rank[st[i].id],n)*pow_(3,st[i].b)%mod)%mod;
         if(Rank[st[i].id]+1 <= n) mul(1,n,1,Rank[st[i].id]+1,n);
      }
      printf("%I64d\n",ans);
   }
   return 0;
}


你可能感兴趣的:(HDU 4913 (思路题目))