diff options
Diffstat (limited to 'lua/mason-core/functional')
| -rw-r--r-- | lua/mason-core/functional/data.lua | 30 | ||||
| -rw-r--r-- | lua/mason-core/functional/function.lua | 89 | ||||
| -rw-r--r-- | lua/mason-core/functional/init.lua | 112 | ||||
| -rw-r--r-- | lua/mason-core/functional/list.lua | 175 | ||||
| -rw-r--r-- | lua/mason-core/functional/logic.lua | 63 | ||||
| -rw-r--r-- | lua/mason-core/functional/number.lua | 34 | ||||
| -rw-r--r-- | lua/mason-core/functional/relation.lua | 17 | ||||
| -rw-r--r-- | lua/mason-core/functional/string.lua | 74 | ||||
| -rw-r--r-- | lua/mason-core/functional/table.lua | 45 | ||||
| -rw-r--r-- | lua/mason-core/functional/type.lua | 14 |
10 files changed, 653 insertions, 0 deletions
diff --git a/lua/mason-core/functional/data.lua b/lua/mason-core/functional/data.lua new file mode 100644 index 00000000..da6f1efd --- /dev/null +++ b/lua/mason-core/functional/data.lua @@ -0,0 +1,30 @@ +local _ = {} + +_.table_pack = function(...) + return { n = select("#", ...), ... } +end + +---@generic T : string +---@param values T[] +---@return table<T, T> +_.enum = function(values) + local result = {} + for i = 1, #values do + local v = values[i] + result[v] = v + end + return result +end + +---@generic T +---@param list T[] +---@return table<T, boolean> +_.set_of = function(list) + local set = {} + for i = 1, #list do + set[list[i]] = true + end + return set +end + +return _ diff --git a/lua/mason-core/functional/function.lua b/lua/mason-core/functional/function.lua new file mode 100644 index 00000000..e85081ce --- /dev/null +++ b/lua/mason-core/functional/function.lua @@ -0,0 +1,89 @@ +local data = require "mason-core.functional.data" + +local _ = {} + +---@generic T : fun(...) +---@param fn T +---@param arity integer +---@return T +_.curryN = function(fn, arity) + return function(...) + local args = data.table_pack(...) + if args.n >= arity then + return fn(unpack(args, 1, arity)) + else + return _.curryN(_.partial(fn, unpack(args, 1, args.n)), arity - args.n) + end + end +end + +_.compose = function(...) + local functions = data.table_pack(...) + assert(functions.n > 0, "compose requires at least one function") + return function(...) + local result = data.table_pack(...) + for i = functions.n, 1, -1 do + result = data.table_pack(functions[i](unpack(result, 1, result.n))) + end + return unpack(result, 1, result.n) + end +end + +---@generic T +---@param fn fun(...): T +---@return fun(...): T +_.partial = function(fn, ...) + local bound_args = data.table_pack(...) + return function(...) + local args = data.table_pack(...) + local merged_args = {} + for i = 1, bound_args.n do + merged_args[i] = bound_args[i] + end + for i = 1, args.n do + merged_args[bound_args.n + i] = args[i] + end + return fn(unpack(merged_args, 1, bound_args.n + args.n)) + end +end + +_.identity = function(a) + return a +end + +_.always = function(a) + return function() + return a + end +end + +_.T = _.always(true) +_.F = _.always(false) + +---@generic T : fun(...) +---@param fn T +---@param cache_key_generator (fun(...): string | nil)|nil +---@return T +_.memoize = function(fn, cache_key_generator) + cache_key_generator = cache_key_generator or _.identity + local cache = {} + return function(...) + local key = cache_key_generator(...) + if not cache[key] then + cache[key] = data.table_pack(fn(...)) + end + return unpack(cache[key], 1, cache[key].n) + end +end + +---@generic T +---@param fn fun(): T +---@return fun(): T +_.lazy = function(fn) + local memoized = _.memoize(fn, _.always "lazyval") + return function() + return memoized() + end +end + +return _ diff --git a/lua/mason-core/functional/init.lua b/lua/mason-core/functional/init.lua new file mode 100644 index 00000000..a7b0a369 --- /dev/null +++ b/lua/mason-core/functional/init.lua @@ -0,0 +1,112 @@ +local _ = {} + +-- data +local data = require "mason-core.functional.data" +_.table_pack = data.table_pack +_.enum = data.enum +_.set_of = data.set_of + +-- function +local fun = require "mason-core.functional.function" +_.curryN = fun.curryN +_.compose = fun.compose +_.partial = fun.partial +_.identity = fun.identity +_.always = fun.always +_.T = fun.T +_.F = fun.F +_.memoize = fun.memoize +_.lazy = fun.lazy + +-- list +local list = require "mason-core.functional.list" +_.reverse = list.reverse +_.list_not_nil = list.list_not_nil +_.list_copy = list.list_copy +_.find_first = list.find_first +_.any = list.any +_.filter = list.filter +_.map = list.map +_.filter_map = list.filter_map +_.each = list.each +_.concat = list.concat +_.append = list.append +_.prepend = list.prepend +_.zip_table = list.zip_table +_.nth = list.nth +_.head = list.head +_.length = list.length +_.flatten = list.flatten +_.sort_by = list.sort_by +_.join = list.join + +-- relation +local relation = require "mason-core.functional.relation" +_.equals = relation.equals +_.prop_eq = relation.prop_eq +_.prop_satisfies = relation.prop_satisfies + +-- logic +local logic = require "mason-core.functional.logic" +_.all_pass = logic.all_pass +_.any_pass = logic.any_pass +_.if_else = logic.if_else +_.is_not = logic.is_not +_.complement = logic.complement +_.cond = logic.cond + +-- number +local number = require "mason-core.functional.number" +_.negate = number.negate +_.gt = number.gt +_.gte = number.gte +_.lt = number.lt +_.lte = number.lte +_.inc = number.inc +_.dec = number.dec + +-- string +local string = require "mason-core.functional.string" +_.matches = string.matches +_.format = string.format +_.split = string.split +_.gsub = string.gsub +_.trim = string.trim +_.dedent = string.dedent +_.starts_with = string.starts_with + +-- table +local tbl = require "mason-core.functional.table" +_.prop = tbl.prop +_.pick = tbl.pick +_.keys = tbl.keys +_.size = tbl.size +_.to_pairs = tbl.to_pairs +_.invert = tbl.invert + +-- type +local typ = require "mason-core.functional.type" +_.is_nil = typ.is_nil +_.is = typ.is + +-- TODO do something else with these + +_.coalesce = function(...) + local args = _.table_pack(...) + for i = 1, args.n do + local variable = args[i] + if variable ~= nil then + return variable + end + end +end + +_.when = function(condition, value) + return condition and value or nil +end + +_.lazy_when = function(condition, value) + return condition and value() or nil +end + +return _ diff --git a/lua/mason-core/functional/list.lua b/lua/mason-core/functional/list.lua new file mode 100644 index 00000000..14db386e --- /dev/null +++ b/lua/mason-core/functional/list.lua @@ -0,0 +1,175 @@ +local fun = require "mason-core.functional.function" +local data = require "mason-core.functional.data" + +local _ = {} + +---@generic T +---@param list T[] +---@return T[] +_.reverse = function(list) + local result = {} + for i = #list, 1, -1 do + result[#result + 1] = list[i] + end + return result +end + +_.list_not_nil = function(...) + local result = {} + local args = data.table_pack(...) + for i = 1, args.n do + if args[i] ~= nil then + result[#result + 1] = args[i] + end + end + return result +end + +---@generic T +---@param predicate fun(item: T): boolean +---@param list T[] +---@return T | nil +_.find_first = fun.curryN(function(predicate, list) + local result + for i = 1, #list do + local entry = list[i] + if predicate(entry) then + return entry + end + end + return result +end, 2) + +---@generic T +---@param predicate fun(item: T): boolean +---@param list T[] +---@return boolean +_.any = fun.curryN(function(predicate, list) + for i = 1, #list do + if predicate(list[i]) then + return true + end + end + return false +end, 2) + +---@generic T +---@type fun(filter_fn: (fun(item: T): boolean), items: T[]): T[] +_.filter = fun.curryN(vim.tbl_filter, 2) + +---@generic T, U +---@type fun(map_fn: (fun(item: T): U), items: T[]): U[] +_.map = fun.curryN(vim.tbl_map, 2) + +_.flatten = fun.curryN(vim.tbl_flatten, 1) + +---@generic T +---@param map_fn fun(item: T): Optional +---@param list T[] +---@return any[] +_.filter_map = fun.curryN(function(map_fn, list) + local ret = {} + for i = 1, #list do + map_fn(list[i]):if_present(function(value) + ret[#ret + 1] = value + end) + end + return ret +end, 2) + +---@generic T +---@param fn fun(item: T, index: integer) +---@param list T[] +_.each = fun.curryN(function(fn, list) + for k, v in pairs(list) do + fn(v, k) + end +end, 2) + +---@generic T +---@param list T[] +---@return T[] @A shallow copy of the list. +_.list_copy = _.map(fun.identity) + +_.concat = fun.curryN(function(a, b) + if type(a) == "table" then + assert(type(b) == "table", "concat: expected table") + return vim.list_extend(_.list_copy(a), b) + elseif type(a) == "string" then + assert(type(b) == "string", "concat: expected string") + return a .. b + end +end, 2) + +---@generic T +---@param value T +---@param list T[] +---@return T[] +_.append = fun.curryN(function(value, list) + local list_copy = _.list_copy(list) + list_copy[#list_copy + 1] = value + return list_copy +end, 2) + +---@generic T +---@param value T +---@param list T[] +---@return T[] +_.prepend = fun.curryN(function(value, list) + local list_copy = _.list_copy(list) + table.insert(list_copy, 1, value) + return list_copy +end, 2) + +---@generic T +---@generic U +---@param keys T[] +---@param values U[] +---@return table<T, U> +_.zip_table = fun.curryN(function(keys, values) + local res = {} + for i, key in ipairs(keys) do + res[key] = values[i] + end + return res +end, 2) + +---@generic T +---@param offset number +---@param value T[]|string +---@return T|string|nil +_.nth = fun.curryN(function(offset, value) + local index = offset < 0 and (#value + (offset + 1)) or offset + if type(value) == "string" then + return string.sub(value, index, index) + else + return value[index] + end +end, 2) + +_.head = _.nth(1) + +---@param value string|any[] +_.length = function(value) + return #value +end + +---@generic T +---@param comp fun(item: T): any +---@param list T[] +---@return T[] +_.sort_by = fun.curryN(function(comp, list) + local copied_list = _.list_copy(list) + table.sort(copied_list, function(a, b) + return comp(a) < comp(b) + end) + return copied_list +end, 2) + +---@param sep string +---@param list any[] +_.join = fun.curryN(function(sep, list) + return table.concat(list, sep) +end, 2) + +return _ diff --git a/lua/mason-core/functional/logic.lua b/lua/mason-core/functional/logic.lua new file mode 100644 index 00000000..0e0044d5 --- /dev/null +++ b/lua/mason-core/functional/logic.lua @@ -0,0 +1,63 @@ +local fun = require "mason-core.functional.function" + +local _ = {} + +---@generic T +---@param predicates (fun(item: T): boolean)[] +---@return fun(item: T): boolean +_.all_pass = fun.curryN(function(predicates, item) + for i = 1, #predicates do + if not predicates[i](item) then + return false + end + end + return true +end, 2) + +---@generic T +---@param predicates (fun(item: T): boolean)[] +---@return fun(item: T): boolean +_.any_pass = fun.curryN(function(predicates, item) + for i = 1, #predicates do + if predicates[i](item) then + return true + end + end + return false +end, 2) + +---@generic T +---@param predicate fun(item: T): boolean +---@param on_true fun(item: T): any +---@param on_false fun(item: T): any +---@param value T +_.if_else = fun.curryN(function(predicate, on_true, on_false, value) + if predicate(value) then + return on_true(value) + else + return on_false(value) + end +end, 4) + +---@param value boolean +_.is_not = function(value) + return not value +end + +---@generic T +---@param predicate fun(value: T): boolean +---@param value T +_.complement = fun.curryN(function(predicate, value) + return not predicate(value) +end, 2) + +_.cond = fun.curryN(function(predicate_transformer_pairs, value) + for _, pair in ipairs(predicate_transformer_pairs) do + local predicate, transformer = pair[1], pair[2] + if predicate(value) then + return transformer(value) + end + end +end, 2) + +return _ diff --git a/lua/mason-core/functional/number.lua b/lua/mason-core/functional/number.lua new file mode 100644 index 00000000..11e8f88a --- /dev/null +++ b/lua/mason-core/functional/number.lua @@ -0,0 +1,34 @@ +local fun = require "mason-core.functional.function" + +local _ = {} + +---@param number number +_.negate = function(number) + return -number +end + +_.gt = fun.curryN(function(number, value) + return value > number +end, 2) + +_.gte = fun.curryN(function(number, value) + return value >= number +end, 2) + +_.lt = fun.curryN(function(number, value) + return value < number +end, 2) + +_.lte = fun.curryN(function(number, value) + return value <= number +end, 2) + +_.inc = fun.curryN(function(increment, value) + return value + increment +end, 2) + +_.dec = fun.curryN(function(decrement, value) + return value - decrement +end, 2) + +return _ diff --git a/lua/mason-core/functional/relation.lua b/lua/mason-core/functional/relation.lua new file mode 100644 index 00000000..94913a13 --- /dev/null +++ b/lua/mason-core/functional/relation.lua @@ -0,0 +1,17 @@ +local fun = require "mason-core.functional.function" + +local _ = {} + +_.equals = fun.curryN(function(expected, value) + return value == expected +end, 2) + +_.prop_eq = fun.curryN(function(property, value, tbl) + return tbl[property] == value +end, 3) + +_.prop_satisfies = fun.curryN(function(predicate, property, tbl) + return predicate(tbl[property]) +end, 3) + +return _ diff --git a/lua/mason-core/functional/string.lua b/lua/mason-core/functional/string.lua new file mode 100644 index 00000000..7726c8e1 --- /dev/null +++ b/lua/mason-core/functional/string.lua @@ -0,0 +1,74 @@ +local fun = require "mason-core.functional.function" + +local _ = {} + +---@param pattern string +---@param str string +_.matches = fun.curryN(function(pattern, str) + return str:match(pattern) ~= nil +end, 2) + +---@param template string +---@param str string +_.format = fun.curryN(function(template, str) + return template:format(str) +end, 2) + +---@param sep string +---@param str string +_.split = fun.curryN(function(sep, str) + return vim.split(str, sep) +end, 2) + +---@param pattern string +---@param repl string|function|table +---@param str string +_.gsub = fun.curryN(function(pattern, repl, str) + return string.gsub(str, pattern, repl) +end, 3) + +_.trim = fun.curryN(function(str) + return vim.trim(str) +end, 1) + +---https://github.com/nvim-lua/nvim-package-specification/blob/93475e47545b579fd20b6c5ce13c4163e7956046/lua/packspec/schema.lua#L8-L37 +---@param str string +---@return string +_.dedent = fun.curryN(function(str) + local lines = {} + local indent = nil + + for line in str:gmatch "[^\n]*\n?" do + if indent == nil then + if not line:match "^%s*$" then + -- save pattern for indentation from the first non-empty line + indent, line = line:match "^(%s*)(.*)$" + indent = "^" .. indent .. "(.*)$" + table.insert(lines, line) + end + else + if line:match "^%s*$" then + -- replace empty lines with a single newline character. + -- empty lines are handled separately to allow the + -- closing "]]" to be one indentation level lower. + table.insert(lines, "\n") + else + -- strip indentation on non-empty lines + line = assert(line:match(indent), "inconsistent indentation") + table.insert(lines, line) + end + end + end + + lines = table.concat(lines) + -- trim trailing whitespace + return lines:match "^(.-)%s*$" +end, 1) + +---@param prefix string +---@str string +_.starts_with = fun.curryN(function(prefix, str) + return vim.startswith(str, prefix) +end, 2) + +return _ diff --git a/lua/mason-core/functional/table.lua b/lua/mason-core/functional/table.lua new file mode 100644 index 00000000..65d05cc8 --- /dev/null +++ b/lua/mason-core/functional/table.lua @@ -0,0 +1,45 @@ +local fun = require "mason-core.functional.function" + +local _ = {} + +---@param index any +---@param tbl table +_.prop = fun.curryN(function(index, tbl) + return tbl[index] +end, 2) + +---@param keys any[] +---@param tbl table +_.pick = fun.curryN(function(keys, tbl) + local ret = {} + for _, key in ipairs(keys) do + ret[key] = tbl[key] + end + return ret +end, 2) + +_.keys = fun.curryN(vim.tbl_keys, 1) +_.size = fun.curryN(vim.tbl_count, 1) + +---@param tbl table<any, any> +---@return any[][] +_.to_pairs = fun.curryN(function(tbl) + local result = {} + for k, v in pairs(tbl) do + result[#result + 1] = { k, v } + end + return result +end, 1) + +---@generic K, V +---@param tbl table<K, V> +---@return table<V, K> +_.invert = fun.curryN(function(tbl) + local result = {} + for k, v in pairs(tbl) do + result[v] = k + end + return result +end, 1) + +return _ diff --git a/lua/mason-core/functional/type.lua b/lua/mason-core/functional/type.lua new file mode 100644 index 00000000..e3bf5fe7 --- /dev/null +++ b/lua/mason-core/functional/type.lua @@ -0,0 +1,14 @@ +local fun = require "mason-core.functional.function" +local rel = require "mason-core.functional.relation" + +local _ = {} + +_.is_nil = rel.equals(nil) + +---@param typ type +---@param value any +_.is = fun.curryN(function(typ, value) + return type(value) == typ +end, 2) + +return _ |
