题面简洁明了,一看就懂
做了这个题之后,才知道怎么用线段树维护递推式。递推式的递推过程可以看作两个矩阵相乘,假设矩阵A是初始值矩阵,矩阵B是变换矩阵,求第n项相当于把矩阵B乘了n - 1次。
那么我们线段树中每个点维护把矩阵B乘了多少次,懒标记下放的时候用快速幂维护sum。
#include
#define LL long long
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
using namespace std;
const LL mod = 1000000007;
const int maxn = 100010;
struct Matrix {
static const int len = 2;
LL x[len][len];
void init() {
memset(x, 0, sizeof(x));
for (int i = 0; i < len; i++)
x[i][i] = 1;
}
void zero() {
memset(x, 0, sizeof(x));
}
Matrix operator * (const Matrix& m) const {
Matrix ans;
ans.zero();
for (int i = 0; i < len; i++)
for (int j = 0; j < len; j++)
for (int k = 0; k < len; k++)
ans.x[i][j] = (ans.x[i][j] + x[i][k] * m.x[k][j]) % mod;
return ans;
}
Matrix operator + (const Matrix& m) const {
Matrix ans;
ans.zero();
for (int i = 0; i < len; i++)
for (int j = 0; j < len; j++)
ans.x[i][j] = (x[i][j] + m.x[i][j]) % mod;
return ans;
}
Matrix operator ^ (int b) const {
Matrix ans, a;
ans.init();
memcpy(a.x, x, sizeof(x));
for (; b; b >>= 1) {
if(b & 1) ans = ans * a;
a = a * a;
}
return ans;
}
};
Matrix mul , tmp, trans ;
int a[maxn];
struct SegementTree {
int lz;
Matrix sum, flag;
};
SegementTree tr[maxn * 4];
void maintain(int o) {
tr[o].sum = tr[ls(o)].sum + tr[rs(o)].sum;
}
void pushdown(int o) {
if(tr[o].lz) {
tr[ls(o)].sum = tr[ls(o)].sum * tr[o].flag;
tr[rs(o)].sum = tr[rs(o)].sum * tr[o].flag;
tr[ls(o)].flag = tr[ls(o)].flag * tr[o].flag;
tr[rs(o)].flag = tr[rs(o)].flag * tr[o].flag;
tr[o].lz = 0;
tr[ls(o)].lz = 1;
tr[rs(o)].lz = 1;
tr[o].flag.init();
}
}
void build(int o, int l, int r) {
tr[o].sum.zero();
tr[o].lz = 0;
tr[o].flag.init();
if(l == r) {
tr[o].sum = trans * ( mul ^ (a[l] - 1));
return;
}
int mid = (l + r) >> 1;
build(ls(o), l, mid);
build(rs(o), mid + 1, r);
maintain(o);
}
void update(int o, int l, int r, int ql, int qr, Matrix now) {
if(l >= ql && r <= qr) {
tr[o].sum = tr[o].sum * now;
tr[o].flag = tr[o].flag * now;
tr[o].lz = 1;
return;
}
pushdown(o);
int mid = (l + r) >> 1;
if(ql <= mid) update(ls(o), l, mid, ql, qr, now);
if(qr > mid) update(rs(o), mid + 1, r, ql, qr, now);
maintain(o);
}
LL query(int o, int l, int r, int ql, int qr) {
if(l >= ql && r <= qr) {
return tr[o].sum.x[0][1];
}
pushdown(o);
int mid = (l + r) >> 1;
LL ans = 0;
if(ql <= mid) ans = (ans + query(ls(o), l, mid, ql, qr)) % mod;
if(qr > mid) ans = (ans + query(rs(o), mid + 1, r, ql, qr)) % mod;
return ans;
}
int main() {
int n, m, op, l, r;
LL x;
trans.zero();
trans.x[0][1] = 1;
mul.x[0][1] = mul.x[1][0] = mul.x[1][1] = 1;
mul.x[0][0] = 0;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
build(1, 1, n);
for (int i = 1; i <= m; i++) {
scanf("%d%d%d", &op, &l, &r);
if(op == 1) {
scanf("%lld", &x);
tmp = (mul ^ x);
update(1, 1, n, l, r, tmp);
} else {
printf("%lld\n", query(1, 1, n, l, r));
}
}
}