原题意挺复杂的,我就尽我能力写简化一点吧……
给定一个有 m 个点的树形结构( 1 为根),其中保证 1 到 n 按照编号顺序形成一条链。
然后你要在这棵树上推Gal从 1 号点走到 n 号点,你走动的规则是从当前点等概率随机选择一个儿子走下去。如果你走进了错误的子树肯定走不到点 n 嘛,因此我们可以设置最多 p 个存档点,每当经过一个设置的存档点,你的当前存档点就更新为它。如果走到了一个不是 n 的叶子,你下一步就可以回到当前存档点。存档点必须设置在 1 到 n 的链上(否则你就会无限循环)。点 1 和点 n 必须设置存档点。
其实推过Gal或者类似游戏或者大概了解这类游戏的很容易理解啦~
现在由你来选择设置存档点的位置,最小化 1 走到 n 的期望步数。
本题 T 组数据。
50≤p≤n≤700,m≤1500,T≤5
保证每个编号属于 [1,n) 的节点至少有两个儿子,至多有三个儿子。
栋栋看完Re0之后出题好题。
按照惯例来一波部分分: 50≤p≤n≤500
我们将链 1...n 称为主链。
令 back(x) 表示非主链节点 x 向下走走回当前存档点的期望步数:
观察 goj,k ,如果我们将 j 左移一位,它的增量是多少呢?
出题人为什么要限制 p 的下界和儿子个数呢?
如果不做任何限制,稍有常识的人都可以看出来这里的 go 是一个指数级增长的函数,会导致我们的答案精度不够。
那么出题人在制造了这些限制之后,答案上界显然就会减少。
具体是多少呢?出题人通过构造一个可行解来估计:一种很平均的状况就是主链上每隔 ⌊np⌋ 个设置一个存档点,然后将主链上儿子个数都取上界,估计最坏情况。然后就是一个等比数列求和,因为 p 有下界,所以值不会很大,是个12位数,令其为 L ,而这还不是最优答案,只是一个可行解。因为我没有打这种算法,这里证明过程从略。
这意味着什么呢?显然 go 函数的增长是比 2 的幂要大的,因为儿子数下界就是 2 ,因此如果两个主链上的点距离超过 log2L ,那么它显然超过了最优解上界 L 。所以 f 枚举转移的时候只需要枚举 log2L 以内的点,大约 40 多个就好了。
时间复杂度 O(nplog2L) 。
更加神奇的算法3,由fanvree在考试时想出来。%%%
这个算法太神奇,我还不是很懂,就在这里口胡一下。听说思路来自2012中国国家集训队命题答辩能量棒。
首先显然我们应该尽量用完所有存档点。
我们二分一个神奇的东西 c 。
这个 c 有什么用呢?我们用它来限制设置存档点个数。
具体怎么限制呢?就是我们检测当前答案是否合法,依然是对主链做 dp ,但是不在状态上对存档点个数做出限制,即取消第一维。但是同时,我们每次选取一个存档点,就将 f 加上 c ,即希望通过选取存档点需要更多的代价来限制我们尽可能少地选择存档点(不然会超过 p )。
dp 完后,我们检查最优解使用了多少个存档点,设为 tot ,如果 tot>p 那么说明我们限制力度还不够,需要扩大 c ;如果 tot<p ,那么说明我们限制过紧,需要减小 c 。否则 tot=p ,说明这就是最优解的方案,将答案减去 tot×c 即可。
但是这样我们有可能找不到会使得 tot=p 的 c ,这时候我们就选取一个最小的使得 tot>p 的 c 然后将答案将去 tot×c 。
二分 c 的上界也是 L 左右即可。这个太神了,正确性我不会证明,有兴趣的自行查阅资料吧~
时间复杂度 O(n2log2L) 。
#include <iostream>
#include <cstring>
#include <cfloat>
#include <cstdio>
using namespace std;
typedef long double db;
const db INF=DBL_MAX/3;
const int N=700;
const int M=1500;
const int P=700;
db f[P+5][N+5],go[N+5][N+5],back[M+5];
int last[M+5],tov[M+5],next[M+5];
bool vis[M+5];
int n,m,tot,p,T;
inline void insert(int x,int y){tov[++tot]=y,next[tot]=last[x],last[x]=tot;}
db dfs(int x)
{
if (vis[x]) return back[x];
vis[x]=1,back[x]=0;
int cnt=0;
for (int i=last[x];i;i=next[i]) cnt++,back[x]+=dfs(tov[i]);
if (cnt) back[x]/=cnt;
return ++back[x];
}
void clearall()
{
memset(vis,0,sizeof vis);
for (;tot;tot--) tov[tot]=next[tot]=0;
for (int i=1;i<=m;i++) last[i]=0;
memset(f,0,sizeof f);
memset(go,0,sizeof go);
memset(back,0,sizeof back);
}
int main()
{
freopen("memory.in","r",stdin),freopen("memory_brute.out","w",stdout);
for (scanf("%d",&T);T--;clearall())
{
scanf("%d%d%d",&n,&m,&p);
for (int i=2;i<=n;i++) insert(i-1,i);
for (int i=1,x,y;i<=m-n;i++) scanf("%d%d",&x,&y),insert(x,y);
for (int i=n+1;i<=m;i++) if (!vis[i]) dfs(i);
for (int x=1;x<=n;x++)
for (int y=x;y<=n;y++)
if (x==y) go[x][y]=0;
else
{
int cnt=0;
go[x][y]=0;
for (int i=last[y-1],v;i;i=next[i])
if ((v=tov[i])>n) cnt++,go[x][y]+=back[v]+1;
go[x][y]+=go[x][y-1]*(cnt+1)+1;
}
for (int i=1;i<=p;i++)
for (int j=1;j<=n;j++)
f[i][j]=INF;
for (int i=2;i<=p;i++) f[i][n]=0;
for (int i=p-1;i>=1;i--)
for (int j=n-1;j>=i;j--)
for (int k=j+1;k<=n;k++) f[i][j]=min(f[i][j],f[i+1][k]+go[j][k]);
printf("%.4lf\n",(double)f[1][1]);
}
fclose(stdin),fclose(stdout);
return 0;
}
单调队列,最快的算法。
#include <iostream>
#include <cstring>
#include <cfloat>
#include <cstdio>
using namespace std;
typedef long double db;
const db INF=DBL_MAX/3;
const int N=700;
const int M=1500;
const int P=700;
db f[P+5][N+5],go[N+5][N+5],back[M+5];
int last[M+5],tov[M+5],next[M+5];
int que[N+5],pt[N+5],head,tail;
bool vis[M+5];
int n,m,tot,p,T;
inline void insert(int x,int y){tov[++tot]=y,next[tot]=last[x],last[x]=tot;}
db dfs(int x)
{
if (vis[x]) return back[x];
vis[x]=1,back[x]=0;
int cnt=0;
for (int i=last[x];i;i=next[i]) cnt++,back[x]+=dfs(tov[i]);
if (cnt) back[x]/=cnt;
return ++back[x];
}
void clearall()
{
memset(vis,0,sizeof vis);
for (;tot;tot--) tov[tot]=next[tot]=0;
for (int i=1;i<=m;i++) last[i]=0;
memset(f,0,sizeof f);
memset(go,0,sizeof go);
memset(back,0,sizeof back);
}
db calc(int x,int y,int z){return f[x][z]+go[y][z];}
int getp(int x,int p,int q)
{
int ret=0,l=1,r=q,mid;
while (l<=r)
{
mid=l+r>>1;
if (calc(x,mid,p)>=calc(x,mid,q)) l=(ret=mid)+1;
else r=mid-1;
}
return ret;
}
int main()
{
freopen("memory.in","r",stdin),freopen("memory.out","w",stdout);
for (scanf("%d",&T);T--;clearall())
{
scanf("%d%d%d",&n,&m,&p);
for (int i=2;i<=n;i++) insert(i-1,i);
for (int i=1,x,y;i<=m-n;i++) scanf("%d%d",&x,&y),insert(x,y);
for (int i=n+1;i<=m;i++) if (!vis[i]) dfs(i);
for (int x=1;x<=n;x++)
for (int y=x;y<=n;y++)
if (x==y) go[x][y]=0;
else
{
int cnt=0;
go[x][y]=0;
for (int i=last[y-1],v;i;i=next[i])
if ((v=tov[i])>n) cnt++,go[x][y]+=back[v]+1;
go[x][y]+=go[x][y-1]*(cnt+1)+1;
}
f[0][n]=f[1][n]=INF;
for (int i=2;i<=p;i++) f[i][n]=0;
for (int i=1;i<n;i++) f[p][i]=INF;
for (int i=p-1;i>=1;i--)
{
head=1,tail=0;
que[++tail]=n;
for (int j=n-1;j>=i;j--)
{
while (head!=tail&&pt[head+1]>=j) head++;
f[i][j]=calc(i+1,j,que[head]);
while (head!=tail&&pt[tail]<=getp(i+1,que[tail],j)) tail--;
que[++tail]=j,pt[tail]=getp(i+1,que[tail-1],que[tail]);
}
}
printf("%.4lf\n",(double)f[1][1]);
}
fclose(stdin),fclose(stdout);
return 0;
}
上界优化法。随便找个人(@jasonvictoryan)的贴上来。
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define maxn 705
#define maxm 1505
#define db double
#define mem(a,b) memset(a,b,sizeof(a))
#define min(a,b) (((a) < (b)) ? a : b)
#define max(a,b) (((a) > (b)) ? a : b)
using namespace std;
db f[maxn][maxn],g[maxm],s[maxm];
int n,m,p,T;
int head[maxm],t[maxm],next[maxm],sum;
int d[maxm];
int tot[maxm];
db a[maxn][maxn];
void insert(int x,int y){
t[++sum]=y;
next[sum]=head[x];
head[x]=sum;
}
void dfs(int x){
d[++d[0]]=x;
for(int tmp=head[x];tmp;tmp=next[tmp]) dfs(t[tmp]);
}
int main(){
freopen("memory.in","r",stdin);
freopen("memory.out","w",stdout);
scanf("%d",&T);
while (T--) {
mem(f,80);
mem(head,0);
sum=0;
d[0]=0;
mem(tot,0);
mem(g,0);
mem(s,0);
///
scanf("%d%d%d",&n,&m,&p);
fo(i,1,n-1) insert(i,i+1),tot[i]++;
fo(i,1,m-n) {
int x,y;
scanf("%d%d",&x,&y);
insert(x,y);
tot[x]++;
}
dfs(1);
fd(i,m,1) {
int w=d[i];
if (w>n) {
if (head[w]==0) g[w]=1;
else {
for(int tmp=head[w];tmp;tmp=next[tmp])
g[w]=g[w]+g[t[tmp]]/tot[w];
g[w]++;
}
}
else {
if (w==n) continue;
for(int tmp=head[w];tmp;tmp=next[tmp]) {
if (t[tmp]==w+1) continue;
s[w]+=g[t[tmp]];
}
}
}
fo(i,1,n) {
a[i][i]=0;
fo(j,i+1,n)
a[i][j]=a[i][j-1] * tot[j-1]+tot[j-1]+s[j-1];
}
f[1][1]=0;
fo(i,1,n) {
fo(j,1,i) {
if (f[i][j]>1e12) continue;
fo(k,i+1,min(i+40,n)) f[k][j+1]=min(f[k][j+1],f[i][j]+a[i][k]);
}
}
printf("%.4lf\n",f[n][p]);
}
return 0;
}
fanvree的方法&程序。
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
const int N=1600;
const long double eps=0.0000001;
const long long inf=10000000000000000ll;
int son[N],n,m,p,fa[N];
long double f[N],w[N][N],_k[N],_b[N];
int pre[N],tmp[N];
int check(long double mid)
{
fo(i,1,n) f[i]=inf,pre[i]=0;
f[1]=0;
fo(i,1,n-1)
fo(j,i+1,n)
if (f[i]+w[i][j]+mid<f[j]) f[j]=f[i]+w[i][j]+mid,pre[j]=i;
int num=0;
for(int now=n;now;now=pre[now]) num++;
return num;
}
int main()
{
freopen("memory.in","r",stdin);freopen("memory.out","w",stdout);
int T;
scanf("%d",&T);
while (T--)
{
scanf("%d%d%d",&n,&m,&p);
memset(son,0,sizeof son);
memset(_k,0,sizeof _k);
memset(_b,0,sizeof _b);
fo(i,1,n) son[i]++;
fo(i,1,m-n)
{
int x,y;
scanf("%d%d",&x,&y);
son[x]++;
fa[y]=x;
}
fd(i,m,n+1)
{
if (son[i]==0) _b[i]=0;
_b[i]++;
_b[fa[i]]+=_b[i]*1.0/son[fa[i]];
}
fd(i,n,1)
{
long double b=0,k=0;
fd(j,i-1,1)
{
long double kk=0,bb=0,p=1.0/son[j];
kk=k*p+(1-p);
bb=p*b+_b[j]+1;
k=kk;
b=bb;
w[j][i]=b/(1-k);
if (w[j][i]>inf) w[j][i]=inf;
}
}
long double l=0,r=0;
r=10000000ll;
long double ans=inf,ans1=inf;
while (l+eps<=r)
{
long double mid=(l+r)/2;
int num=check(mid);
if (num<=p)
{
long double sum=0;
for(int now=n;now!=1;now=pre[now]) sum+=w[pre[now]][now];
if (ans==inf || ans<sum+(num-p)*mid) ans=sum+(num-p)*mid;
if (num==p)
{
ans=sum;
break;
};
r=mid-eps;
} else l=mid+eps;
}
printf("%.4f\n",(double)ans);
}
}