双向树形DP(民科专用名词)
感觉这年SH的题很水啊
先用半个小时发呆(明明是想不出来怎么搞,还以为是解方程吧)
然后半小时想出来DP方程。
然后就码码码1A了,话说本地的Lemon有几个点炸了不知道怎么回事,linux下测没问题
首先肯定是一颗树了。
对于每个节点,考虑它只有子节点时的概率。
用f[u][0]和f[u][1]分别表示节点u充不充电的概率。
f[u][0]=(1-q[u])*(f[v][0]+(1-p(u->v))*f[v][1])*……
f[u][1]=1-f[u][0]
这个可以从下往上DP得出
然后从上往下DP,即用父节点的结果更新子节点
当更新某个子节点时,假设父节点的答案为最终答案(即已经计算完了)
我们不妨把树的形态改变一下,把v的兄弟都转上去
然后变成了一个沙漏形
观察可知v的答案由上三角,下三角和v自身构成。
v的下三角和自身的贡献已经求出,那么就差u所代表的上三角对v的贡献了
我们会发现f[u][0]中有一部分是来自v的,显然不可能有电从v传到u再传回来
于是我们令f[u][0]'=f[u][0]/(f[v][0]+(1-p(u->v))*f[v][1])
再把f[u][0]'的贡献乘到f[v][0]上,于是v也计算完了
递归往下走就好了。
好像很简单的样子啊
#include<iostream> #include<cstdio> #include<cstring> #include<vector> using namespace std; const int N=500000+5; struct Edge{int to,next;double p;}e[N<<1]; int head[N],cnt; void ins(int u,int v,double p){ e[++cnt]=(Edge){v,head[u],p};head[u]=cnt; } void insert(int u,int v,double p){ ins(u,v,p);ins(v,u,p); } double f[N][2],q[N]; bool vis[N]; void dpup(int u,int fa){ vis[u]=1; for(int i=head[u];i;i=e[i].next){ int v=e[i].to;if(v==fa)continue; dpup(v,u);f[u][0]*=(f[v][0]+(1.0-e[i].p)*f[v][1]); } f[u][1]=1.0-f[u][0]; } void dpdown(int u,int fa){ vis[u]=1; for(int i=head[u];i;i=e[i].next){ int v=e[i].to;if(v==fa)continue; double fu0=f[u][0]/(f[v][0]+(1.0-e[i].p)*f[v][1]); f[v][0]*=(fu0+(1.0-fu0)*(1.0-e[i].p));f[v][1]=1.0-f[v][0]; dpdown(v,u); } } int pa[N]; int find(int x){ return pa[x]==x?x:pa[x]=find(pa[x]); } void merge(int u,int v){ u=find(u);v=find(v); pa[u]=v; } struct edge{ int u,v,p; }; vector<edge>g; int main(){ //freopen("a.in","r",stdin); int n;scanf("%d",&n); for(int i=1;i<=n;i++)pa[i]=i; for(int i=1;i<n;i++){ int u,v,p;scanf("%d%d%d",&u,&v,&p); if(p==100)merge(u,v); else if(p)g.push_back((edge){u,v,p}); } for(int i=1;i<=n;i++)scanf("%lf",&q[i]),q[i]/=100.0; for(int i=1;i<=n;i++)f[i][0]=1.0; for(int i=1;i<=n;i++){ int t=find(i); f[t][0]*=(1.0-q[i]); } for(int i=1;i<=n;i++)if(pa[i]==i)f[i][1]=1.0-f[i][0]; for(int i=0;i<g.size();i++){ int u=g[i].u,v=g[i].v; u=find(u);v=find(v); if(u==v)continue; insert(u,v,g[i].p/100.0); } for(int i=1;i<=n;i++)if(pa[i]==i&&!vis[i])dpup(i,-1); memset(vis,0,sizeof(vis)); for(int i=1;i<=n;i++)if(pa[i]==i&&!vis[i])dpdown(i,-1); double ans=0; for(int i=1;i<=n;i++)ans+=f[find(i)][1]; printf("%.6lf\n",ans); return 0; }