2020牛客寒假算法基础集训营2 J-求函数 (线段树维护矩阵乘法)

题目链接:https://ac.nowcoder.com/acm/contest/3003/J

思路:

方法①

f1(1)=k1+b1=(k1)+(b1)

f2(f1(1))=k2(f1(1))+b2=k2k1+k2b1+b2=(k2k1)+(k2b1+b2)

f3(f2(f1(1)))=(k3k2k1)+(k3k2b1+k3b2+b3)

通过上面的展开,我们可以发现一个式子可以分成两部分:∏Ki  与  ∑ri=l(bi*∏rj=i+1Kj)

分别用线段树维护这两部分即可,现在考虑如果合并区[l, r] 与 [r1+1 ,r]

假设左区间的第一部分为 N第二部分为 M1

  右区间的第一部分为 N2 第二部分为 M2

合并后区间的第一部分为N1*N2,第二部分为N2 * M1 + M2

#include
#include
#include
 using namespace std;
 typedef long long ll;
 const int mod=1e9+7;
 const int maxn=2e5+10;
 struct node{
     ll l,r,k,b;
 }tree[maxn<<2];
 ll k[maxn],b[maxn],n,m,op,l1,r1,po,k1,b1;
 void pushup(int rt)
 {
     tree[rt].k=(tree[rt<<1].k*tree[rt<<1|1].k)%mod;
     tree[rt].b=((tree[rt<<1|1].k*tree[rt<<1].b)%mod+tree[rt<<1|1].b)%mod;
 }
 void build(ll rt,ll l,ll r)
 {
     tree[rt].l=l;
     tree[rt].r=r;
     if(l==r){
         tree[rt].k=k[l],tree[rt].b=b[l];
         return;
     }
    ll mid=(l+r)>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
    pushup(rt);
 }
 void update(ll rt,ll pos)
 {
     if(tree[rt].l==pos&&tree[rt].r==pos){
         tree[rt].k=k[pos],tree[rt].b=b[pos];
         return;
     }
    ll mid=(tree[rt].l+tree[rt].r)>>1;
    if(pos<=mid)    update(rt<<1,pos);
    else update(rt<<1|1,pos);
    pushup(rt);
  }
  typedef pair p;
  p query(int rt,int l,int r,int ll,int rr)
{
    if(ll>r || rrreturn p(-1,-1);
    if(l>=ll && r<=rr) return p(tree[rt].k,tree[rt].b);
    int mid=(l+r)>>1;
    p p1=query(rt<<1,l,mid,ll,rr);
    p p2=query(rt<<1|1,mid+1,r,ll,rr);
    if(p1.first==-1) return p2;
    if(p2.first==-1) return p1;
    int k1=p1.first,b1=p1.second;
    int k2=p2.first,b2=p2.second;
    return p(1ll*k1*k2%mod,(1ll*b1*k2+b2)%mod);
}
 int main()
 {
     scanf("%lld%lld",&n,&m);
     for(int i=1;i<=n;i++) scanf("%d",&k[i]);
     for(int i=1;i<=n;i++) scanf("%d",&b[i]);
     build(1,1,n);
     for(int i=1;i<=m;i++){
         scanf("%lld",&op);
         if(op==1){
             scanf("%lld%lld%lld",&po,&k1,&b1);
            k[po]=k1;
            b[po]=b1;
            update(1,po);
         } 
         else{
             scanf("%lld%lld",&l1,&r1);
             p p1=query(1,1,n,l1,r1);
             int k=p1.first,b=p1.second;
            printf("%d\n",((k+b)%mod+mod)%mod);
         } 
     }
    return 0;
 }

 方法②矩阵乘法

这篇博客讲的不错:https://www.cnblogs.com/BakaCirno/p/12270838.html

 

#include
#include
#include
#define mid ((l+r)>>1)
 typedef long long ll;
 using namespace std;
 const int maxn=2e5+10;
 const int mod=1e9+7;
 int n,m;
 ll k[maxn],b[maxn];
 struct MX{
     ll m[2][2];
     MX(){memset(m,0,sizeof(m));}
    friend MX operator *(const MX&a,const MX&b){
        MX res;
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++){
                for(int k=0;k<2;k++)
                    res.m[i][j]+=a.m[i][k]*b.m[k][j];
                res.m[i][j]%=mod;
            }
        return res;
    }
 }mx[maxn<<2];
 void update(int rt,int l,int r,int pos)
 {
     if(l==r){mx[rt].m[0][0]=k[l],mx[rt].m[1][0]=b[l],mx[rt].m[1][1]=1;return;}
     if(pos<=mid)    update(rt<<1,l,mid,pos);
     else update(rt<<1|1,mid+1,r,pos);
     mx[rt]=mx[rt<<1]*mx[rt<<1|1];
 }
 MX query(int rt,int l,int r,int L,int R)
 {
     if(L<=l&&r<=R)    return mx[rt];
     MX res;
     res.m[0][0]=res.m[1][1]=1;
     if(L<=mid)    res=res*query(rt<<1,l,mid,L,R);
     if(R>mid)    res=res*query(rt<<1|1,mid+1,r,L,R);
     return res;
 }
 int main()
 {
     cin>>n>>m;
     for(int i=1;i<=n;i++)    scanf("%lld",&k[i]);
     for(int i=1;i<=n;i++)    scanf("%lld",&b[i]);
     for(int i=1;i<=n;i++)    update(1,1,n,i);
     for(int i=1,opt,l,r;i<=m;i++){
         cin>>opt;
         if(opt==1){
             scanf("%d",&l);
             scanf("%lld%lld",&k[l],&b[l]);
             update(1,1,n,l);
         }
         else{
             scanf("%d%d",&l,&r);
             MX res;
             res.m[0][0]=res.m[0][1]=1;
             res=res*query(1,1,n,l,r);
             printf("%lld\n",res.m[0][0]%mod);
         }
     }
  } 

 

你可能感兴趣的:(2020牛客寒假算法基础集训营2 J-求函数 (线段树维护矩阵乘法))