Good Luck Google-code-jam 2013-Round-1A

这个题目还是很有意思,它没有给出解决问题所需的全部信息,而是只给了部分信息,来猜测正确答案是什么。在一定的概率下会猜测到正确的结果。

1 问题描述

从2到M中随机选择N个数,允许重复。
再从这N个数中随机选择它的一个子集,计算出这个子集中元素的乘积$p_1$。
重复上一步,直到得到了K个乘积$p_1,p_2,…,p_k$。
要求:给出这K个乘积,推测原来的N个数
具体的输入输出要求,不详细介绍了,到codejam网站上看去吧

2 解决方案

从2到M中选择N个数,若考虑顺序那么就有$(M-1)^N$种,若不考虑顺序就大大减少了。

对于题目的第二个输入N=12, M=8, 不考虑顺序就只有18564种,不算太大。
那么,我们的任务就是根据这K个乘积,选出概率最大的一种组合
现在我们计算在乘积为$p_1,p_2,…,p_k$条件下,原组合为C的概率: $$ P(C|p_1,p_2,...,p_k) = \frac{P(C)*P(p_1,p_2,...,p_k|C)}{P(p_1,p_2,...,p_k)} $$ 这个就是贝叶斯概率公式,一般形式写作:$ P(A|B)=\frac{P(A)P(B|A)}{P(B)} $ 机器学习中有一类重要的分类方法就是基于这个贝叶斯公式。

由于我们的目标是找出概率最大的C,所以不必把这个公式中的所有的项都计算出来,例如分母为P(p1,p2,…,pk)对于特定的K个乘积在所有的组合下都是不变的。
只需要计算$P(C)$和\(P(p_1,p_2,\ldots,p_k|C)\), 对于由于组合C,由于不考虑顺序那么,每个C被取得概率就不等了,比如234会出现6次,而333只会出现一次,计算这个P(C): $$ P(C)= \frac{N!/(c_2!c_3!\ldots c_m!)}{(M-1)^N}$$ 其中$c_2,c_3,\ldots,c_i$表示组合中数字i出现的次数

$$ P(p_1,p_2,\ldots,p_k|C) = P(p_1|C)P(p_2|C)\ldots P(p_k|C)$$

其中$P(p|C) = \frac{元素积为p的子集个数}{2^N}$
枚举C的所有子集就可以计算出上述$P(p|C)$。
现在就知道怎么实现了:对于所有的组合计算条件概率,从中选出概率最大的一个组合。 由于会有很多组乘积,需要进行推测,所以这里将会有大量的重复计算,所以,最好将P和P(p|C)全部提前计算出来。

3 实现和结果分析

 

3.1 预计算

预计算:计算出某个组合的概率,和这种组合下子集积的概率

# coding: UTF-8

import cPickle as pickle

from array import array

from sys import stdout

from sys import stderr

from random import randint



def fact(n):

    if n == 0:

        return 1

    return n * fact(n-1)



def probability(nums):

    p = fact(len(nums))

    i = 1

    c = 1

    while i < len(nums):

        if (nums[i] == nums[i-1]):

            c += 1

        else:

            p /= fact(c)

            c = 1

        i += 1

    p /= fact(c)

    return p



gcount = 0

def pre_compute(N, M):

    fo = open('dump.dat', 'w')

    nums = array('i', xrange(N))

    pset = {}



    def subset(d, p):

        if d >= N:

            if p in pset:

                pset[p] += 1

            else:

                pset[p] = 1

        else:

            subset(d+1, p)

            subset(d+1, p*nums[d])



    def search(d, lt):

        global gcount

        if d >= N:

            gcount += 1

            if (gcount % 100) == 0:

                print "generated %d.." % gcount

            pickle.dump(nums, fo)

            pickle.dump(probability(nums), fo)

            pset.clear()

            subset(0, 1)

            pickle.dump(pset, fo)

        else:

            for i in xrange(lt, M+1):

                nums[d] = i

                search(d+1, i)



    search(0, 2)

    fo.close()

预计算大概需要1m多,而载入只需要2s左右,还是节省了很多时间的

载入的代码:

def load():

    fo = open("dump.dat")

    count = 0

    table = []

    while True:

        try:

            nums=pickle.load(fo)

            prob=pickle.load(fo)

            pset=pickle.load(fo)

            table.append([nums, prob, pset])

            count += 1

            if (count % 1000) == 0:

                stderr.write("loaded %d..\n" % count)

        except EOFError:

            break

    return table

3.2 处理过程,遍历所有的组合找到概率最大的组合

def process_case(case):

    table = load()



    (R, N, M, K) = map(int, raw_input().split())

    print "Case #%d:" % case

    for r in xrange(R):

        products = map(int, raw_input().split())

        max_prob = -1

        nums = []

        for item in table:

            prob = item[1]

            pset = item[2]

            for p in products:

                if p not in pset:

                    prob = 0

                    break

                prob *= pset[p]

            if prob > max_prob:

                max_prob = prob

                nums = item[0]



        for n in nums:

            stdout.write(str(n))

        stdout.write("\n")

3.3 自己构造输入输出

其实这个题是可以自己构造输入的,然后自己写个judge的程序, 对于第二种输入,阈值要求是1120,自己构造输入,然后用上面的程序得到猜测结果, 大概猜对了1300左右。

构造输入和judge代码如下:

# 生成样例数据

def generate(R, N, M, K):

    fi = open('input.txt', 'w')

    fa = open('right-answer.txt', 'w')

    fi.write("1\n%d %d %d %d\n" % (R, N, M, K))

    nums = array('i', xrange(N))

    for i in xrange(R):

        for j in xrange(N):

            nums[j] = randint(2, M)

            fa.write("%d" % nums[j])

        fa.write('\n')

        for j in xrange(K):

            p = 1

            for n in nums:

                if randint(0, 1) == 1:

                    p *= n

            fi.write("%d " % p)

        fi.write('\n')

    fi.close()

    fa.close()



def judge(submit_file, answer_file, R):

    fs = open(submit_file)

    fa = open(answer_file)



    fs.readline()

    count = 0

    for i in xrange(R):

        ls = list(fs.readline().strip())

        la = list(fa.readline().strip())

        ls.sort()

        la.sort()

        if ls == la:

            count += 1

    fs.close()

    fa.close()

    return count

4 总结

谷歌的出题思路跟技术的发展趋势是一致的,基于统计的机器学习方法有很广泛的应用。 这个题就是使用贝叶斯公式(条件概率公式)进行推断。


ps1: 上面的python代码在我的机器上运行3m,而题目要求是四分钟内提交答案,时间有点紧,勉强够用。使用C++会更快的。不过,我后来找到了一个叫pypy的python解释器,它的速度比python官方的解释器快好多,一分钟就算出结果了。
ps2: 吐嘈下python

def f():

    x = 1

    def g():

        x += 1

    g()

调用f()就会出错,直到python3才给出个nonlocal关键字解决这个问题。
关于这点,我只想说,设计语言也太不专业了。。。

Date: 2013-05-10 Fri

Author: liyongmou

Org version 7.9.2 with Emacs version 24

Validate XHTML 1.0

 

你可能感兴趣的:(Google)