玩玩24点(中)

《玩玩24点》系列:

  • 上篇

  • 中篇

在上篇中,我用上位机程序遍历了4个1~13的数的1820种组合,通过递归穷举计算出其中1362组的24点接法,并转换为二进制形式,放到单片机程序中,减少了单片机24点游戏程序的计算量,获得了不错的游戏体验。

上篇的最后留了一个疯狂暗示,但时至如今我也没有实现出来,因为写完上篇过后一直在准备各种比赛和考试,这两天也在写AVR单片机教程,一直都没有空去管它。

写这篇中篇的原因,是几个没有作业写甚至不需要高考的同学在玩一种24点游戏的升级版——用计算器按出5个1~20的随机整数,通过四则运算获得不超过50的最大有理数。经过一整个晚自修的手算后,他们想起我之前写的24点,来问我他们算出的是不是上界。

我写算法注重可复用性,毕竟不是std::都不写的OI。于是我很快就在上次程序的基础上写成了他们要的算法。

这个程序,以及人机计算能力的对比,虽然毫无悬念,但是先放一边。我对上篇所写的内容有一些更深的思考。

算式的可读性

实际上这个24点程序还远不完美。单片机经常在屏幕上输出诡异的解法,比如10 * 12 = 120, 120 / 5 = 24,这些是不符合人类计算逻辑的,正常人想到的都是10 / 5 = 2, 2 * 12 = 24。一个可行的方法是把递归搜索的顺序换一下,先减再加,先除后乘,在除法中优先用最大的数除以最小的数。但还是会出现12 / 5 = 12/5, 12/5 * 10 = 24这样的式子,最根本的算法还是根据表达式建立树,在树上调整顺序。也许4个数算24点的情况不需要这么复杂,但这是万能的、具有可扩展性的做法(也有可能是我想多了)。

这是上篇中提出的问题与解决方案,现在我认为需要修改。

首先,对于5, 10, 12的例子,我已经找到简单方法来使程序输出符合人类逻辑的算式了:搜索顺序改为减法、加法、结果为整数的除法、乘法、结果为分数的除法(代码可以在后面的程序中找到,这里就不单独放了)。在更新算法后我试玩了几十组,发现程序给出的结果都是比较正常的,因此这个问题至少在4数24点的问题中算是解决了。

其次,作为看似更好的算法,即使我能克服学数据结构时对树的恐惧,成功地用二叉树表达了算式,“在树上调整顺序”的概念也是模糊的。用什么规则来调整呢?如果是整数优先,那么10 / 5可以保证,但是在新的游戏规则中,如果运算数是2, 3, 33,最优结果是99/2,程序会先计算33 * 3,再计算99 / 2,而我的思路会是33 * 1.5。那么这算什么规则呢?其他的情况呢?理不清。

所以,调整一下搜索顺序,见好就收吧。

4数24点的优化

一位对计算机程序一无所知的数学竞赛同学对求解24点的算法十分感兴趣。在我绞尽脑汁跟他解释通这个程序后,他认为这个算法不好,因为有大量的重复计算。

有道理。比方说1, 2, 3,原来的算法会先算1 + 2,替换为3,用3, 3递归调用,得到6,这是1 + 2 + 3,然后还有1 + 3 + 22 + 3 + 11, 2, 3, 4就更多了。

他提出“分治”的策略:24一定是由两个中间结果加减乘除得到的,而每个中间结果也都是由两个运算数得到的。在为他凭空想出分治而震惊之余,我指出这是错的,这很显然。

但这个想法还是有一定启发性的。为了优化4数24点的求解算法,我想还不如枚举出所有可能的运算结构算了:

  1. a * b * c * d

  2. a + b + c + d

  3. a * b + c + d

  4. a * b * (c + d)

  5. a * b * c + d

  6. a * (b + c + d)

  7. a * b + c * d

  8. (a + b) * (c + d)

  9. (a * b + c) * d

  10. (a + b) * c + d

其中+代表加或减,*代表乘或除。偶数序号的结构都是前一个奇数序号结构的对偶,指把加减与乘除互换,加括号保证原有的优先级。

inline bool read_bit(int c, int b)
{
    return c & (1 << b);
}

