这个题目还是很有意思,它没有给出解决问题所需的全部信息,而是只给了部分信息,来猜测正确答案是什么。在一定的概率下会猜测到正确的结果。
从2到M中随机选择N个数,允许重复。
再从这N个数中随机选择它的一个子集,计算出这个子集中元素的乘积$p_1$。
重复上一步,直到得到了K个乘积$p_1,p_2,…,p_k$。
要求:给出这K个乘积,推测原来的N个数
具体的输入输出要求,不详细介绍了,到codejam网站上看去吧
从2到M中选择N个数,若考虑顺序那么就有$(M-1)^N$种,若不考虑顺序就大大减少了。
对于题目的第二个输入N=12, M=8, 不考虑顺序就只有18564种,不算太大。
那么,我们的任务就是根据这K个乘积,选出概率最大的一种组合
现在我们计算在乘积为$p_1,p_2,…,p_k$条件下,原组合为C的概率: $$ P(C|p_1,p_2,...,p_k) = \frac{P(C)*P(p_1,p_2,...,p_k|C)}{P(p_1,p_2,...,p_k)} $$ 这个就是贝叶斯概率公式,一般形式写作:$ P(A|B)=\frac{P(A)P(B|A)}{P(B)} $ 机器学习中有一类重要的分类方法就是基于这个贝叶斯公式。
由于我们的目标是找出概率最大的C,所以不必把这个公式中的所有的项都计算出来,例如分母为P(p1,p2,…,pk)对于特定的K个乘积在所有的组合下都是不变的。
只需要计算$P(C)$和\(P(p_1,p_2,\ldots,p_k|C)\), 对于由于组合C,由于不考虑顺序那么,每个C被取得概率就不等了,比如234会出现6次,而333只会出现一次,计算这个P(C): $$ P(C)= \frac{N!/(c_2!c_3!\ldots c_m!)}{(M-1)^N}$$ 其中$c_2,c_3,\ldots,c_i$表示组合中数字i出现的次数
其中$P(p|C) = \frac{元素积为p的子集个数}{2^N}$
枚举C的所有子集就可以计算出上述$P(p|C)$。
现在就知道怎么实现了:对于所有的组合计算条件概率,从中选出概率最大的一个组合。 由于会有很多组乘积,需要进行推测,所以这里将会有大量的重复计算,所以,最好将P和P(p|C)全部提前计算出来。
预计算:计算出某个组合的概率,和这种组合下子集积的概率
# coding: UTF-8 import cPickle as pickle from array import array from sys import stdout from sys import stderr from random import randint def fact(n): if n == 0: return 1 return n * fact(n-1) def probability(nums): p = fact(len(nums)) i = 1 c = 1 while i < len(nums): if (nums[i] == nums[i-1]): c += 1 else: p /= fact(c) c = 1 i += 1 p /= fact(c) return p gcount = 0 def pre_compute(N, M): fo = open('dump.dat', 'w') nums = array('i', xrange(N)) pset = {} def subset(d, p): if d >= N: if p in pset: pset[p] += 1 else: pset[p] = 1 else: subset(d+1, p) subset(d+1, p*nums[d]) def search(d, lt): global gcount if d >= N: gcount += 1 if (gcount % 100) == 0: print "generated %d.." % gcount pickle.dump(nums, fo) pickle.dump(probability(nums), fo) pset.clear() subset(0, 1) pickle.dump(pset, fo) else: for i in xrange(lt, M+1): nums[d] = i search(d+1, i) search(0, 2) fo.close()
预计算大概需要1m多,而载入只需要2s左右,还是节省了很多时间的
载入的代码:
def load(): fo = open("dump.dat") count = 0 table = [] while True: try: nums=pickle.load(fo) prob=pickle.load(fo) pset=pickle.load(fo) table.append([nums, prob, pset]) count += 1 if (count % 1000) == 0: stderr.write("loaded %d..\n" % count) except EOFError: break return table
def process_case(case): table = load() (R, N, M, K) = map(int, raw_input().split()) print "Case #%d:" % case for r in xrange(R): products = map(int, raw_input().split()) max_prob = -1 nums = [] for item in table: prob = item[1] pset = item[2] for p in products: if p not in pset: prob = 0 break prob *= pset[p] if prob > max_prob: max_prob = prob nums = item[0] for n in nums: stdout.write(str(n)) stdout.write("\n")
其实这个题是可以自己构造输入的,然后自己写个judge的程序, 对于第二种输入,阈值要求是1120,自己构造输入,然后用上面的程序得到猜测结果, 大概猜对了1300左右。
构造输入和judge代码如下:
# 生成样例数据 def generate(R, N, M, K): fi = open('input.txt', 'w') fa = open('right-answer.txt', 'w') fi.write("1\n%d %d %d %d\n" % (R, N, M, K)) nums = array('i', xrange(N)) for i in xrange(R): for j in xrange(N): nums[j] = randint(2, M) fa.write("%d" % nums[j]) fa.write('\n') for j in xrange(K): p = 1 for n in nums: if randint(0, 1) == 1: p *= n fi.write("%d " % p) fi.write('\n') fi.close() fa.close() def judge(submit_file, answer_file, R): fs = open(submit_file) fa = open(answer_file) fs.readline() count = 0 for i in xrange(R): ls = list(fs.readline().strip()) la = list(fa.readline().strip()) ls.sort() la.sort() if ls == la: count += 1 fs.close() fa.close() return count
谷歌的出题思路跟技术的发展趋势是一致的,基于统计的机器学习方法有很广泛的应用。 这个题就是使用贝叶斯公式(条件概率公式)进行推断。
ps1: 上面的python代码在我的机器上运行3m,而题目要求是四分钟内提交答案,时间有点紧,勉强够用。使用C++会更快的。不过,我后来找到了一个叫pypy的python解释器,它的速度比python官方的解释器快好多,一分钟就算出结果了。
ps2: 吐嘈下python
def f(): x = 1 def g(): x += 1 g()
调用f()就会出错,直到python3才给出个nonlocal关键字解决这个问题。
关于这点,我只想说,设计语言也太不专业了。。。