CF1731F Function Sum
对于一个长度为 n n n的序列 a a a。
定义 l i l_i li表示 i i i左边比 a i a_i ai小的数的个数,即 l i = ∑ j = 1 i − 1 [ a j < a i ] l_i=\sum\limits_{j=1}^{i-1}[a_j
定义 r i r_i ri表示 i i i右边比 a i a_i ai大的数的个数,即 r i = ∑ j = i + 1 n [ a j > a i ] r_i=\sum\limits_{j=i+1}^{n}[a_j>a_i] ri=j=i+1∑n[aj>ai]
我们称一个位置是好的,当且仅当 l i < r i l_i
对于序列 a a a,定义函数 f ( a ) = ∑ i = 1 n a i [ a i i s g o o d ] f(a)=\sum\limits_{i=1}^na_i[a_i \ is \ good] f(a)=i=1∑nai[ai is good]。
现在给定两个整数 n , k n,k n,k,请你求出对于所有长度为 n n n且 1 ≤ a i ≤ k 1\leq a_i\leq k 1≤ai≤k的数列 a a a的 f ( a ) f(a) f(a)的和是多少。
设 f i , j f_{i,j} fi,j表示第 i i i个数为 j j j时该位置对答案的贡献。枚举在 i i i之前小于 a i a_i ai的数的个数 x x x和在 i i i之后小于 a i a_i ai的数的个数 y y y,那么
f i , j = j × ∑ x = 1 i − 1 C i − 1 x × ( j − 1 ) x ( k − j + 1 ) i − 1 − x ∑ y = x + 1 n − i C n − i y ( k − j ) y j n − i − y f_{i,j}=j\times \sum\limits_{x=1}^{i-1}C_{i-1}^x\times (j-1)^x(k-j+1)^{i-1-x}\sum\limits_{y=x+1}^{n-i}C_{n-i}^y(k-j)^yj^{n-i-y} fi,j=j×x=1∑i−1Ci−1x×(j−1)x(k−j+1)i−1−xy=x+1∑n−iCn−iy(k−j)yjn−i−y
那么答案就是 ∑ i = 1 n ∑ j = 1 k f i , j \sum\limits_{i=1}^n\sum\limits_{j=1}^{k}f_{i,j} i=1∑nj=1∑kfi,j。
但这样的时间复杂度实在太大,所以我们考虑优化。
观察上面 f i , j f_{i,j} fi,j的表达式,将其视作关于 j j j的一个多项式,我们发现这个多项式的最高次数为 n n n。
设 F j = ∑ i = 1 n f i , j F_j=\sum\limits_{i=1}^nf_{i,j} Fj=i=1∑nfi,j, G i = ∑ j = 1 i F j G_i=\sum\limits_{j=1}^iF_j Gi=j=1∑iFj。那么答案为 G k G_{k} Gk。
因为 f i , j f_{i,j} fi,j中 j j j的最高次项的次数为 n n n,所以 G i G_i Gi中 i i i的最高次项为 n + 1 n+1 n+1。我们可以用拉格朗日插值法求出 n + 2 n+2 n+2个 G G G的值,就能求出 G ( k ) G(k) G(k)的值了。
时间复杂度为 O ( n 4 log n ) O(n^4\log n) O(n4logn)。
#include
using namespace std;
const int N=55;
long long jc[105],ny[105],f[105];
long long mod=998244353;
long long mi(long long t,long long v){
if(!v) return 1;
long long re=mi(t,v/2);
re=re*re%mod;
if(v&1) re=re*t%mod;
return re;
}
long long C(int x,int y){
return jc[x]*ny[y]%mod*ny[x-y]%mod;
}
void init(){
jc[0]=1;
for(int i=1;i<=N;i++) jc[i]=jc[i-1]*i%mod;
ny[N]=mi(jc[N],mod-2);
for(int i=N-1;i>=0;i--) ny[i]=ny[i+1]*(i+1)%mod;
}
long long dd(long long n,long long k){
long long re=0;
for(int i=1;i<=n;i++){
long long p=f[i],q=1;
for(int j=1;j<=n;j++){
if(i!=j){
p=p*(k-j)%mod;q=q*(i-j+mod)%mod;
}
}
re=(re+p*mi(q,mod-2)%mod)%mod;
}
return re;
}
int main()
{
long long n,k;
init();
scanf("%lld%lld",&n,&k);
for(int t=1;t<=min(n+2,k);t++){
for(int i=1;i<=n;i++){
for(int x=0;x<=i-1;x++){
for(int y=x+1;y<=n-i;y++){
f[t]=(f[t]+C(i-1,x)*mi(t-1,x)%mod*mi(k-t+1,i-1-x)%mod
*C(n-i,y)%mod*mi(k-t,y)%mod*mi(t,n-i-y)%mod)%mod;
}
}
}
f[t]=t*f[t]%mod;
}
for(int i=1;i<=n+2;i++) f[i]=(f[i]+f[i-1])%mod;
if(k<=n+2) printf("%lld",f[k]);
else printf("%lld",dd(n+2,k));
return 0;
}