POJ 3156 Interconnect 图论+数论
题目的意思是:
给一个有n个点,m条边的无向图
两点之间可以存在多条边
现在每次随机增加一条边
问使得全部点都连通需要增加多少次(期望值)
首先,求出所有连通分量。用并查集。
每次随机增加一条边的时候一共有两种情况:
1)这条边连接了两个不同的连通分量,它的概率是p
2)这条边在一个连通分量里,它的概率是q = 1 - p
前者可以改变连通分量的数量,后者不能
如果把当前图的状态视为一个子问题
那么就可以用动态规划解决问题了
图的状态可以表示为:有多少个连通分量,每个连通分量包含多少个点
比如说图的状态 (2, 3, 3) 表示有三个连通分量,每个连通分量包含的点的个数分别为 2, 3, 3
动态规划的转移方程为:
f = p*(1+r) + p*q*(2+r) + p*q^2*(3+r) ....
其中r为p发生后,新状态的期望值
这个东西高中的时候学过,呵呵。
而1)中也包含多种情况,需要两两枚举
最大的问题是,f的值是一个无限数列,它的极值很难求。但无论如何,有高手求出来了。。在这里:http://archive.cnblogs.com/a/1325929/
它的极值是 f = p * (1 / (1 - q) + r) / (1 - q)
我对照了一下标程,确实是这个。
后来我自己推导了一下,发现它可以化成多个等比数列相加的形式,求和后,发现当n趋向于无穷大的时候,它的极限就是上面这个公式。
(注意:i*q^i, 当0<q<1,i趋向于无穷大的时候等于0)
这样程序就可以写了。动态规划保存每个图的状态。
如果用python写,只要建立一个tuple到float的映射就可以了。非常方便。
java中也有List<int>到Double的映射。
c里面估计就得用hash了。
py代码,参照标程写的。
fi = open( ' in ' )
fo = open( ' out ' )
dp = {():0}
ti = 0
def get(s):
if s in dp:
return dp[s]
q = sum([i * (i - 1 ) for i in s]) * 1.0 / 2 / nn
res = 0
for i in range(len(s)):
for j in range(len(s)):
if i < j:
l = list(s)
del l[max(i,j)]
del l[min(i,j)]
l.append(s[i] + s[j])
l.sort()
r = get(tuple(l))
p = s[i] * s[j] * 1.0 / nn
res += p * ( 1 + r - r * q) / pow( 1 - q, 2 )
dp[s] = res
return res
while 1 :
a = fi.readline().split()
if a == None or len(a) != 2 :
break
N, M = int(a[0]), int(a[ 1 ])
nn = N * (N - 1 ) / 2
s = [ i for i in range(N) ]
for i in range(M):
u, v = [ int(i) for i in fi.readline().split() ]
u -= 1
v -= 1
k = s[u]
for j in range(N):
if s[j] == k:
s[j] = s[v]
ss = [ s.count(i) for i in set(s) ]
ss.sort()
print ' ---- ' , ti
mine = get(tuple(ss))
ans = float(fo.readline().strip())
print ' mine ' , mine, ' ans ' , ans
print len(dp)
ti += 1
标程
用很简洁的代码写了并查集,值得借鉴!
给一个有n个点,m条边的无向图
两点之间可以存在多条边
现在每次随机增加一条边
问使得全部点都连通需要增加多少次(期望值)
首先,求出所有连通分量。用并查集。
每次随机增加一条边的时候一共有两种情况:
1)这条边连接了两个不同的连通分量,它的概率是p
2)这条边在一个连通分量里,它的概率是q = 1 - p
前者可以改变连通分量的数量,后者不能
如果把当前图的状态视为一个子问题
那么就可以用动态规划解决问题了
图的状态可以表示为:有多少个连通分量,每个连通分量包含多少个点
比如说图的状态 (2, 3, 3) 表示有三个连通分量,每个连通分量包含的点的个数分别为 2, 3, 3
动态规划的转移方程为:
f = p*(1+r) + p*q*(2+r) + p*q^2*(3+r) ....
其中r为p发生后,新状态的期望值
这个东西高中的时候学过,呵呵。
而1)中也包含多种情况,需要两两枚举
最大的问题是,f的值是一个无限数列,它的极值很难求。但无论如何,有高手求出来了。。在这里:http://archive.cnblogs.com/a/1325929/
它的极值是 f = p * (1 / (1 - q) + r) / (1 - q)
我对照了一下标程,确实是这个。
后来我自己推导了一下,发现它可以化成多个等比数列相加的形式,求和后,发现当n趋向于无穷大的时候,它的极限就是上面这个公式。
(注意:i*q^i, 当0<q<1,i趋向于无穷大的时候等于0)
这样程序就可以写了。动态规划保存每个图的状态。
如果用python写,只要建立一个tuple到float的映射就可以了。非常方便。
java中也有List<int>到Double的映射。
c里面估计就得用hash了。
py代码,参照标程写的。
fi = open( ' in ' )
fo = open( ' out ' )
dp = {():0}
ti = 0
def get(s):
if s in dp:
return dp[s]
q = sum([i * (i - 1 ) for i in s]) * 1.0 / 2 / nn
res = 0
for i in range(len(s)):
for j in range(len(s)):
if i < j:
l = list(s)
del l[max(i,j)]
del l[min(i,j)]
l.append(s[i] + s[j])
l.sort()
r = get(tuple(l))
p = s[i] * s[j] * 1.0 / nn
res += p * ( 1 + r - r * q) / pow( 1 - q, 2 )
dp[s] = res
return res
while 1 :
a = fi.readline().split()
if a == None or len(a) != 2 :
break
N, M = int(a[0]), int(a[ 1 ])
nn = N * (N - 1 ) / 2
s = [ i for i in range(N) ]
for i in range(M):
u, v = [ int(i) for i in fi.readline().split() ]
u -= 1
v -= 1
k = s[u]
for j in range(N):
if s[j] == k:
s[j] = s[v]
ss = [ s.count(i) for i in set(s) ]
ss.sort()
print ' ---- ' , ti
mine = get(tuple(ss))
ans = float(fo.readline().strip())
print ' mine ' , mine, ' ans ' , ans
print len(dp)
ti += 1
标程
用很简洁的代码写了并查集,值得借鉴!
import
java.util.
*
;
import java.io.File;
import java.io.PrintWriter;
import java.io.FileNotFoundException;
public class interconnect_pm {
private static int nn;
public static void main(String[] args) throws FileNotFoundException {
Scanner in = new Scanner( new File( " in " ));
PrintWriter out = new PrintWriter( " ans.out " );
int n = in.nextInt();
nn = (n * (n - 1 )) / 2 ;
int m = in.nextInt();
int [] p = new int [n];
for ( int i = 0 ; i < n; i ++ ) p[i] = i;
for ( int i = 0 ; i < m; i ++ ) {
int u = in.nextInt();
int v = in.nextInt();
u -- ;
v -- ;
int k = p[u];
for ( int j = 0 ; j < n; j ++ ) {
if (p[j] == k) {
p[j] = p[v];
}
}
}
List < Integer > st = new ArrayList < Integer > ();
for ( int i = 0 ; i < n; i ++ ) {
int s = 0 ;
for ( int j = 0 ; j < n; j ++ ) {
if (p[j] == i) s ++ ;
}
if (s > 0 ) {
st.add(s);
}
}
Collections.sort(st);
List < Integer > fn = new ArrayList < Integer > ();
fn.add(n);
mem.put(fn, 0.0 );
out.println(get(st));
System.out.println(mem.size());
out.close();
}
static Map < List < Integer > , Double > mem = new HashMap < List < Integer > , Double > ();
private static double get(List < Integer > st) {
Double ret = mem.get(st);
if (ret != null ) return ret;
int m = st.size();
int [][] a = new int [m][m];
for ( int i = 0 ; i < m; i ++ ) {
for ( int j = i + 1 ; j < m; j ++ ) {
a[i][j] = st.get(i) * st.get(j);
}
}
int s = 0 ;
for ( int i = 0 ; i < m; i ++ ) {
s += st.get(i) * (st.get(i) - 1 ) / 2 ;
}
double res = 0 ;
for ( int i = 0 ; i < m; i ++ ) {
for ( int j = i + 1 ; j < m; j ++ ) {
List < Integer > ss = new ArrayList < Integer > (st.size() - 1 );
boolean q = true ;
int z = st.get(i) + st.get(j);
for ( int k = 0 ; k < m; k ++ ) {
if (k != i && k != j) {
int zz = st.get(k);
if (q && zz >= z) {
q = false ;
ss.add(z);
}
ss.add(zz);
}
}
if (q)
ss.add(z);
double p = a[i][j] * 1.0 / (nn - s);
double e = a[i][j] * 1.0 / (( 1 - s * 1.0 / nn) * ( 1 - s * 1.0 / nn) * nn);
e = e + get(ss) * p;
res += e;
}
}
System.out.println(st);
mem.put(st, res);
return res;
}
}
import java.io.File;
import java.io.PrintWriter;
import java.io.FileNotFoundException;
public class interconnect_pm {
private static int nn;
public static void main(String[] args) throws FileNotFoundException {
Scanner in = new Scanner( new File( " in " ));
PrintWriter out = new PrintWriter( " ans.out " );
int n = in.nextInt();
nn = (n * (n - 1 )) / 2 ;
int m = in.nextInt();
int [] p = new int [n];
for ( int i = 0 ; i < n; i ++ ) p[i] = i;
for ( int i = 0 ; i < m; i ++ ) {
int u = in.nextInt();
int v = in.nextInt();
u -- ;
v -- ;
int k = p[u];
for ( int j = 0 ; j < n; j ++ ) {
if (p[j] == k) {
p[j] = p[v];
}
}
}
List < Integer > st = new ArrayList < Integer > ();
for ( int i = 0 ; i < n; i ++ ) {
int s = 0 ;
for ( int j = 0 ; j < n; j ++ ) {
if (p[j] == i) s ++ ;
}
if (s > 0 ) {
st.add(s);
}
}
Collections.sort(st);
List < Integer > fn = new ArrayList < Integer > ();
fn.add(n);
mem.put(fn, 0.0 );
out.println(get(st));
System.out.println(mem.size());
out.close();
}
static Map < List < Integer > , Double > mem = new HashMap < List < Integer > , Double > ();
private static double get(List < Integer > st) {
Double ret = mem.get(st);
if (ret != null ) return ret;
int m = st.size();
int [][] a = new int [m][m];
for ( int i = 0 ; i < m; i ++ ) {
for ( int j = i + 1 ; j < m; j ++ ) {
a[i][j] = st.get(i) * st.get(j);
}
}
int s = 0 ;
for ( int i = 0 ; i < m; i ++ ) {
s += st.get(i) * (st.get(i) - 1 ) / 2 ;
}
double res = 0 ;
for ( int i = 0 ; i < m; i ++ ) {
for ( int j = i + 1 ; j < m; j ++ ) {
List < Integer > ss = new ArrayList < Integer > (st.size() - 1 );
boolean q = true ;
int z = st.get(i) + st.get(j);
for ( int k = 0 ; k < m; k ++ ) {
if (k != i && k != j) {
int zz = st.get(k);
if (q && zz >= z) {
q = false ;
ss.add(z);
}
ss.add(zz);
}
}
if (q)
ss.add(z);
double p = a[i][j] * 1.0 / (nn - s);
double e = a[i][j] * 1.0 / (( 1 - s * 1.0 / nn) * ( 1 - s * 1.0 / nn) * nn);
e = e + get(ss) * p;
res += e;
}
}
System.out.println(st);
mem.put(st, res);
return res;
}
}