统计的力量

Orz zkw!!!

最近看完了《统计的力量》……觉得这实在是太神了……原来线段树可以这么写……

zkw线段树的思想:先将线段长度N变为2的整数次方,使线段树成为满二叉树,然后就可以通过各种位运算直接链接到某个结点,不必递归了,因此大大减小了常数……

本沙茶利用zkw线段树在BZOJ1756和1798上都刷到了rank3……
代码:
<1>BZOJ1756:
#include  < iostream >
#include 
< stdio.h >
#include 
< stdlib.h >
#include 
< string .h >
using   namespace  std;
#define  re(i, n) for (int i=0; i<n; i++)
#define  re1(i, n) for (int i=1; i<=n; i++)
#define  re2(i, l, r) for (int i=l; i<r; i++)
#define  re3(i, l, r) for (int i=l; i<=r; i++)
#define  rre(i, n) for (int i=n-1; i>=0; i--)
#define  rre1(i, n) for (int i=n; i>0; i--)
#define  rre2(i, r, l) for (int i=r-1; i>=l; i--)
#define  rre3(i, r, l) for (int i=r; i>=l; i--)
#define  ll long long
const   int  MAXN  =  ( 1   <<   19 +   10 , INF  =   ~ 0U   >>   2 ;
struct  node {
    
int  sum, lv, rv, v;
} T[MAXN 
<<   1 ];
int  n, N, A[MAXN], Z[MAXN], res;
inline 
int  get_int()
{
    
int  x;  char  ch;  bool  FF;
    
while  ((ch  =  getchar())  <   48   &&  ch  !=   ' - ' ) ;
    
if  (ch  ==   ' - ' ) {FF  =   1 ; x  =   0 ;}  else  {FF  =   0 ; x  =  ch  -   48 ;}
    
while  ((ch  =  getchar())  >=   48 ) x  =  x  *   10   +  ch  -   48 ;
    
if  (FF) x  =   - x;  return  x;
}
void  prepare()
{
    N 
=  n  <<   1 ;
    re2(i, n, N) T[i].sum 
=  T[i].lv  =  T[i].rv  =  T[i].v  =  A[i  -  n];
    
for  ( int  i = 0 ; ( 1 << i) <= n; i ++ ) Z[ 1   <<  i]  =  i;
    
int  lch, rch, _;
    rre2(i, n, 
1 ) {
        lch 
=  i  <<   1 ; rch  =  lch  ^   1 ;
        T[i].sum 
=  T[lch].sum  +  T[rch].sum;
        T[i].lv 
=  (_  =  T[lch].sum  +  T[rch].lv)  >=  T[lch].lv  ?  _ : T[lch].lv;
        T[i].rv 
=  (_  =  T[rch].sum  +  T[lch].rv)  >=  T[rch].rv  ?  _ : T[rch].rv;
        T[i].v 
=  T[lch].v  >=  T[rch].v  ?  T[lch].v : T[rch].v;
        
if  ((_  =  T[lch].rv  +  T[rch].lv)  >=  T[i].v) T[i].v  =  _;
    }
}
void  opr0( int  pos,  int  x)
{
    
int  i  =  pos  +  n; T[i].sum  =  T[i].lv  =  T[i].rv  =  T[i].v  =  x;  int  _, __  =  x  -  A[pos], lch, rch; A[pos]  =  x;
    
for  (i >>= 1 ; i; i >>= 1 ) {
        lch 
=  i  <<   1 ; rch  =  lch  ^   1 ;
        T[i].sum 
+=  __;
        T[i].lv 
=  (_  =  T[lch].sum  +  T[rch].lv)  >=  T[lch].lv  ?  _ : T[lch].lv;
        T[i].rv 
=  (_  =  T[rch].sum  +  T[lch].rv)  >=  T[rch].rv  ?  _ : T[rch].rv;
        T[i].v 
=  T[lch].v  >=  T[rch].v  ?  T[lch].v : T[rch].v;
        
if  ((_  =  T[lch].rv  +  T[rch].lv)  >=  T[i].v) T[i].v  =  _;
    }
}
void  opr1( int  l,  int  r)
{
    
int  sum0  =   0 , l0, i, _; l  |=  n; r  |=  n; r ++ ; res  =   - INF;
    
for  (; l0 = l, (l += l &- l) <= r; ) {
        i 
=  l0  /  (l0  &   - l0);
        
if  (T[i].v  >  res) res  =  T[i].v;
        
if  ((_  =  sum0  +  T[i].lv)  >  res) res  =  _;
        sum0 
+=  T[i].sum;  if  (T[i].rv  >  sum0) sum0  =  T[i].rv;
    }
    
int  s  =  (l0  &   - l0)  >>   1 , z  =  Z[s];
    
for  (; l0 < r; s >>= 1 , z -- if  ((l  =  l0  +  s)  <=  r) {
        i 
=  l0  >>  z;
        
if  (T[i].v  >  res) res  =  T[i].v;
        
if  ((_  =  sum0  +  T[i].lv)  >  res) res  =  _;
        sum0 
+=  T[i].sum;  if  (T[i].rv  >  sum0) sum0  =  T[i].rv;
        l0 
=  l;
    }
}
int  main()
{
    
int  n0, m, x, y, z;
    n0 
=  get_int(); m  =  get_int(); re(i, n0) A[i]  =  get_int();  for  (n = 1 ; n < n0; n <<= 1 ) ;
    prepare();
    re(i, m) {
        x 
=  get_int(); y  =  get_int(); z  =  get_int();
        
if  (x  ==   1 ) { if  (y  >  z) {x  =  y; y  =  z; z  =  x;} opr1( -- y,  -- z); printf( " %d\n " , res);}  else  opr0( -- y, z);
    }
    
return   0 ;
}


<2>BZOJ1798:
#include  < iostream >
#include 
< stdio.h >
#include 
< stdlib.h >
#include 
< string .h >
using   namespace  std;
#define  re(i, n) for (int i=0; i<n; i++)
#define  re1(i, n) for (int i=1; i<=n; i++)
#define  re2(i, l, r) for (int i=l; i<r; i++)
#define  re3(i, l, r) for (int i=l; i<=r; i++)
#define  rre(i, n) for (int i=n-1; i>=0; i--)
#define  rre1(i, n) for (int i=n; i>0; i--)
#define  rre2(i, r, l) for (int i=r-1; i>=l; i--)
#define  rre3(i, r, l) for (int i=r; i>=l; i--)
#define  ll long long
const   int  MAXN  =  ( 1   <<   17 +   10 ;
struct  node {
    ll mr0, mr1, sum;
    
int  len;
} T[MAXN 
<<   1 ];
int  n, s, N, A[MAXN];
ll MOD, res;
inline 
int  get_int()
{
    
char  ch;  int  x;
    
while  ((ch  =  getchar())  <   48 ) ;
    x 
=  ch  -   48 while  ((ch  =  getchar())  >   47 ) x  =  x  *   10   +  ch  -   48 ;
    
return  x;
}
void  prepare()
{
    N 
=  n  <<   1 int  lch, rch;
    re2(i, n, N) {T[i].mr0 
=   1 ; T[i].sum  =  A[i  -  n]  %  MOD; T[i].len  =   0 ;}
    rre2(i, n, 
1 ) {
        lch 
=  i  <<   1 ; rch  =  lch  ^   1 ; T[i].len  =  T[lch].len  +   1 ;
        T[i].mr0 
=   1 ; T[i].sum  =  T[lch].sum  +  T[rch].sum;  if  (T[i].sum  >=  MOD) T[i].sum  -=  MOD;
    }
}
inline 
void  dm( int  i)
{
    
int  lch  =  i  <<   1 , rch  =  lch  ^   1 ; ll c0;
    
if  ((c0  =  T[i].mr0)  ^   1 ) {
        T[i].mr0 
=   1 ;
        T[lch].mr0 
=  T[lch].mr0  *  c0  %  MOD; T[lch].mr1  =  T[lch].mr1  *  c0  %  MOD; T[lch].sum  =  T[lch].sum  *  c0  %  MOD;
        T[rch].mr0 
=  T[rch].mr0  *  c0  %  MOD; T[rch].mr1  =  T[rch].mr1  *  c0  %  MOD; T[rch].sum  =  T[rch].sum  *  c0  %  MOD;
    }
    
if  (c0  =  T[i].mr1) {
        T[i].mr1 
=   0 ;
        T[lch].mr1 
+=  c0;  if  (T[lch].mr1  >=  MOD) T[lch].mr1  -=  MOD; T[lch].sum  =  (T[lch].sum  +  (c0  <<  T[lch].len))  %  MOD;
        T[rch].mr1 
+=  c0;  if  (T[rch].mr1  >=  MOD) T[rch].mr1  -=  MOD; T[rch].sum  =  (T[rch].sum  +  (c0  <<  T[rch].len))  %  MOD;
    }
}
void  opr0( int  l,  int  r, ll c)
{
    
int  k, l0  =  l  |  n, r0  =  r  |  n, i, j, lch, rch; ll c0;
    
for  (k = s - 1 ; k && (i = l0 >> k) == r0 >> k; k -- ) dm(i);
    
for  ( int  k0 = k; ((i = l0 >> k0) << k0) ^ l0; k0 -- ) dm(i);
    
for  ( int  k0 = k; (((i = r0 >> k0) << k0) | (( 1 << k0) - 1 )) ^ r0; k0 -- ) dm(i);
    r0
++ int  l1;
    
for  (; l1 = l0, (l0 += l0 &- l0) <= r0; ) {
        i 
=  l1  /  (l1  &   - l1);
        T[i].mr0 
=  T[i].mr0  *  c  %  MOD; T[i].mr1  =  T[i].mr1  *  c  %  MOD; T[i].sum  =  T[i].sum  *  c  %  MOD;
    }
    
int  _  =  (l1  &   - l1)  >>   1 , z  =  __builtin_ctz(_);
    
for  (; l1 < r0; _ >>= 1 , z -- if  (l1  +  _  <=  r0) {
        i 
=  l1  >>  z;
        T[i].mr0 
=  T[i].mr0  *  c  %  MOD; T[i].mr1  =  T[i].mr1  *  c  %  MOD; T[i].sum  =  T[i].sum  *  c  %  MOD;
        l1 
+=  _;
    }
    l0 
=  l  |  n; r0  =  r  |  n;
    
for  (k = 0 ; (i = l0 >> k) ^ (j = r0 >> k); k ++ ) {
        
if  ((i  <<  k)  ^  l0) {T[i].sum  =  T[i  <<   1 ].sum  +  T[(i  <<   1 ^   1 ].sum;  if  (T[i].sum  >=  MOD) T[i].sum  -=  MOD;}
        
if  (((j  <<  k)  |  (( 1   <<  k)  -   1 ))  ^  r0) {T[j].sum  =  T[j  <<   1 ].sum  +  T[(j  <<   1 ^   1 ].sum;  if  (T[j].sum  >=  MOD) T[j].sum  -=  MOD;}
    }
    
for  (; k < s; k ++ ) {
        i 
=  l0  >>  k;
        
if  ((i  <<  k)  ^  l0  ||  ((j  <<  k)  |  (( 1   <<  k)  -   1 ))  ^  r0) {T[i].sum  =  T[i  <<   1 ].sum  +  T[(i  <<   1 ^   1 ].sum;  if  (T[i].sum  >=  MOD) T[i].sum  -=  MOD;}
    }
}
void  opr1( int  l,  int  r, ll c)
{
    
int  k, l0  =  l  |  n, r0  =  r  |  n, i, j, lch, rch; ll c0;
    
for  (k = s - 1 ; k && (i = l0 >> k) == r0 >> k; k -- ) dm(i);
    
for  ( int  k0 = k; ((i = l0 >> k0) << k0) ^ l0; k0 -- ) dm(i);
    
for  ( int  k0 = k; (((i = r0 >> k0) << k0) | (( 1 << k0) - 1 )) ^ r0; k0 -- ) dm(i);
    r0
++ int  l1;
    
for  (; l1 = l0, (l0 += l0 &- l0) <= r0; ) {
        i 
=  l1  /  (l1  &   - l1);
        T[i].mr1 
+=  c;  if  (T[i].mr1  >=  MOD) T[i].mr1  -=  MOD; T[i].sum  =  (T[i].sum  +  (c  <<  T[i].len))  %  MOD;
    }
    
int  _  =  (l1  &   - l1)  >>   1 , z  =  __builtin_ctz(_);
    
for  (; l1 < r0; _ >>= 1 , z -- if  (l1  +  _  <=  r0) {
        i 
=  l1  >>  z;
        T[i].mr1 
+=  c;  if  (T[i].mr1  >=  MOD) T[i].mr1  -=  MOD; T[i].sum  =  (T[i].sum  +  (c  <<  T[i].len))  %  MOD;
        l1 
+=  _;
    }
    l0 
=  l  |  n; r0  =  r  |  n;
    
for  (k = 0 ; (i = l0 >> k) ^ (j = r0 >> k); k ++ ) {
        
if  ((i  <<  k)  ^  l0) {T[i].sum  =  T[i  <<   1 ].sum  +  T[(i  <<   1 ^   1 ].sum;  if  (T[i].sum  >=  MOD) T[i].sum  -=  MOD;}
        
if  (((j  <<  k)  |  (( 1   <<  k)  -   1 ))  ^  r0) {T[j].sum  =  T[j  <<   1 ].sum  +  T[(j  <<   1 ^   1 ].sum;  if  (T[j].sum  >=  MOD) T[j].sum  -=  MOD;}
    }
    
for  (; k < s; k ++ ) {
        i 
=  l0  >>  k;
        
if  ((i  <<  k)  ^  l0  ||  ((j  <<  k)  |  (( 1   <<  k)  -   1 ))  ^  r0) {T[i].sum  =  T[i  <<   1 ].sum  +  T[(i  <<   1 ^   1 ].sum;  if  (T[i].sum  >=  MOD) T[i].sum  -=  MOD;}
    }
}
void  opr2( int  l,  int  r)
{
    res 
=   0 ;
    
int  k, l0  =  l  |  n, r0  =  r  |  n, i, j, lch, rch; ll c0;
    
for  (k = s - 1 ; k && (i = l0 >> k) == r0 >> k; k -- ) dm(i);
    
for  ( int  k0 = k; ((i = l0 >> k0) << k0) ^ l0; k0 -- ) dm(i);
    
for  ( int  k0 = k; (((i = r0 >> k0) << k0) | (( 1 << k0) - 1 )) ^ r0; k0 -- ) dm(i);
    r0
++ int  l1, r1;
    
for  (; l1 = l0, (l0 += l0 &- l0) <= r0; ) {
        i 
=  l1  /  (l1  &   - l1); res  +=  T[i].sum;
    }
    
int  _  =  (l1  &   - l1)  >>   1 , z  =  __builtin_ctz(_);
    
for  (; l1 < r0; _ >>= 1 , z -- if  (l1  +  _  <=  r0) {
        i 
=  l1  >>  z; res  +=  T[i].sum;
        l1 
+=  _;
    }
    res 
%=  MOD;
}
int  main()
{
    
int  n0  =  get_int(); MOD  =  get_int(); re(i, n0) A[i]  =  get_int();  for  (n = 1 , s = 0 ; n < n0; n <<= 1 , s ++ ) ; s ++ ;
    prepare(); 
int  M, _, l, r, c; M  =  get_int();
    re(i, M) {
        _ 
=  get_int(); l  =  get_int(); r  =  get_int(); l -- ; r -- ;
        
if  (_  ==   1 ) {
            c 
=  get_int()  %  MOD; opr0(l, r, c);
        } 
else   if  (_  ==   2 ) {
            c 
=  get_int()  %  MOD; opr1(l, r, c);
        } 
else  {
            opr2(l, r); printf(
" %d\n " , ( int ) res);
        }
    }
    
return   0 ;
}

你可能感兴趣的:(统计的力量)