Lua中 table.sort排序算法原理详解2022

table.sort的介绍

table.sort是lua内置的一个对table排序的函数(仅对数组),函数原型:
table.sort (list [, comp])
list:待排序的数组
comp:比较函数,接受list的两个元素,返回true/false

table.sort的算法原理

table.sort的排序包含了快速排序的迭代实现和递归实现两种思路
算法源码在lua源码的ltablib.c文件

static void auxsort (lua_State *L, IdxT lo, IdxT up,
                                   unsigned int rnd) {
  while (lo < up) {  /* loop for tail recursion */
    IdxT p;  /* Pivot index */
    IdxT n;  /* to be used later */
    /* sort elements 'lo', 'p', and 'up' */
    lua_geti(L, 1, lo);
    lua_geti(L, 1, up);
    if (sort_comp(L, -1, -2))  /* a[up] < a[lo]? */
      set2(L, lo, up);  /* swap a[lo] - a[up] */
    else
      lua_pop(L, 2);  /* remove both values */
    if (up - lo == 1)  /* only 2 elements? */
      return;  /* already sorted */
    if (up - lo < RANLIMIT || rnd == 0)  /* small interval or no randomize? */
      p = (lo + up)/2;  /* middle element is a good pivot */
    else  /* for larger intervals, it is worth a random pivot */
      p = choosePivot(lo, up, rnd);
    lua_geti(L, 1, p);
    lua_geti(L, 1, lo);
    if (sort_comp(L, -2, -1))  /* a[p] < a[lo]? */
      set2(L, p, lo);  /* swap a[p] - a[lo] */
    else {
      lua_pop(L, 1);  /* remove a[lo] */
      lua_geti(L, 1, up);
      if (sort_comp(L, -1, -2))  /* a[up] < a[p]? */
        set2(L, p, up);  /* swap a[up] - a[p] */
      else
        lua_pop(L, 2);
    }
    if (up - lo == 2)  /* only 3 elements? */
      return;  /* already sorted */
    lua_geti(L, 1, p);  /* get middle element (Pivot) */
    lua_pushvalue(L, -1);  /* push Pivot */
    lua_geti(L, 1, up - 1);  /* push a[up - 1] */
    set2(L, p, up - 1);  /* swap Pivot (a[p]) with a[up - 1] */
    p = partition(L, lo, up);
    /* a[lo .. p - 1] <= a[p] == P <= a[p + 1 .. up] */
    if (p - lo < up - p) {  /* lower interval is smaller? */
      auxsort(L, lo, p - 1, rnd);  /* call recursively for lower interval */
      n = p - lo;  /* size of smaller interval */
      lo = p + 1;  /* tail call for [p + 1 .. up] (upper interval) */
    }
    else {
      auxsort(L, p + 1, up, rnd);  /* call recursively for upper interval */
      n = up - p;  /* size of smaller interval */
      up = p - 1;  /* tail call for [lo .. p - 1]  (lower interval) */
    }
    if ((up - lo) / 128 > n) /* partition too imbalanced? */
      rnd = l_randomizePivot();  /* try a new randomization */
  }  /* tail call auxsort(L, lo, up, rnd) */
}

table.sort的本质是快速排序,对常规的快速排序增加了如下优化:
1.保证数组 arr[1] <= arr[p] <= arr[len],p:参照值下标,len:数组大小
2.对只有2个或者3个元素的序列,直接比较大小交换后返回
3.根据参照值区分元素后,元素少的半边继续递归排序,元素多的半边则继续循环排序,减少递归次数
4.试图通过随机参照值避免左边和右边元素数量差距过大,但实际并未生效;初始的一个无序数组,无论取哪个位置做参照值,最终还是可能导致左边和右边元素数量差距过大;而且只有做完一轮排序,才能知道左右元素数量差距是否过大,这时rnd已经不起作用了,个人认为rnd参数是个败笔

算法lua实现如下(附超详细注释,可调试运行):