class fast_vector
{
public:
    void push_back(const Rational& r)
    {
        data[size++] = r;
    }
    Rational* begin()
    {
        return data;
    }
    Rational* end()
    {
        return data + size;
    }
private:
    Rational data[1 << max_count];
    int size = 0;
};

using vector_type = fast_vector;

void all_sum(const std::vector& data, vector_type& result)
{
    auto end = (1 << data.size()) - 1;
    for (int c = 0; c != end; ++c)
    {
        Rational sum = 0;
        bool valid = true;
        for (int b = 0; b != data.size(); ++b)
            if (!read_bit(c, b))
                sum += data[b];
        for (int b = 0; b != data.size(); ++b)
            if (read_bit(c, b))
            {
                if (sum < data[b])
                {
                    valid = false;
                    break;
                }
                sum -= data[b];
            }
        if (valid)
            result.push_back(sum);
    }
}

void all_pro(const std::vector& data, vector_type& result)
{
    auto end = (1 << data.size()) - 1;
    for (int c = 0; c != end; ++c)
    {
        Rational pro = 1;
        bool valid = true;
        for (int b = 0; b != data.size(); ++b)
        {
            if (read_bit(c, b))
            {
                if (data[b] == 0)
                {
                    valid = false;
                    break;
                }
                pro /= data[b];
            }
            else
                pro *= data[b];
        }
        if (valid)
            result.push_back(pro);
    }
}

bool test_sum(const Rational& lhs, const Rational& rhs)
{
    if (lhs + rhs == target)
        return true;
    if (lhs < rhs && rhs - lhs == target)
        return true;
    if (rhs < lhs && lhs - rhs == target)
        return true;
    return false;
}

bool test_pro(const Rational& lhs, const Rational& rhs)
{
    if (lhs * rhs == target)
        return true;
    if (rhs != 0 && rhs / lhs == target)
        return true;
    if (lhs != 0 && lhs / rhs == target)
        return true;
    return false;
}

bool solve(int a, int b, int c, int d)
{
    std::vector data(4);
    data[0] = a;
    data[1] = b;
    data[2] = c;
    data[3] = d;

    // a * b * c * d
    {
        vector_type pro;
        all_pro(data, pro);
        for (const auto& r : pro)
            if (r == target)
                return true;
    }

    // a + b + c + d
    {
        vector_type sum;
        all_sum(data, sum);
        for (const auto& r : sum)
            if (r == target)
                return true;
    }

    // a * b + c + d
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto pm = data;
            pm.erase(pm.begin() + j);
            pm.erase(pm.begin() + i);
            std::vector md{ data[i], data[j] };
            vector_type pro;
            all_pro(md, pro);
            for (const auto& r : pro)
            {
                pm.push_back(r);
                vector_type sum;
                all_sum(pm, sum);
                for (const auto& r : sum)
                    if (r == target)
                        return true;
                pm.pop_back();
            }
        }

    // a * b * (c + d)
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto md = data;
            md.erase(md.begin() + j);
            md.erase(md.begin() + i);
            std::vector pm{ data[i], data[j] };
            vector_type sum;
            all_sum(pm, sum);
            for (const auto& r : sum)
            {
                md.push_back(r);
                vector_type pro;
                all_pro(md, pro);
                for (const auto& r : pro)
                    if (r == target)
                        return true;
                md.pop_back();
            }
        }

    // a * b * c + d
    for (int i = 0; i != 4; ++i)
    {
        auto md = data;
        md.erase(md.begin() + i);
        vector_type pro;
        all_pro(md, pro);
        for (const auto& r : pro)
            if (test_sum(data[i], r))
                return true;
    }

    // a * (b + c + d)
    for (int i = 0; i != 4; ++i)
    {
        auto pm = data;
        pm.erase(pm.begin() + i);
        vector_type sum;
        all_sum(pm, sum);
        for (const auto& r : sum)
            if (test_pro(data[i], r))
                return true;
    }

    // a * b + c * d
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto md2 = data;
            md2.erase(md2.begin() + j);
            md2.erase(md2.begin() + i);
            decltype(md2) md1{ data[i], data[j] };
            vector_type pro1, pro2;
            all_pro(md1, pro1);
            all_pro(md2, pro2);
            for (const auto& r1 : pro1)
                for (const auto& r2 : pro2)
                    if (test_sum(r1, r2))
                        return true;
        }

    // (a + b) * (c + d)
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto pm2 = data;
            pm2.erase(pm2.begin() + j);
            pm2.erase(pm2.begin() + i);
            decltype(pm2) pm1{ data[i], data[j] };
            vector_type sum1, sum2;
            all_sum(pm1, sum1);
            all_sum(pm2, sum2);
            for (const auto& r1 : sum1)
                for (const auto& r2 : sum2)
                    if (test_pro(r1, r2))
                        return true;
        }

    // (a * b + c) * d
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto rest = data;
            rest.erase(rest.begin() + j);
            rest.erase(rest.begin() + i);
            std::vector md{ data[i], data[j] };
            vector_type pro;
            all_pro(md, pro);
            for (const auto& r : pro)
            {
                for (int k = 0; k != 2; ++k)
                {
                    std::vector pm{ r, rest[k] };
                    vector_type sum;
                    all_sum(pm, sum);
                    for (const auto& r : sum)
                        if (test_pro(r, rest[1 - k]))
                            return true;
                }
            }
        }

    // (a + b) * c + d
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto rest = data;
            rest.erase(rest.begin() + j);
            rest.erase(rest.begin() + i);
            std::vector pm{ data[i], data[j] };
            vector_type sum;
            all_sum(pm, sum);
            for (const auto& r : sum)
            {
                for (int k = 0; k != 2; ++k)
                {
                    std::vector md{ r, rest[k] };
                    vector_type pro;
                    all_pro(md, pro);
                    for (const auto& r : pro)
                        if (test_sum(r, rest[1 - k]))
                            return true;
                }
            }
        }

    return false;
}

