HDU 5401(计数dp)

题意描述:

原先设定第0颗树只有一个节点0,现在要生成第i颗数,选  ai, bi, (ai < i, bi< i) 中两个节点(ci , di)相连接,构成一个新的树,且ai中节点的编号不变, bi中的所有节点编号都要在原来的基础上+ai树的大小,这样保证编号连续,对于每颗树T而言 ,, F(T)=n1i=0n1j=i+1d(vi,vj) ( d(vi,vj)
即任意两点之间距离总和。


这是多校题解:

考虑爆搜,树iii生成后,两两点对路径分成两部分,一部分不经过中间的边,那么就是aia_iaibib_ibi的答案,如果经过中间的边,首先计算中间这条边出现的次数,也就是ai,bia_i,b_iai,bi子树大小的乘积。对于aia_iai,对答案的贡献为所有点到cic_ici的距离和乘上bib_ibi的子树大小。bib_ibi同理。

那么转化为计算在树iii中,所有点到某个点jjj的距离和。假设jjjaia_iai内,那么就转化成了aia_iaijjj这个点的距离总和加上bib_ibi内所有点到did_idi的总和加上did_idijjj的距离乘上子树bib_ibi的大小,称作第一类询问。

这样就化成了在树iii中两个点jjjkkk的距离,如果在同一棵子树中,可以递归下去,否则假设jjjaia_iaikkkbib_ibi中,那么距离为jjjcic_ici的距离加上kkkdid_idi的距离加上lil_ili,称作第二类询问。

然后对两类询问全都记忆化搜索即可。

接着考虑计算一下复杂度。

对于第二类询问,可以考虑询问的过程类似于线段树,只会有两个分支,中间的部分已经记忆化下来,不用再搜,时间复杂度O(m)O(m)O(m)

我们分析一下复杂度,首先对于第一类询问,在bib_ibi中到did_idi的点距离和已经由前面的询问得到,那么就转化为一个第一类询问和一个第二类询问,最多会被转化成O(m)O(m)O(m)个第二类询问。

所以每个询问复杂度是O(m2)O(m^2)O(m2),总复杂度O(m3)O(m^3)O(m3)

复杂度计算思考:

对于第一类询问,只会例如sum(a[i], c[i])递归计算时,每个会分成两个第一类询问和一个第二类询问,而两个第一类询问必有一个已经被计算过(可以手动分解看看前后关系)

,所以每次分解成一个第一类和一个第二类,复杂度为m*m。

dis计算也同理。

被记忆的也不会很多,每次最多多记录m*m个。

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <map>
#include <set>
#include <vector>
#include <cctype>
#include <cmath>
#include <queue>
#define ls rt<<1
#define rs rt<<1|1
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mem(a,n) memset(a,n,sizeof(a))
#define rep(i,n) for(int i=0;i<(int)n;i++)
#define rep1(i,x,y) for(int i=x;i<=(int)y;i++)
using namespace std;
#pragma comment(linker, "/STACK:102400000,102400000")
typedef pair<int,int> pii;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const ll oo = 1e12;
typedef pair<ll,ll> pll;
const int N = 65;
const int mod = 1e9+7;

map<pll,ll> M[N];
map<ll,ll> M2[N];
int n;
ll a[N],b[N],c[N],d[N],siz[N],ms[N],l[N],ans[N];
void init(){
  for(int i = 0; i < N;i++)
    M[i].clear(),M2[i].clear();
  M[0][pll(0,0)]=0;
  M2[0][0] = 0;
  siz[0] = ms[0] = 1;
}
ll dis(int i,ll j,ll k){
   if(j > k) swap(j,k);
   if(M[i].count(pll(j,k))) return M[i][pll(j,k)];
   if(k < siz[a[i]]) return M[i][pll(j,k)] = dis(a[i],j,k);
   if(j >= siz[a[i]]) return M[i][pll(j,k)] = dis(b[i],j-siz[a[i]],k-siz[a[i]]);
   return  M[i][pll(j,k)] = (dis(a[i],j,c[i])+l[i]+dis(b[i],d[i],k-siz[a[i]]))%mod;
}
ll sum(int i,ll j){
   if(M2[i].count(j)) return M2[i][j];
   if(j<siz[a[i]]) return  M2[i][j]=(sum(a[i],j)+(l[i]+dis(a[i],j,c[i]))*ms[b[i]]+sum(b[i],d[i]))%mod;
   if(j>=siz[a[i]]) return M2[i][j]=(sum(a[i],c[i])+(l[i]+dis(b[i],j-siz[a[i]],d[i]))*ms[a[i]]+sum(b[i],j-siz[a[i]]))%mod;
}
ll cal(int i){
   siz[i] = siz[a[i]]+siz[b[i]];
   ms[i] = siz[i]%mod;
   ans[i] = ans[a[i]]+ans[b[i]]+ms[a[i]]*ms[b[i]]%mod*l[i]%mod+ms[b[i]]*sum(a[i],c[i])+ms[a[i]]*sum(b[i],d[i]);
   ans[i]=ans[i]%mod;
   return ans[i];
}
int main()
{
   while(scanf("%d",&n)==1){
      init();
      for(int i=1;i<=n;i++){
         scanf("%I64d %I64d %I64d %I64d %I64d",&a[i],&b[i],&c[i],&d[i],&l[i]);
         printf("%I64d\n",cal(i));
      }
   }
   return 0;
}





你可能感兴趣的:(HDU 5401(计数dp))