一个数列,有两种操作。一是修改数列中某个数,二是求数列中连续一段的积。
很明显的线段树单点更新模版题。
#include<stdio.h> #include<iostream> using namespace std; #define lch(x) ((x)<<1) #define rch(x) (((x)<<1)|1) #define ll long long const int mod = 1000000007; int a[50010]; struct node{ int l,r; ll val; }tree[200010]; void push_up(int nd){ if(tree[nd].l==tree[nd].r){ return ; } tree[nd].val=tree[lch(nd)].val*tree[rch(nd)].val; tree[nd].val%=mod; } void build_tree(int nd,int l,int r){ tree[nd].l=l; tree[nd].r=r; if(l==r){ tree[nd].val=a[l]; return; } int mid=(l+r)>>1; build_tree(lch(nd),l,mid); build_tree(rch(nd),mid+1,r); push_up(nd); } ll query(int nd,int l,int r){ if(tree[nd].l==l&&tree[nd].r==r){ return tree[nd].val; } int mid= (tree[nd].l+tree[nd].r)>>1; if(r<=mid){ return query(lch(nd),l,r); }else{ if(l>mid){ return query(rch(nd),l,r); }else{ ll a=query(lch(nd),l,mid); ll b=query(rch(nd),mid+1,r); return (a*b)%mod; } } } void update(int nd,int pos,int val){ if(tree[nd].l==tree[nd].r){ tree[nd].val=val; return; } int mid=(tree[nd].l+tree[nd].r)>>1; if(pos<=mid){ update(lch(nd),pos,val); }else{ update(rch(nd),pos,val); } push_up(nd); } int main(){ int t; cin>>t; while(t--){ int n; cin>>n; for(int i=1;i<=n;i++){ scanf("%d",&a[i]); } build_tree(1,1,n); int q; cin>>q; for(int i=1;i<=q;i++){ int op,a,b; scanf("%d%d%d",&op,&a,&b); if(op){ update(1,a,b); }else{ //query ll ans=query(1,a,b); printf("%I64d\n",ans); } } } return 0; }
天真的我曾经以为BIT只能维护和,不能维护积。但是比赛时一个学弟用BIT+逆元很好地教育了我。。膜拜一下Orz。。自己实现调了很久才调出来。
题目求的是数列一个区间的积,输出取模后的结果。当某个数改变后,我们要在BIT里先除以旧的数再乘上新的数。但是由于取模的原因,不能用除法,所以必须用乘上旧的数的逆元来代替。
#include <stdio.h> #include <iostream> using namespace std; #define ll long long const int maxn=50010; const int mod=1e9+7; int T,n; int a[maxn]; ll c[maxn]; //扩展欧几里德 void ExEuclid(int a,int b,ll &x,ll &y,ll &q){ if(b==0){ x=1;y=0;q=a; return; } ExEuclid(b,a%b,y,x,q); y-=x*(a/b); } //逆元 int inv(ll num){ ll x,y,q; ExEuclid(num,mod,x,y,q); if(q==1)return (x+mod)%mod; } int lowbit(int x){ return x&(-x); } void update(int pos,ll val){ int tmp=val; val*=inv(a[pos]); val%=mod; a[pos]=tmp; while(pos<=n){ c[pos]*=val; c[pos]%=mod; pos+=lowbit(pos); } } int product(int pos){ ll re=1; while(pos){ re*=c[pos]; re%=mod; pos-=lowbit(pos); } return re; } int main(){ cin>>T; while(T--){ cin>>n; for(int i=1;i<=n;i++)a[i]=c[i]=1; for(int i=1;i<=n;i++){ int num; scanf("%d",&num); update(i,num); } int q; cin>>q; for(int i=1;i<=q;i++){ int op; scanf("%d",&op); if(op){ int k,p; scanf("%d%d",&k,&p); update(k,p); }else{ int k1,k2; scanf("%d%d",&k1,&k2); ll part1=product(k1-1); ll part2=product(k2); ll ans=inv(part1)*part2; ans%=mod; printf("%I64d\n",ans); } } } return 0; }