int main()
{
    auto start_time = std::clock();
    int count = 0;
    for (int a = 1; a <= max_num; ++a)
        for (int b = a; b <= max_num; ++b)
            for (int c = b; c <= max_num; ++c)
                for (int d = c; d <= max_num; ++d)
                    if (solve(a, b, c, d))
                        ++count;
    std::cout << count << std::endl;
    std::cout << (static_cast(std::clock()) - start_time) * 1000
        / CLOCKS_PER_SEC << "ms" << std::endl;
    return 0;
}

IntegerintRationalExpression的定义见上篇。

原算法没有使用std::vector数据结构,由于STL的糟糕性能,我写了个不涉及动态内存分配的fast_vector来替换存储运算结果的std::vector;运算数的懒得改了。

算法的核心在于all_sum函数,用于求出data数组中的元素通过加减法可以得到的所有结果:

void all_sum(const std::vector& data, vector_type& result)
{
    auto end = (1 << data.size()) - 1;
    for (int c = 0; c != end; ++c)
    {
        Rational sum = 0;
        bool valid = true;
        for (int b = 0; b != data.size(); ++b)
            if (!read_bit(c, b))
                sum += data[b];
        for (int b = 0; b != data.size(); ++b)
            if (read_bit(c, b))
            {
                if (sum < data[b])
                {
                    valid = false;
                    break;
                }
                sum -= data[b];
            }
        if (valid)
            result.push_back(sum);
    }
}

函数用一个整数c表示data数组中各元素取加号还是减号,当二进制c的第b位为0时(最低位为第0位),下标为b的元素取加号,否则取减号;c取不到0b11...1data.size()1),是因为不能所有元素都取减号。对于每个c,如果算出来的值是有效的,就把它追加到结果的数组中去。我把返回值写成了引用参数,虽然编译器很可能RVO(返回值优化),我还是手动写出来以明确我提升性能的意图。

all_pro函数类似,只不过计算的是积与商。

程序在VS2019中编译,配置为Release、x86,在没插电的最节能配置下的i7-7700HQ上测试,从命令行调用,优化算法的平均运行时间为55ms,而原算法为82ms,是有明显提升的。

概率问题

在一篇研究24点游戏的文章中,有这样一句话:

其实还有一个原因,就是有解的概率太小了。4个数字的话也就大约80%的题能算,如果算上人头牌,可解的题就只有75%了。

没错,在1820种可能的4数组合中,有1362种有解,比例为74.8%。