辅助函数
local swap = function(arr, i, j)
    if i == j then return end
    local tmp = arr[i]
    arr[i] = arr[j]
    arr[j] = tmp
end

local comp_fun -- 用户自定义比较函数

local sort_comp = function(arr, a, b)
    if not comp_fun then
        return arr[a] < arr[b]
    else
        return comp_fun(arr[a], arr[b])
    end
end

--随机一个rnd,lua源码使用一个clock_t和time_t转换成的随机数,这里直接用随机数模拟
local function l_randomizePivot()
    return math.rand(1, 10000000)
end

--在数组中间区域随机一个值做pivot
local function choosePivot(low, upper, rnd)
    assert(rnd >= 0 and rnd == math.floor(rnd))
    local r4 = (upper - low) / 4
    local p = rnd % (r4 * 2) + (low + r4)
    assert(low + r4 <= p and p <= upper - r4)
    return p
end
算法核心代码
--以pivot区分arr的元素
local function partition(arr, low, upper)
    local i = low
    local j = upper - 1 -- j - i >= 2
    local p = j
    while true do
        i = i + 1
        while sort_comp(arr, i, p) do --在左边找到一个大于arr[p]的元素
            if i == p then -- arr[i] < arr[p]时 assert(i < p)
                error("invalid order function for sorting")
            end
            i = i + 1
        end
        j = j - 1
        while sort_comp(arr, p, j) do --在右边找到一个小于arr[p]的元素
            if j < i then -- arr[j] > arr[p]时 assert(j > i), arr[0~i] <= arr[p]
                error("invalid order function for sorting")
            end
            j = j - 1
        end
        if j < i then --i = j + 1,i到了比arr[p]大的区域,j到了比arr[p]小的区域,
            swap(arr, p, i) --把参照值换回中间,p左边的元素小于arr[p],p右边的元素大于arr[p] (arr[i]大于arr[p])
            return i
        end
        swap(arr, i, j)
    end
end

local auxsort --新版本的table.sort 与老版无太大区别,多了个rnd来随机pivot,但并未开放该功能
auxsort = function(arr, low, upper, rnd) --rnd暂未使用,随机pivot
    assert(low == math.floor(low))
    assert(upper == math.floor(upper))
    assert(rnd >= 0 and rnd == math.floor(rnd)) --无符号整型
    while low < upper do
        local p, n
        if sort_comp(arr, upper, low) then --满足arr[low] < arr[upper]
            swap(arr, low, upper)
        end
        if upper - low == 1 then break end --仅有2个元素
        if upper - low < 100 or rnd == 0 then
            p = math.floor((low + upper) / 2)
        else
            p = choosePivot(low, upper, rnd) --随机参照数pivot
        end
        if sort_comp(arr, p, low) then --使 arr[low] <= arr[p] <= arr[upper]
            swap(arr, p, low)
        elseif sort_comp(arr, upper, p) then
            swap(arr, p, upper)
        end
        if upper - low == 2 then break end --仅有3个元素,已有序
        swap(arr, p, upper - 1) --参照值放置在参与比较序列的尾部,保证参照值不会被交换
        p = partition(arr, low, upper)
        if p - low < upper - p then --元素少的半边继续递归,元素多的继续循环
            auxsort(arr, low, p - 1, rnd)
            n = p - low
            low = p + 1
        else
            auxsort(arr, p + 1, upper, rnd)
            n = upper - p
            upper = p - 1
        end
        if (upper - low) / 128 > n then-- 元素多的部分比元素少的部分大太多
            rnd = l_randomizePivot() -- rnd不会被任何地方使用,仅仅表达rnd需要被随机
        end
    end
end

local function sort(arr, fun)
    local n = #arr
    if n > 1 then
        if fun then
            assert(type(fun) == "function")
        end
        comp_fun = fun
        auxsort(arr, 1, n, 0)
        comp_fun = nil
    end
