Lua 实现函数式编程中的 Curry 特性

核心思路就是构造一个 argt 表,先把 bound 指定的位置参数填充,再把后来提供的参数填入空格。

代码如下,直接保存为 curry.lua 然后 require 即可。用例见注释:

--[[
  summary: a simple implementation of curry functions
  author: Hustlion([email protected])
  date: 2019-7-8 17:18:3
  based on: 
  https://github.com/lua-stdlib/lua-stdlib/blob/499cf61feb9aa6c44f42d546990266d38871270e/lib/std/functional.lua
  https://www.sitepoint.com/currying-in-functional-javascript/

]]
local table_unpack	= table.unpack or unpack

-- No need to recurse because functables are second class citizens in
-- Lua:
-- func=function () print "called" end
-- func() --> "called"
-- functable=setmetatable ({}, {__call=func})
-- functable() --> "called"
-- nested=setmetatable ({}, {__call=functable})
-- nested()
-- --> stdin:1: attempt to call a table value (global 'd')
-- --> stack traceback:
-- -->	stdin:1: in main chunk
-- -->		[C]: in ?
local function callable (x)
  if type (x) == "function" then return x end
  return (getmetatable (x) or {}).__call
end

getmetamethod = function (x, n)
  local m = (getmetatable (x) or {})[n]
  if callable (m) then return m end
end

local function unpack (t, i, j)
  if j == nil then
    -- respect __len, and then maxn if nil j was passed
    local m = getmetamethod (t, "__len")
    j = m and m (t) or maxn (t)
  end
  local fn = getmetamethod (t, "__unpack") or table_unpack
  return fn (t, tonumber (i) or 1, tonumber (j))
end

-- 核心思路就是构造一个 argt 表,先把 bound 指定的位置参数填充,再把后来提供的参数填入空格。
local function bind (fn, bound)
  return function (...)
    local argt, unbound = {}, table.pack (...)

    -- Inline `argt = copy (bound)`...
    local n = bound.n or 0
    for k, v in pairs (bound) do
      -- ...but only copy integer keys.
      if type (k) == "number" and math.ceil (k) == k then
        argt[k] = v
        n = k > n and k or n  -- Inline `n = maxn (unbound)` in same pass.
      end
    end

    -- Bind *unbound* parameters sequentially into *argt* gaps.
    local i = 1
    for j = 1, unbound.n do
      while argt[i] ~= nil do i = i + 1 end
      argt[i], i = unbound[j], i + 1
    end

    -- Even if there are gaps remaining above *i*, pass at least *n* args.
    if n >= i then return fn (unpack (argt, 1, n)) end

    -- Otherwise, we filled gaps beyond *n*, and pass that many args.
    return fn (unpack (argt, 1, i - 1))
  end
end

local function curry (fn, n)
  if n <= 1 then
    return fn
  else
    return function (x)
             return curry (bind (fn, {x}), n - 1)
           end
  end
end


fun = fun or {}

-- 返回一个部分应用参数的函数。理念参见函数式编程的 curried function
-- 第二个参数是一个 table,里面指定要提前设定的参数的 index 及其值。
-- 例:
--[[
local pow = function(x, n)
  res = 1
  for i = 1, n do
    res = x * res
  end
  return res
end
local cube = fun.bind (pow, {[2] = 3})
print(cube(2)); -- 8,即 2^3
]]
fun.bind = bind
-- bind 的包装版本,传入参数会自动填充最后一个参数
-- 例:
--[[
local pow = function(x, n)
  res = 1
  for i = 1, n do
    res = x * res
  end
  return res
end

local cube2 = curry(pow, 3)
print(cube(2)) -- 8, 即 2^3
]]
fun.curry = curry

你可能感兴趣的:(Lua)