但是注意,我说的是“比例”而不是“概率”,这两者是有区别的。要计算“有解的概率”,必须先确定出题的方式。

如果是从1820道题目的题库中等概率地选择一道,类似与上篇中提到的单片机程序一样,这样每一道题被选中都是古典概型中的基本事件,有解概率就是74.8%。

如果是从52张扑克牌中等概率地选择4张,那么概率就不是74.8%,因为每一种题目出现的概率是不相等的。比如,6, 6, 6, 6出现的概率为\(1 / C_{52}^{4}\),而1, 2, 3, 4出现的概率为\(4! / C_{52}^{4}\),两者相差24倍。每一种4数的有序排列都是古典概型中的基本事件,有解概率需要重新计算。

std::set> solution;
int solved = 0;
int total = 0;
int card[4];
std::vector comb(4);
for (card[0] = 0; card[0] != 49; ++card[0])
    for (card[1] = card[0] + 1; card[1] != 50; ++card[1])
        for (card[2] = card[1] + 1; card[2] != 51; ++card[2])
            for (card[3] = card[2] + 1; card[3] != 52; ++card[3])
            {
                ++total;
                for (int i = 0; i != 4; ++i)
                    comb[i] = card[i] / 4 + 1;
                if (solution.find(comb) != solution.end())
                    ++solved;
            }
std::cout << solved << " / " << total << std::endl;

其中,solution已经保存了有解的4数组合。程序的输出为:

217817 / 270725

这个比例为80.5%,也是这种模型下有解的概率。

新款50点游戏

50点游戏的规则是,用5个1~20的整数通过四则运算得到不超过50的最大有理数。

为什么是50呢?如果是48的话,我想你也会问为什么是48的。唯一的一点道理,他们说,是这样比较考验一个人对数字的感觉。

上回4数的算法并不局限于4数,参数都可以通过全局变量来调整。50点相对于24点还改变了输出结果的规则,但只需要修改递归出口的条件和操作。在那个程序的基础上,50点很快就写好了。

// return whether the branch has found a better solution
bool solve(Integer count, const Rational* data, const Rational target, Rational* max, Expression* expr)
{
    // assume data is ordered
    if (count == 1)
    {
        if (*data <= target && *data > *max)
        {
            *max = *data;
            return true;
        }
        else
            return false;
    }
    auto end = data + count;
    auto before_end = end - 1;
    --count;
    Rational new_data[max_count - 1];
    auto new_end = new_data + count;
    bool optimize = false;

    // -
    for (auto lhs = data + 1; lhs != end; ++lhs)
        for (auto rhs = data; rhs != lhs; ++rhs)
        {
            auto dst = new_data;
            for (auto src = data; src != end; ++src)
                if (src != lhs && src != rhs)
                    *dst++ = *src;
            *dst = *lhs - *rhs;
            Expression temp(*lhs, '-', *rhs, *dst);
            if (temp.rhs == 0)
            {
                std::swap(temp.lhs, temp.rhs);
                temp.op = '+';
            }
            std::sort(new_data, new_end);
            if (solve(count, new_data, target, max, expr + 1))
            {
                optimize = true;
                *expr = temp;
            }
        }

    // +
    for (auto lhs = data; lhs != before_end; ++lhs)
        for (auto rhs = lhs + 1; rhs != end; ++rhs)
        {
            auto dst = new_data;
            for (auto src = data; src != end; ++src)
                if (src != lhs && src != rhs)
                    *dst++ = *src;
            *dst = *lhs + *rhs;
            Expression temp(*lhs, '+', *rhs, *dst);
            std::sort(new_data, new_end);
            if (solve(count, new_data, target, max, expr + 1))
            {
                optimize = true;
                *expr = temp;
            }
        }

    // / integer
    struct
    {
        const Rational* lhs;
        const Rational* rhs;
        Rational res;
    } div_frac[max_count * (max_count - 1)];
    Integer frac_size = 0;
    for (auto lhs = before_end; lhs != data - 1; --lhs)
        for (auto rhs = data; rhs != end; ++rhs)
        {
            if (lhs == rhs || *rhs == Rational(0))
                continue;
            auto res = *lhs / *rhs;
            if (res.den != 1)
            {
                div_frac[frac_size].lhs = lhs;
                div_frac[frac_size].rhs = rhs;
                div_frac[frac_size++].res = res;
                continue;
            }
            auto dst = new_data;
            for (auto src = data; src != end; ++src)
                if (src != lhs && src != rhs)
                    *dst++ = *src;
            *dst = res;
            Expression temp(*lhs, '/', *rhs, *dst);
            if (temp.rhs == 1)
            {
                if (Rational(1) < temp.lhs)
                    std::swap(temp.lhs, temp.rhs);
                temp.op = '*';
            }
            std::sort(new_data, new_end);
            if (solve(count, new_data, target, max, expr + 1))
            {
                optimize = true;
                *expr = temp;
            }
        }

    // *
    for (auto lhs = data; lhs != before_end; ++lhs)
        for (auto rhs = lhs + 1; rhs != end; ++rhs)
        {
            auto dst = new_data;
            for (auto src = data; src != end; ++src)
                if (src != lhs && src != rhs)
                    *dst++ = *src;
            *dst = *lhs * *rhs;
            Expression temp(*lhs, '*', *rhs, *dst);
            std::sort(new_data, new_end);
            if (solve(count, new_data, target, max, expr + 1))
            {
                optimize = true;
                *expr = temp;
            }
        }

    // / fraction
    for (Integer i = 0; i != frac_size; ++i)
    {
        auto dst = new_data;
        for (auto src = data; src != end; ++src)
            if (src != div_frac[i].lhs && src != div_frac[i].rhs)
                *dst++ = *src;
        *dst = div_frac[i].res;
        Expression temp(*div_frac[i].lhs, '/', *div_frac[i].rhs, *dst);
        std::sort(new_data, new_end);
        if (solve(count, new_data, target, max, expr + 1))
        {
            optimize = true;
            *expr = temp;
        }
    }

    return optimize;
}

