老年选手搞了半个月的文化课开始康复训练。
很显然要求的就是个独立集形式的树上背包。
写成卷积的形式链分治+带权二分即可。
很好写,拿下LOJ rk1。
复杂度 O ( n log 2 n ) O(n\log^2n) O(nlog2n),分析方式类似全局平衡二叉树。
看了下AC代码,除了我和rk2,剩下的分治部分似乎都是普通二分而不是带权二分,可以卡到 O ( n log 3 n ) O(n\log^3 n) O(nlog3n)
rk2可能是慢在取模优化得不够极限,其实都差不多。
代码:
#include
#define ll long long
#define re register
#define cs const
namespace IO{
inline char gc(){
static cs int Rlen=1<<22|1;static char buf[Rlen],*p1,*p2;
return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
}template<typename T>T get_integer(){
char c;bool f=false;while(!isdigit(c=gc()))f=c=='-';T x=c^48;
while(isdigit(c=gc()))x=((x+(x<<2))<<1)+(c^48);return f?-x:x;
}inline int gi(){return get_integer<int>();}
}using namespace IO;
using std::cerr;
using std::cout;
cs int mod=998244353;
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline int dec(int a,int b){return a-b<0?a-b+mod:a-b;}
inline int mul(int a,int b){ll r=(ll)a*b;return r>=mod?r%mod:r;}
inline void Inc(int &a,int b){a+=b-mod;a+=a>>31&mod;}
inline void Dec(int &a,int b){a-=b;a+=a>>31&mod;}
inline void Mul(int &a,int b){a=mul(a,b);}
inline int po(int a,int b){int r=1;for(;b;b>>=1,Mul(a,a))if(b&1)Mul(r,a);return r;}
inline void ex_gcd(int a,int b,int &x,int &y){
if(!b){x=1,y=0;return;}ex_gcd(b,a%b,y,x);y-=a/b*x;
}inline int inv(int a){int x,y;ex_gcd(mod,a,y,x);return x+(x>>31&mod);}
template<class ...Args>
inline int add(int a,cs Args& ...args){return add(a,add(args...));}
template<class ...Args>
inline int mul(int a,cs Args& ...args){return mul(a,mul(args...));}
cs int bit=18,SIZE=1<<bit|7;
int r[SIZE],*w[bit+1],Log[SIZE],len,inv_len;
inline void init_omega(){
for(int re i=1;i<=bit;++i)
w[i]=new int[1<<(i-1)];
int wn=po(3,(mod-1)>>bit);
for(int re i=w[bit][0]=1;i<(1<<(bit-1));++i)
w[bit][i]=mul(w[bit][i-1],wn);
for(int re j=bit-1;j;--j)
for(int re i=0;i<(1<<(j-1));++i)
w[j][i]=w[j+1][i<<1];
for(int re i=2;i<SIZE;++i)
Log[i]=Log[i-1]+((1<<Log[i-1])<i);
}void DFT(int *A){
for(int re i=1;i<len;++i)
if(i<r[i])std::swap(A[i],A[r[i]]);
for(int re i=1,d=1;i<len;i<<=1,++d)
for(int re j=0;j<len;j+=i<<1)
if(i<8){
for(int re k=0;k<i;++k){
int &t1=A[j+k],&t2=A[j+k+i];
int t=mul(t2,w[d][k]);
t2=dec(t1,t);Inc(t1,t);
}
}else {
for(int re k=0;k<i;k+=8){
#define work(p) \
{ \
int &t1=A[j+k+p],&t2=A[j+k+i+p]; \
int t=mul(t2,w[d][k+p]); \
t2=dec(t1,t);Inc(t1,t); \
}
work(0);work(1);work(2);work(3);
work(4);work(5);work(6);work(7);
}
}
}void IDFT(int *A){
DFT(A);std::reverse(A+1,A+len);
for(int re i=0;i<len;++i)Mul(A[i],inv_len);
}void init_len(int l){
len=l,inv_len=inv(l);
for(int re i=1;i<l;++i)r[i]=r[i>>1]>>1|((i&1)?l>>1:0);
}
using Poly=std::vector<int>;
void DFT(Poly &A){DFT(&A[0]);}
void IDFT(Poly &A){IDFT(&A[0]);}
cs int N=8e4+7;
int n,m;
struct matrix{
Poly a[2][2];matrix(){}
inline int deg()cs{
return std::max(
std::max(a[0][0].size(),a[0][1].size()),
std::max(a[1][0].size(),a[1][1].size())
);
}
inline void clear(){
a[0][0].clear(),a[0][1].clear();
a[1][0].clear(),a[1][1].clear();
}
inline cs Poly* operator[](int o)cs{return a[o];}
inline Poly* operator[](int o){return a[o];}
};
matrix& operator*=(matrix &A,cs matrix &B){
static int a[2][2][SIZE],b[2][2][SIZE],c[2][2][SIZE];
init_len(1<<Log[A.deg()+B.deg()-1]);
for(int re i=0;i<2;++i)for(int re j=0;j<2;++j){
cs Poly &va=A[i][j],&vb=B[i][j];
memcpy(a[i][j],&va[0],va.size()<<2);
memset(a[i][j]+va.size(),0,(len-va.size())<<2);
memcpy(b[i][j],&vb[0],vb.size()<<2);
memset(b[i][j]+vb.size(),0,(len-vb.size())<<2);
DFT(a[i][j]);DFT(b[i][j]);
}
for(int re i=0;i<2;++i)for(int re j=0;j<2;++j){
for(int re p=0;p<len;++p)
c[i][j][p]=add(mul(a[i][1][p],b[0][j][p]),
mul(a[i][0][p],add(b[0][j][p],b[1][j][p])));
}
for(int re i=0;i<2;++i)for(int re j=0;j<2;++j){
IDFT(c[i][j]);int d=std::min(len-1,m);
while(~d&&!c[i][j][d])--d;
if(~d)A[i][j]=Poly(c[i][j],c[i][j]+d+1);
else A[i][j].clear();
}return A;
}
Poly operator*(cs Poly &a,cs Poly &b){
static int A[SIZE],B[SIZE];
int deg=a.size()+b.size()-1;init_len(1<<Log[deg]);
memcpy(A,&a[0],a.size()<<2);
memset(A+a.size(),0,(len-a.size())<<2);
memcpy(B,&b[0],b.size()<<2);
memset(B+b.size(),0,(len-b.size())<<2);
DFT(A);DFT(B);for(int re i=0;i<len;++i)Mul(A[i],B[i]);
IDFT(A);deg=std::min(deg-1,m)+1;
while(deg>0&&!A[deg-1])--deg;return Poly(A,A+deg);
}
Poly& operator+=(Poly &A,cs Poly &B){
if(A.size()<B.size())A.resize(B.size());
for(int re i=0;i<(int)B.size();++i)Inc(A[i],B[i]);
return A;
}
std::vector<int> G[N];
void adde(int u,int v){
G[u].push_back(v);
G[v].push_back(u);
}
matrix dp[N];
int vl[N];
int fa[N],sz[N],son[N];
void pre_dfs(int u,int p){
fa[u]=p;sz[u]=1;
for(int re v:G[u])
if(v!=p){
pre_dfs(v,u);sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
}
int st[N],tp;
Poly p[N];
int pr[N];
Poly merge_son(int l,int r){
if(l==r)return p[l];
int m=std::lower_bound(pr+l,pr+r+1,(pr[l]+pr[r])>>1)-pr;
if(m==r)--m;return merge_son(l,m)*merge_son(m+1,r);
}
void merge_chain(int l,int r){
if(l==r)return ;
int m=std::lower_bound(pr+l,pr+r+1,(pr[l]+pr[r])>>1)-pr;
if(m==r)--m;merge_chain(l,m),merge_chain(m+1,r);
dp[st[l]]*=dp[st[m+1]];
}
void dfs_solve(int u){
for(int re v:G[u])
if(v!=fa[u])dfs_solve(v);
tp=0;
for(int re v:G[u])
if(v!=fa[u]&&v!=son[u]){
p[++tp].clear();
p[tp].resize(dp[v].deg());
p[tp]+=dp[v][0][0];
p[tp]+=dp[v][0][1];
p[tp]+=dp[v][1][0];
p[tp]+=dp[v][1][1];
pr[tp]=pr[tp-1]+p[tp].size();
}
if(tp)dp[u][0][0]=merge_son(1,tp);else dp[u][0][0].push_back(1);
tp=0;
for(int re v:G[u])
if(v!=fa[u]&&v!=son[u]){
p[++tp].clear();
p[tp]+=dp[v][0][0];
p[tp]+=dp[v][0][1];
pr[tp]=pr[tp-1]+p[tp].size();
}
if(tp)dp[u][1][1]=merge_son(1,tp);else dp[u][1][1].push_back(1);
for(int &v:dp[u][1][1])Mul(v,vl[u]);dp[u][1][1].insert(dp[u][1][1].begin(),0);
if(u!=son[fa[u]]){tp=0;
for(int re p=u;p;p=son[p])
st[++tp]=p,pr[tp]=pr[tp-1]+dp[p].deg();
merge_chain(1,tp);
}
}
void Main(){
n=gi(),m=gi();init_omega();
for(int re i=1;i<=n;++i)vl[i]=gi();
for(int re i=1;i<n;++i)adde(gi(),gi());
pre_dfs(1,0);dfs_solve(1);int ans=0;
for(int re i=0;i<2;++i)for(int re j=0;j<2;++j)
if((int)dp[1][i][j].size()>m)Inc(ans,dp[1][i][j][m]);
cout<<ans<<"\n";
}
inline void file(){
#ifdef zxyoi
freopen("flower.in","r",stdin);
#endif
}
signed main(){file();Main();return 0;}