For the given sequence with n different elements find the number of increasing subsequences with k + 1 elements. It is guaranteed that the answer is not greater than 8·1018.
Input
First line contain two integer values n and k (1 ≤ n ≤ 10^5, 0 ≤ k ≤ 10) — the length of sequence and the number of elements in increasing subsequences.
Next n lines contains one integer ai (1 ≤ ai ≤ n) each — elements of sequence. All values ai are different.
Output
Print one integer — the answer to the problem.
Examples
input
5 2
1
2
3
5
4
output
7
大致题意:告诉你n个不同的数 范围1到n,让你统计有多少个长度为k+1的递增数列.
思路:很容易可以想到dp方程:dp[ num[i] ][ len ]=sum(dp[ num[j] ][ len-1 ])(j < i&&num[i] >num[j],1<=len<=k+1),但是这样的转移时间是O(n),总的时间复杂度为O(k*n^2),肯定gg。所以我们考虑如何去优化转移。这里很巧妙的用到了线段树(当然用树状数组也可以)。我们以这n个数为下标建立k+1棵线段树,按照输入的先后顺序,当我们输入一个数num时,我们就可以在logn的时间内求出dp{num][len]的值,即在第len-1棵树中,查询区间1到num-1的和,然后用所求得的值去更新第len棵树上下标为num位置上的值。最后查询一下第k+1棵树中区间1到n的和即可。这样时间复杂度就降到了O(k*nlogn)
代码如下
代码1.
/*
线段树
*/
#include
#include
#define ll long long int
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
using namespace std;
const int M=1e5+5;
ll dp[M<<2][15];
int f;
inline void PushPlus(int rt)
{
dp[rt][f] = dp[rt<<1][f]+dp[rt<<1|1][f];
}
void Updata(int p,ll change, int l, int r, int rt)//单点更新,p位置上的数改变为change
{
if( l == r )
{
dp[rt][f]=change;
return ;
}
int m = ( l + r ) >> 1;
if(p <= m)
Updata(p, change, lson);
else
Updata(p, change, rson);
PushPlus(rt);
}
ll Query(int L,int R,int l,int r,int rt)
{
if(L>R)
return 0;
if( L <= l && r <= R )
{
return dp[rt][f];
}
int m = ( l + r ) >> 1;
ll ans=0;
if(L<=m )
ans+=Query(L,R,lson);
if(R>m)
ans+=Query(L,R,rson);
return ans;
}
int main()
{
int n,k;
int A;
scanf("%d%d",&n,&k);
k++;
for(int i=1;i<=n;i++)
{
scanf("%d",&A);
for(int j=1;j<=k;j++)
{
f=j;
f--;
ll ans;
if(f==0)
ans=1;
else
ans=Query(1,A-1,1,n,1);
f++;
Updata(A,ans,1,n,1);//单点更新,A位置上值变为ans
}
}
f=k;
printf("%I64d\n",Query(1,n,1,n,1));
return 0;
}
代码2.
/*
树状数组
emmm速度快了很多,写起来也方便不少,以后能用树状数组就尽量用吧
*/
#include
#include
#include
#include
#define ll long long int
using namespace std;
const int N=100005;
ll dp[N][15];
int f;
int n,k;
int lowbit(int x)
{
return x&-x;
}
ll sum(int x)
{
ll s=0;
while(x>0)
{
s+=dp[x][f];
x=x-lowbit(x);
}
return s;
}
void add(int x,ll date)
{
while(x<=n)
{
dp[x][f]+=date;
x=x+lowbit(x);
}
}
int main()
{
int A;
scanf("%d%d",&n,&k);
k++;
for(int i=1;i<=n;i++)
{
scanf("%d",&A);
for(int j=1;j<=k;j++)
{
ll ans;
f=j;
f--;
if(f==0)
ans=1;
else
ans=sum(A-1);
f++;
add(A,ans);
}
}
f=k;
printf("%I64d\n",sum(n));
return 0;
}