Rational test(Rational* operand, const Rational target, std::ostream& os = std::cout)
{
    Expression expr[4];
    Rational result = 0;
    solve(5, operand, target, &result, expr);
    for (int i = 0; i != 4; ++i)
        os << operand[i] << ", ";
    os << operand[4] << ": ";
    os << result;
    if (result.den != 1)
        os << " = " << (double)result.num / result.den;
    os << std::endl;
    for (const auto& e : expr)
        os << '\t' << e << std::endl;
    return result;
}

同样,Integerint的类型别名,使程序可以处理100以内的整数;RationalExpression类的定义见上篇,不过去掉了对除法和取模运算的计数。

DFS依然有些难理解。24点中solve函数返回该路径下是否能计算出24,如果得到true,则调用者solve本身把当前操作的表达式写入expr数组,并直接return true,一路返回到test,并输出解法。但是,50点不能把50设置为唯一的目标,而是在每次获得结果时更新最优解。solve函数返回该路径下能否找到更优的解,如果为true,则调用者solve同样把当前操作的表达式写入expr数组,但不返回,而是继续试探下一路径。

不返回是比较好理解的,因为找到的不一定是最优解。不过如果找到50,则可以一路返回到底,避免不必要的搜索。由于可以算出50的输入占一大部分,这种优化可以显著加速全部输入的穷举,which原本需要十分钟。不过这一点是我一分钟前刚想出来的,还没放进代码。

无论递归深度,当路径中有更优解时,就立即更新expr数组,是正确的算法。这是因为,每一层的递归都只负责一个Expression空间,不同层互不干扰,因此这个写入只会覆盖本次调用中上一次写入或之前调用写入的表达式,其对应的结果没有当前找到的优,因此可以放心覆盖。由于从递归出口到最初调用的每一层调用都能得知这个最优解,因此最后获得的表达式是完整的。

对于单组输入,这个算法是NP的,对于所有输入而言更是。所以运算数个数、范围和运算符都受到严格限制,而且我感觉这个问题不会有P的算法。

5个1~20的数共有\(C_{20 + 5 - 1}^{5} = 42504\)种组合(没有同一个数最多4个的限制),全部求解一遍需要十分钟。下篇应该会解决一个规则更复杂的问题,由于5数50点已经跑得够慢了,我决定这个寒假里学习并发。

你可能感兴趣的:(玩玩24点(中))