X - Calculate the Function(线段树+矩阵)

http://acm.sdut.edu.cn:8080/vjudge/contest/view.action?cid=216#problem/X


给出一个序列,在区间[l,r]内定义一个函数f(x),其中f(l) = a[l],f(r) = a[r],f(x) = f(x-1) + f(x-2)*a[x](x >= l+2)。

对于每一个询问区间[l,r],输出f(r)。


刚看到这个题没有什么思路,看到别人的博客标题写着线段树+矩阵,想到f(x)的表达式类似于求斐波那契,于是想到构造矩阵,对于区间[l,r],f[r]可由矩阵乘法求出,所以可以用线段树维护区间[l+2,r]的矩阵的乘积。


#include <stdio.h>
#include <iostream>
#include <map>
#include <set>
#include <list>
#include <stack>
#include <vector>
#include <math.h>
#include <string.h>
#include <queue>
#include <string>
#include <stdlib.h>
#include <algorithm>
#define LL long long
#define eps 1e-12
#define PI acos(-1.0)
#define PP pair<LL,LL>
using namespace std;
const int INF = 0x3f3f3f3f;
const int maxn = 100010;
const int mod = 1000000007;

typedef struct matrix
{
	LL mat[2][2];
}Matrix;

Matrix operator * (Matrix a, Matrix b)
{
	Matrix res;
	res.mat[0][0] = res.mat[0][1] = res.mat[1][0] = res.mat[1][1] = 0;
	for(int i = 0; i < 2; i++)
	{
		for(int k = 0; k < 2; k++)
		{
			if(a.mat[i][k] == 0) continue;
			for(int j = 0; j < 2; j++)
			{
				res.mat[i][j] += a.mat[i][k]*b.mat[k][j]%mod;
				res.mat[i][j] %= mod;
			}
		}
	}
	return res;
}

struct node
{
	int l,r;
	Matrix matrix;
}tree[maxn*4];

LL a[maxn];

void push_up(int v)
{
	tree[v].matrix = tree[v*2+1].matrix * tree[v*2].matrix;
}

void build(int v, int l, int r)
{
	tree[v].l = l;
	tree[v].r = r;
	if(l == r)
	{
		Matrix tmp;
		tmp.mat[0][0] = (LL)1;
		tmp.mat[0][1] = a[l];
		tmp.mat[1][0] = (LL)1;
		tmp.mat[1][1] = 0;
		tree[v].matrix = tmp;
		return;
	}
	int mid = (tree[v].l + tree[v].r)>>1;
	build(v*2,l,mid);
	build(v*2+1,mid+1,r);
	push_up(v);
}

Matrix query(int v, int l, int r)
{
	if(tree[v].l == l && tree[v].r == r)
		return tree[v].matrix;
	int mid = (tree[v].l + tree[v].r) >> 1;
	if(r <= mid)
		return query(v*2,l,r);
	else if(l > mid)
		return query(v*2+1,l,r);
	else
	{
		Matrix tmp = query(v*2+1,mid+1,r) * query(v*2,l,mid); //注意顺序
		return tmp;
	}
}

int main()
{
	int test;
	int n,m,l,r;
	LL ans;
	scanf("%d",&test);
	while(test--)
	{
		scanf("%d %d",&n,&m);
		for(int i = 1; i <= n; i++)
			scanf("%lld",&a[i]);
		build(1,1,n);
		while(m--)
		{
			scanf("%d %d",&l,&r);
			if(l == r)
				printf("%lld\n",a[l]%mod);
			else if(r == l+1)
				printf("%lld\n",a[r]%mod);
			else
			{
				Matrix tmp = query(1,l+2,r);
				ans = (tmp.mat[0][0]*a[l+1]%mod + tmp.mat[0][1]*a[l]%mod)%mod;
				printf("%lld\n",ans);
			}
		}
	}
	return 0;
}



你可能感兴趣的:(线段树)