end
测试代码,随机n个数字构建一个数组
local function test()
    local rand = math.random
    local tinsert = table.insert
    local arr = {}
    local len = 10
    print("排序前:")
    for i = 1, len do
        local rn = rand(1, 1000)
        tinsert(arr, rn)
        print(rn)
    end
    sort(arr)
    print("排序后:")
    for i, v in ipairs(arr) do
        print(i, v)
        if i > 1 and arr[i - 1] > v then
            error(string.format("序列不是有序的%d", 1))
        end
    end
    print("----------------")
    -- if true then return end
    local tarr = {}
    print("tb排序前:")
    for i = 1, len do
        local rn = rand(1, 10)
        local t = {
            val = rn,
            tm = os.time() + i --模拟数据创建时间不一致
        }
        tinsert(tarr, t)
        print(i, t.val, t.tm)
    end
    sort(tarr, 
        function(a, b)
            if a.val ~= b.val then
                return a.val > b.val
            else
                return a.tm < b.tm
            end
        end
    )
    print("tb排序后:")
    for i, v in ipairs(tarr) do
        print(i, v.val, v.tm)
    end
end

local function main()
    math.randomseed(tostring(os.time()):reverse():sub(1, 7)) --设置时间种子
    test()
end

main()

老版table.sort实现算法

源码如下,与新版区别不大
static void auxsort (lua_State *L, int l, int u) {
    while (l < u) {  /* for tail recursion */
      int i, j;
      /* sort elements a[l], a[(l+u)/2] and a[u] */
      lua_rawgeti(L, 1, l);
      lua_rawgeti(L, 1, u);
      if (sort_comp(L, -1, -2))  /* a[u] < a[l]? */
        set2(L, l, u);  /* swap a[l] - a[u] */
      else
        lua_pop(L, 2);
      if (u-l == 1) break;  /* only 2 elements */
      i = (l+u)/2;
      lua_rawgeti(L, 1, i);
      lua_rawgeti(L, 1, l);
      if (sort_comp(L, -2, -1))  /* a[i]= P */
        while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) {
          if (i>u) luaL_error(L, "invalid order function for sorting");
          lua_pop(L, 1);  /* remove a[i] */
        }
        /* repeat --j until a[j] <= P */
        while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) {
          if (j
算法lua代码实现,附超详细注释,可调试运行
local auxsort --老lua版本的table.sort 与新版差别不大,只是无rnd
auxsort = function(arr, low, upper)
    assert(low == math.floor(low))
    assert(upper == math.floor(upper))
    while low < upper do
        local i, j
        if sort_comp(arr, upper, low) then --保证头比尾部小
            swap(arr, low, upper)
        end
        if upper - low == 1 then break end --仅有2个元素
        i = math.floor((low + upper) / 2)
        if sort_comp(arr, i, low) then --使 arr[low] <= arr[i] <= arr[upper]
            swap(arr, i, low)
        elseif sort_comp(arr, upper, i) then
            swap(arr, i, upper)
        end
        if upper - low == 2 then break end --仅有3个元素,已有序
        swap(arr, i, upper - 1) --参照值放置在参与比较序列的尾部,保证参照值不会被交换
        i = low
        j = upper - 1
        local p = j
        while true do
            i = i + 1
            while sort_comp(arr, i, p) do --在左边找到一个大于arr[p]的元素
                i = i + 1
            end
            j = j - 1
            while sort_comp(arr, p, j) do --在右边找到一个小于arr[p]的元素
                j = j - 1
            end
            if i > j then break end
            swap(arr, i, j)
        end
        swap(arr, i, p) --把参照值换回中间,p左边的元素小于arr[p],p右边的元素大于arr[p] (arr[i]大于arr[p])
        if i - low < upper - i then --元素少的半边继续递归,元素多的继续循环
            j = low
            i = i - 1
            low = i + 2
        else
            j = i + 1
            i = upper
            upper = j - 2
        end
        auxsort(arr, j, i)
    end
end

你可能感兴趣的:(Lua中 table.sort排序算法原理详解2022)