aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/async.lua769
-rw-r--r--lua/nvim-treesitter/install.lua117
2 files changed, 760 insertions, 126 deletions
diff --git a/lua/nvim-treesitter/async.lua b/lua/nvim-treesitter/async.lua
index 8a5e7c8af..e78c0f964 100644
--- a/lua/nvim-treesitter/async.lua
+++ b/lua/nvim-treesitter/async.lua
@@ -1,112 +1,725 @@
-local co = coroutine
+local pcall = copcall or pcall
+--- @param ... any
+--- @return {[integer]: any, n: integer}
+local function pack_len(...)
+ return { n = select('#', ...), ... }
+end
+
+--- like unpack() but use the length set by F.pack_len if present
+--- @param t? { [integer]: any, n?: integer }
+--- @param first? integer
+--- @return ...any
+local function unpack_len(t, first)
+ if t then
+ return unpack(t, first or 1, t.n or table.maxn(t))
+ end
+end
+
+--- @class async
local M = {}
----Executes a future with a callback when it is done
----@param func function
----@param callback function
----@param ... unknown
-local function execute(func, callback, ...)
- local thread = co.create(func)
+--- Weak table to keep track of running tasks
+--- @type table<thread,async.Task?>
+local threads = setmetatable({}, { __mode = 'k' })
+
+--- @return async.Task?
+local function running()
+ local task = threads[coroutine.running()]
+ if task and not (task:_completed() or task._closing) then
+ return task
+ end
+end
+
+--- Base class for async tasks. Async functions should return a subclass of
+--- this. This is designed specifically to be a base class of uv_handle_t
+--- @class async.Handle
+--- @field close fun(self: async.Handle, callback?: fun())
+--- @field is_closing? fun(self: async.Handle): boolean
+
+--- @alias async.CallbackFn fun(...: any): async.Handle?
+
+--- @class async.Task : async.Handle
+--- @field package _callbacks table<integer,fun(err?: any, ...: any)>
+--- @field package _callback_pos integer
+--- @field private _thread thread
+---
+--- Tasks can call other async functions (task of callback functions)
+--- when we are waiting on a child, we store the handle to it here so we can
+--- cancel it.
+--- @field private _current_child? async.Handle
+---
+--- Error result of the task is an error occurs.
+--- Must use `await` to get the result.
+--- @field private _err? any
+---
+--- Result of the task.
+--- Must use `await` to get the result.
+--- @field private _result? any[]
+local Task = {}
+Task.__index = Task
+
+--- @private
+--- @param func function
+--- @return async.Task
+function Task._new(func)
+ local thread = coroutine.create(func)
+
+ local self = setmetatable({
+ _closing = false,
+ _thread = thread,
+ _callbacks = {},
+ _callback_pos = 1,
+ }, Task)
+
+ threads[thread] = self
+
+ return self
+end
+
+--- @param callback fun(err?: any, ...: any)
+function Task:await(callback)
+ if self._closing then
+ callback('closing')
+ elseif self:_completed() then -- TODO(lewis6991): test
+ -- Already finished or closed
+ callback(self._err, unpack_len(self._result))
+ else
+ self._callbacks[self._callback_pos] = callback
+ self._callback_pos = self._callback_pos + 1
+ end
+end
+
+--- @package
+function Task:_completed()
+ return (self._err or self._result) ~= nil
+end
+
+-- Use max 32-bit signed int value to avoid overflow on 32-bit systems.
+-- Do not use `math.huge` as it is not interpreted as a positive integer on all
+-- platforms.
+local MAX_TIMEOUT = 2 ^ 31 - 1
+
+--- Synchronously wait (protected) for a task to finish (blocking)
+---
+--- If an error is returned, `Task:traceback()` can be used to get the
+--- stack trace of the error.
+---
+--- Example:
+--- ```lua
+---
+--- local ok, err_or_result = task:pwait(10)
+---
+--- if not ok then
+--- error(task:traceback(err_or_result))
+--- end
+---
+--- local _, result = assert(task:pwait(10))
+--- ```
+---
+--- Can be called if a task is closing.
+--- @param timeout? integer
+--- @return boolean status
+--- @return any ... result or error
+function Task:pwait(timeout)
+ local done = vim.wait(timeout or MAX_TIMEOUT, function()
+ -- Note we use self:_completed() instead of self:await() to avoid creating a
+ -- callback. This avoids having to cleanup/unregister any callback in the
+ -- case of a timeout.
+ return self:_completed()
+ end)
+
+ if not done then
+ return false, 'timeout'
+ elseif self._err then
+ return false, self._err
+ else
+ return true, unpack_len(self._result)
+ end
+end
+
+--- Synchronously wait for a task to finish (blocking)
+---
+--- Example:
+--- ```lua
+--- local result = task:wait(10) -- wait for 10ms or else error
+---
+--- local result = task:wait() -- wait indefinitely
+--- ```
+--- @param timeout? integer Timeout in milliseconds
+--- @return any ... result
+function Task:wait(timeout)
+ local res = pack_len(self:pwait(timeout))
+ local stat = res[1]
+
+ if not stat then
+ error(self:traceback(res[2]))
+ end
+
+ return unpack_len(res, 2)
+end
+
+--- @private
+--- @param msg? string
+--- @param _lvl? integer
+--- @return string
+function Task:_traceback(msg, _lvl)
+ _lvl = _lvl or 0
+
+ local thread = ('[%s] '):format(self._thread)
+
+ local child = self._current_child
+ if getmetatable(child) == Task then
+ --- @cast child async.Task
+ msg = child:_traceback(msg, _lvl + 1)
+ end
+
+ local tblvl = getmetatable(child) == Task and 2 or nil
+ msg = (msg or '') .. debug.traceback(self._thread, '', tblvl):gsub('\n\t', '\n\t' .. thread)
+
+ if _lvl == 0 then
+ --- @type string
+ msg = msg
+ :gsub('\nstack traceback:\n', '\nSTACK TRACEBACK:\n', 1)
+ :gsub('\nstack traceback:\n', '\n')
+ :gsub('\nSTACK TRACEBACK:\n', '\nstack traceback:\n', 1)
+ end
+
+ return msg
+end
- local function step(...)
- local ret = { co.resume(thread, ...) }
- ---@type boolean, any
- local stat, nargs_or_err = unpack(ret)
+--- Get the traceback of a task when it is not active.
+--- Will also get the traceback of nested tasks.
+---
+--- @param msg? string
+--- @return string
+function Task:traceback(msg)
+ return self:_traceback(msg)
+end
- if not stat then
- error(
- string.format(
- 'The coroutine failed with this message: %s\n%s',
- nargs_or_err,
- debug.traceback(thread)
- )
- )
+--- If a task completes with an error, raise the error
+function Task:raise_on_error()
+ self:await(function(err)
+ if err then
+ error(self:_traceback(err), 0)
end
+ end)
+ return self
+end
- if co.status(thread) == 'dead' then
- if callback then
- callback(unpack(ret, 3, table.maxn(ret)))
+--- @private
+--- @param err? any
+--- @param result? {[integer]: any, n: integer}
+function Task:_finish(err, result)
+ self._current_child = nil
+ self._err = err
+ self._result = result
+ threads[self._thread] = nil
+
+ local errs = {} --- @type string[]
+ for _, cb in pairs(self._callbacks) do
+ --- @type boolean, string
+ local ok, cb_err = pcall(cb, err, unpack_len(result))
+ if not ok then
+ errs[#errs + 1] = cb_err
+ end
+ end
+
+ if #errs > 0 then
+ error(table.concat(errs, '\n'), 0)
+ end
+end
+
+--- @return boolean
+function Task:is_closing()
+ return self._closing
+end
+
+--- Close the task and all its children.
+--- If callback is provided it will run asynchronously,
+--- else it will run synchronously.
+---
+--- @param callback? fun()
+function Task:close(callback)
+ if self:_completed() then
+ if callback then
+ callback()
+ end
+ return
+ end
+
+ if self._closing then
+ return
+ end
+
+ self._closing = true
+
+ if callback then -- async
+ if self._current_child then
+ self._current_child:close(function()
+ self:_finish('closed')
+ callback()
+ end)
+ else
+ self:_finish('closed')
+ callback()
+ end
+ else -- sync
+ if self._current_child then
+ self._current_child:close(function()
+ self:_finish('closed')
+ end)
+ else
+ self:_finish('closed')
+ end
+ vim.wait(0, function()
+ return self:_completed()
+ end)
+ end
+end
+
+--- @param obj any
+--- @return boolean
+local function is_async_handle(obj)
+ local ty = type(obj)
+ return (ty == 'table' or ty == 'userdata') and vim.is_callable(obj.close)
+end
+
+--- @param ... any
+function Task:_resume(...)
+ --- @type [boolean, string|async.CallbackFn]
+ local ret = pack_len(coroutine.resume(self._thread, ...))
+ local stat = ret[1]
+
+ if not stat then
+ -- Coroutine had error
+ self:_finish(ret[2])
+ elseif coroutine.status(self._thread) == 'dead' then
+ -- Coroutine finished
+ local result = pack_len(unpack_len(ret, 2))
+ self:_finish(nil, result)
+ else
+ local fn = ret[2]
+ --- @cast fn -string
+
+ -- TODO(lewis6991): refine error handler to be more specific
+ local ok, r
+ ok, r = pcall(fn, function(...)
+ if is_async_handle(r) then
+ --- @cast r async.Handle
+ -- We must close children before we resume to ensure
+ -- all resources are collected.
+ local args = pack_len(...)
+ r:close(function()
+ self:_resume(unpack_len(args))
+ end)
+ else
+ self:_resume(...)
end
- return
+ end)
+
+ if not ok then
+ self:_finish(r)
+ elseif is_async_handle(r) then
+ self._current_child = r
end
+ end
+end
+
+--- @return 'running'|'suspended'|'normal'|'dead'?
+function Task:status()
+ return coroutine.status(self._thread)
+end
+
+--- Run a function in an async context, asynchronously.
+---
+--- Examples:
+--- ```lua
+--- -- The two below blocks are equivalent:
+---
+--- -- Run a uv function and wait for it
+--- local stat = async.arun(function()
+--- return async.await(2, vim.uv.fs_stat, 'foo.txt')
+--- end):wait()
+---
+--- -- Since uv functions have sync versions. You can just do:
+--- local stat = vim.fs_stat('foo.txt')
+--- ```
+--- @param func function
+--- @param ... any
+--- @return async.Task
+function M.arun(func, ...)
+ local task = Task._new(func)
+ task:_resume(...)
+ return task
+end
+
+--- @class async.TaskFun
+--- @field package _fun fun(...: any): any
+--- @operator call(...): any
+local TaskFun = {}
+TaskFun.__index = TaskFun
+
+function TaskFun:__call(...)
+ return M.arun(self._fun, ...)
+end
+
+--- Create an async function
+--- @param fun function
+--- @return async.TaskFun
+function M.async(fun)
+ return setmetatable({ _fun = fun }, TaskFun)
+end
+
+--- Returns the status of a task’s thread.
+---
+--- @param task? async.Task
+--- @return 'running'|'suspended'|'normal'|'dead'?
+function M.status(task)
+ task = task or running()
+ if task then
+ assert(getmetatable(task) == Task, 'Expected Task')
+ return task:status()
+ end
+end
+
+--- @async
+--- @generic R1, R2, R3, R4
+--- @param fun fun(callback: fun(r1: R1, r2: R2, r3: R3, r4: R4)): any?
+--- @return R1, R2, R3, R4
+local function yield(fun)
+ assert(type(fun) == 'function', 'Expected function')
+ return coroutine.yield(fun)
+end
+
+--- @async
+--- @param task async.Task
+--- @return any ...
+local function await_task(task)
+ --- @param callback fun(err?: string, ...: any)
+ --- @return function
+ local res = pack_len(yield(function(callback)
+ task:await(callback)
+ return task
+ end))
+
+ local err = res[1]
+
+ if err then
+ -- TODO(lewis6991): what is the correct level to pass?
+ error(err, 0)
+ end
+
+ return unpack_len(res, 2)
+end
+
+--- Asynchronous blocking wait
+--- @param argc integer
+--- @param fun async.CallbackFn
+--- @param ... any func arguments
+--- @return any ...
+local function await_cbfun(argc, fun, ...)
+ local args = pack_len(...)
+
+ --- @param callback fun(...:any)
+ --- @return any?
+ return yield(function(callback)
+ args[argc] = callback
+ args.n = math.max(args.n, argc)
+ return fun(unpack_len(args))
+ end)
+end
- ---@type function, any[]
- local fn, args = ret[3], { unpack(ret, 4, table.maxn(ret)) }
- args[nargs_or_err] = step
- fn(unpack(args, 1, nargs_or_err))
+--- @param taskfun async.TaskFun
+--- @param ... any
+--- @return any ...
+local function await_taskfun(taskfun, ...)
+ return taskfun._fun(...)
+end
+
+--- Asynchronous blocking wait
+---
+--- Example:
+--- ```lua
+--- local task = async.arun(function()
+--- return 1, 'a'
+--- end)
+---
+--- local task_fun = async.async(function(arg)
+--- return 2, 'b', arg
+--- end)
+---
+--- async.arun(function()
+--- do -- await a callback function
+--- async.await(1, vim.schedule)
+--- end
+---
+--- do -- await a task (new async context)
+--- local n, s = async.await(task)
+--- assert(n == 1 and s == 'a')
+--- end
+---
+--- do -- await a started task function (new async context)
+--- local n, s, arg = async.await(task_fun('A'))
+--- assert(n == 2)
+--- assert(s == 'b')
+--- assert(args == 'A')
+--- end
+---
+--- do -- await a task function (re-using the current async context)
+--- local n, s, arg = async.await(task_fun, 'B')
+--- assert(n == 2)
+--- assert(s == 'b')
+--- assert(args == 'B')
+--- end
+--- end)
+--- ```
+--- @async
+--- @overload fun(argc: integer, func: async.CallbackFn, ...:any): any ...
+--- @overload fun(task: async.Task): any ...
+--- @overload fun(taskfun: async.TaskFun): any ...
+function M.await(...)
+ assert(running(), 'Not in async context')
+
+ local arg1 = select(1, ...)
+
+ if type(arg1) == 'number' then
+ return await_cbfun(...)
+ elseif getmetatable(arg1) == Task then
+ return await_task(...)
+ elseif getmetatable(arg1) == TaskFun then
+ return await_taskfun(...)
end
- step(...)
+ error('Invalid arguments, expected Task or (argc, func) got: ' .. type(arg1), 2)
end
--- Creates an async function with a callback style function.
----@generic F: function
----@param func F
----@param argc integer
----@return F
-function M.wrap(func, argc)
- vim.validate('func', func, 'function')
- vim.validate('argc', argc, 'number')
- ---@param ... unknown
- ---@return unknown
+---
+--- Example:
+---
+--- ```lua
+--- --- Note the callback argument is not present in the return function
+--- --- @type fun(timeout: integer)
+--- local sleep = async.awrap(2, function(timeout, callback)
+--- local timer = vim.uv.new_timer()
+--- timer:start(timeout * 1000, 0, callback)
+--- -- uv_timer_t provides a close method so timer will be
+--- -- cleaned up when this function finishes
+--- return timer
+--- end)
+---
+--- async.arun(function()
+--- print('hello')
+--- sleep(2)
+--- print('world')
+--- end)
+--- ```
+---
+--- local atimer = async.awrap(
+--- @param argc integer
+--- @param func async.CallbackFn
+--- @return async function
+function M.awrap(argc, func)
+ assert(type(argc) == 'number')
+ assert(type(func) == 'function')
+ --- @async
return function(...)
- return co.yield(argc, func, ...)
+ return M.await(argc, func, ...)
end
end
----Use this to create a function which executes in an async context but
----called from a non-async context. Inherently this cannot return anything
----since it is non-blocking
----@generic F: function
----@param func async F
----@param nargs? integer
----@return F
-function M.sync(func, nargs)
- nargs = nargs or 0
- return function(...)
- local callback = select(nargs + 1, ...)
- execute(func, callback, unpack({ ... }, 1, nargs))
+if vim.schedule then
+ --- An async function that when called will yield to the Neovim scheduler to be
+ --- able to call the API.
+ M.schedule = M.awrap(1, vim.schedule)
+end
+
+--- Create a function that runs a function when it is garbage collected.
+--- @generic F
+--- @param f F
+--- @param gc fun()
+--- @return F
+local function gc_fun(f, gc)
+ local proxy = newproxy(true)
+ local proxy_mt = getmetatable(proxy)
+ proxy_mt.__gc = gc
+ proxy_mt.__call = function(_, ...)
+ return f(...)
end
+
+ return proxy
end
----@param n integer max number of concurrent jobs
----@param interrupt_check? function
----@param thunks function[]
----@return any
-function M.join(n, interrupt_check, thunks)
- return co.yield(1, function(finish)
- if #thunks == 0 then
- return finish()
+--- @param task_cbs table<async.Task,function>
+local function gc_cbs(task_cbs)
+ for task, tcb in pairs(task_cbs) do
+ for j, cb in pairs(task._callbacks) do
+ if cb == tcb then
+ task._callbacks[j] = nil
+ break
+ end
end
+ end
+end
+
+--- @async
+--- Example:
+--- ```lua
+--- local task1 = async.arun(function()
+--- return 1, 'a'
+--- end)
+---
+--- local task2 = async.arun(function()
+--- return 1, 'a'
+--- end)
+---
+--- local task3 = async.arun(function()
+--- error('task3 error')
+--- end)
+---
+--- async.arun(function()
+--- for i, err, r1, r2 in async.iter({task1, task2, task3})
+--- print(i, err, r1, r2)
+--- end
+--- end)
+--- ```
+---
+--- Prints:
+--- ```
+--- 1 nil 1 'a'
+--- 2 nil 2 'b'
+--- 3 'task3 error' nil nil
+--- ```
+---
+--- @param tasks async.Task[]
+--- @return fun(): (integer?, any?, ...)
+function M.iter(tasks)
+ assert(running(), 'Not in async context')
- local remaining = { select(n + 1, unpack(thunks)) }
- local to_go = #thunks
+ local results = {} --- @type [integer, any, ...][]
- local ret = {} ---@type any[]
+ -- Iter blocks in an async context so only one waiter is needed
+ local waiter = nil
+ local task_cbs = {} --- @type table<async.Task,function>
+ local remaining = #tasks
- local function cb(...)
- ret[#ret + 1] = { ... }
- to_go = to_go - 1
- if to_go == 0 then
- finish(ret)
- elseif not interrupt_check or not interrupt_check() then
- if #remaining > 0 then
- local next_task = table.remove(remaining)
- next_task(cb)
- end
+ --- If can_gc_cbs is true, then the iterator function has been garbage
+ --- collected and means any awaiters can also be garbage collected. The
+ --- only time we can't do this is if with the special case when iter() is
+ --- called anonymously (`local i = async.iter(tasks)()`), so we should not
+ --- garbage collect the callbacks until at least one awaiter is called.
+ local can_gc_cbs = false
+
+ for i, task in ipairs(tasks) do
+ local function cb(err, ...)
+ if can_gc_cbs == true then
+ gc_cbs(task_cbs)
+ end
+
+ local callback = waiter
+
+ -- Clear waiter before calling it
+ waiter = nil
+
+ remaining = remaining - 1
+ if callback then
+ -- Iterator is waiting, yield to it
+ callback(i, err, ...)
+ else
+ -- Task finished before Iterator was called. Store results.
+ table.insert(results, pack_len(i, err, ...))
end
end
- for i = 1, math.min(n, #thunks) do
- thunks[i](cb)
+ task_cbs[task] = cb
+ task:await(cb)
+ end
+
+ return gc_fun(
+ M.awrap(1, function(callback)
+ if next(results) then
+ local res = table.remove(results, 1)
+ callback(unpack_len(res))
+ elseif remaining == 0 then
+ callback() -- finish
+ else
+ assert(not waiter, 'internal error: waiter already set')
+ waiter = callback
+ end
+ end),
+ function()
+ -- Don't gc callbacks just yet. Wait until at least one of them is called.
+ can_gc_cbs = true
end
- end, 1)
+ )
end
----An async function that when called will yield to the Neovim scheduler to be
----able to call the API.
----@type fun()
-M.main = M.wrap(vim.schedule, 1)
+do -- join()
+ --- @param results table<integer,table>
+ --- @param i integer
+ --- @param ... any
+ --- @return boolean
+ local function collect(results, i, ...)
+ if i then
+ results[i] = pack_len(...)
+ end
+ return i ~= nil
+ end
+
+ --- @param iter fun(): ...
+ --- @return table<integer,table>
+ local function drain_iter(iter)
+ local results = {} --- @type table<integer,table>
+ while collect(results, iter()) do
+ end
+ return results
+ end
+
+ --- @async
+ --- Wait for all tasks to finish and return their results.
+ ---
+ --- Example:
+ --- ```lua
+ --- local task1 = async.arun(function()
+ --- return 1, 'a'
+ --- end)
+ ---
+ --- local task2 = async.arun(function()
+ --- return 1, 'a'
+ --- end)
+ ---
+ --- local task3 = async.arun(function()
+ --- error('task3 error')
+ --- end)
+ ---
+ --- async.arun(function()
+ --- local results = async.join({task1, task2, task3})
+ --- print(vim.inspect(results))
+ --- end)
+ --- ```
+ ---
+ --- Prints:
+ --- ```
+ --- {
+ --- [1] = { nil, 1, 'a' },
+ --- [2] = { nil, 2, 'b' },
+ --- [3] = { 'task2 error' },
+ --- }
+ --- ```
+ --- @param tasks async.Task[]
+ --- @return table<integer,[any?,...?]>
+ function M.join(tasks)
+ assert(running(), 'Not in async context')
+ return drain_iter(M.iter(tasks))
+ end
+
+ --- @async
+ --- @param tasks async.Task[]
+ --- @return integer?, any?, ...?
+ function M.joinany(tasks)
+ return M.iter(tasks)()
+ end
+end
return M
diff --git a/lua/nvim-treesitter/install.lua b/lua/nvim-treesitter/install.lua
index 68f4c5ea4..6d60b7921 100644
--- a/lua/nvim-treesitter/install.lua
+++ b/lua/nvim-treesitter/install.lua
@@ -9,23 +9,53 @@ local parsers = require('nvim-treesitter.parsers')
local util = require('nvim-treesitter.util')
---@type fun(path: string, new_path: string, flags?: table): string?
-local uv_copyfile = a.wrap(uv.fs_copyfile, 4)
+local uv_copyfile = a.awrap(4, uv.fs_copyfile)
---@type fun(path: string, mode: integer): string?
-local uv_mkdir = a.wrap(uv.fs_mkdir, 3)
+local uv_mkdir = a.awrap(3, uv.fs_mkdir)
---@type fun(path: string, new_path: string): string?
-local uv_rename = a.wrap(uv.fs_rename, 3)
+local uv_rename = a.awrap(3, uv.fs_rename)
---@type fun(path: string, new_path: string, flags?: table): string?
-local uv_symlink = a.wrap(uv.fs_symlink, 4)
+local uv_symlink = a.awrap(4, uv.fs_symlink)
---@type fun(path: string): string?
-local uv_unlink = a.wrap(uv.fs_unlink, 2)
+local uv_unlink = a.awrap(2, uv.fs_unlink)
local MAX_JOBS = 100
local INSTALL_TIMEOUT = 60000
+--- @async
+--- @param max_jobs integer
+--- @param task_funs async.TaskFun[]
+local function join(max_jobs, task_funs)
+ if #task_funs == 0 then
+ return
+ end
+
+ max_jobs = math.min(max_jobs, #task_funs)
+
+ local remaining = { select(max_jobs + 1, unpack(task_funs)) }
+ local to_go = #task_funs
+
+ a.await(1, function(finish)
+ local function cb()
+ to_go = to_go - 1
+ if to_go == 0 then
+ finish()
+ elseif #remaining > 0 then
+ local next_task = table.remove(remaining)
+ next_task():await(cb)
+ end
+ end
+
+ for i = 1, max_jobs do
+ task_funs[i]():await(cb)
+ end
+ end)
+end
+
---@async
---@param cmd string[]
---@param opts? vim.SystemOpts
@@ -33,8 +63,8 @@ local INSTALL_TIMEOUT = 60000
local function system(cmd, opts)
local cwd = opts and opts.cwd or uv.cwd()
log.trace('running job: (cwd=%s) %s', cwd, table.concat(cmd, ' '))
- local r = a.wrap(vim.system, 3)(cmd, opts) --[[@as vim.SystemCompleted]]
- a.main()
+ local r = a.await(3, vim.system, cmd, opts) --[[@as vim.SystemCompleted]]
+ a.schedule()
if r.stdout and r.stdout ~= '' then
log.trace('stdout -> %s', r.stdout)
end
@@ -190,7 +220,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu
do -- Create tmp dir
logger:debug('Creating temporary directory: %s', tmp)
local err = mkpath(tmp)
- a.main()
+ a.schedule()
if err then
return logger:error('Could not create %s-tmp: %s', project_name, err)
end
@@ -211,7 +241,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu
do -- Remove tarball
logger:debug('Removing %s...', tarball_path)
local err = uv_unlink(tarball_path)
- a.main()
+ a.schedule()
if err then
return logger:error('Could not remove tarball: %s', err)
end
@@ -223,7 +253,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu
local extracted = fs.joinpath(tmp, repo_project_name .. '-' .. dir_rev)
logger:debug('Moving %s to %s/...', extracted, output_dir)
local err = uv_rename(extracted, output_dir)
- a.main()
+ a.schedule()
if err then
return logger:error('Could not rename temp: %s', err)
end
@@ -265,7 +295,7 @@ local function do_install(logger, compile_location, target_location)
end
local err = uv_copyfile(compile_location, target_location)
- a.main()
+ a.schedule()
if err then
return logger:error('Error during parser installation: %s', err)
end
@@ -343,7 +373,7 @@ local function try_install_lang(lang, cache_dir, install_dir, generate)
local queries_src = M.get_package_path('runtime', 'queries', lang)
uv_unlink(queries)
local err = uv_symlink(queries_src, queries, { dir = true, junction = true })
- a.main()
+ a.schedule()
if err then
return logger:error(err)
end
@@ -403,20 +433,20 @@ end
---@field max_jobs? integer
--- Install a parser
+---@async
---@param languages string[]
---@param options? InstallOptions
----@param callback? fun(boolean)
-local function install(languages, options, callback)
+local function install(languages, options)
options = options or {}
local cache_dir = fs.normalize(fn.stdpath('cache'))
local install_dir = config.get_install_dir('parser')
- local tasks = {} ---@type fun()[]
+ local task_funs = {} ---@type async.TaskFun[]
local done = 0
for _, lang in ipairs(languages) do
- tasks[#tasks + 1] = a.sync(function()
- a.main()
+ task_funs[#task_funs + 1] = a.async(function()
+ a.schedule()
local status = install_lang(lang, cache_dir, install_dir, options.force, options.generate)
if status ~= 'failed' then
done = done + 1
@@ -424,29 +454,24 @@ local function install(languages, options, callback)
end)
end
- a.join(options and options.max_jobs or MAX_JOBS, nil, tasks)
- if #tasks > 1 then
- a.main()
- log.info('Installed %d/%d languages', done, #tasks)
- end
- if callback then
- callback(done == #tasks)
+ join(options and options.max_jobs or MAX_JOBS, task_funs)
+ if #task_funs > 1 then
+ a.schedule()
+ log.info('Installed %d/%d languages', done, #task_funs)
end
+ return done == #task_funs
end
---@param languages string[]|string
---@param options? InstallOptions
----@param callback? fun(boolean)
-M.install = a.sync(function(languages, options, callback)
+M.install = a.async(function(languages, options)
reload_parsers()
languages = config.norm_languages(languages, { unsupported = true })
- install(languages, options, callback)
-end, 3)
+ return install(languages, options)
+end)
---@param languages? string[]|string
----@param _options? table
----@param callback? function
-M.update = a.sync(function(languages, _options, callback)
+M.update = a.async(function(languages)
reload_parsers()
if not languages or #languages == 0 then
languages = 'all'
@@ -455,14 +480,12 @@ M.update = a.sync(function(languages, _options, callback)
languages = vim.tbl_filter(needs_update, languages) ---@type string[]
if #languages > 0 then
- install(languages, { force = true }, callback)
+ return install(languages, { force = true })
else
log.info('All parsers are up-to-date')
- if callback then
- callback(true)
- end
+ return true
end
-end, 3)
+end)
---@async
---@param logger Logger
@@ -477,7 +500,7 @@ local function uninstall_lang(logger, lang, parser, queries)
if fn.filereadable(parser) == 1 then
logger:debug('Unlinking ' .. parser)
local perr = uv_unlink(parser)
- a.main()
+ a.schedule()
if perr then
return logger:error(perr)
@@ -487,7 +510,7 @@ local function uninstall_lang(logger, lang, parser, queries)
if fn.isdirectory(queries) == 1 then
logger:debug('Unlinking ' .. queries)
local qerr = uv_unlink(queries)
- a.main()
+ a.schedule()
if qerr then
return logger:error(qerr)
@@ -498,16 +521,14 @@ local function uninstall_lang(logger, lang, parser, queries)
end
---@param languages string[]|string
----@param _options? table
----@param _callback? fun()
-M.uninstall = a.sync(function(languages, _options, _callback)
+M.uninstall = a.async(function(languages)
languages = config.norm_languages(languages or 'all', { missing = true, dependencies = true })
local parser_dir = config.get_install_dir('parser')
local query_dir = config.get_install_dir('queries')
local installed = config.installed_parsers()
- local tasks = {} ---@type fun()[]
+ local task_funs = {} ---@type async.TaskFun[]
local done = 0
for _, lang in ipairs(languages) do
local logger = log.new('uninstall/' .. lang)
@@ -516,7 +537,7 @@ M.uninstall = a.sync(function(languages, _options, _callback)
else
local parser = fs.joinpath(parser_dir, lang) .. '.so'
local queries = fs.joinpath(query_dir, lang)
- tasks[#tasks + 1] = a.sync(function()
+ task_funs[#task_funs + 1] = a.async(function()
local err = uninstall_lang(logger, lang, parser, queries)
if not err then
done = done + 1
@@ -525,11 +546,11 @@ M.uninstall = a.sync(function(languages, _options, _callback)
end
end
- a.join(MAX_JOBS, nil, tasks)
- if #tasks > 1 then
- a.main()
- log.info('Uninstalled %d/%d languages', done, #tasks)
+ join(MAX_JOBS, task_funs)
+ if #task_funs > 1 then
+ a.schedule()
+ log.info('Uninstalled %d/%d languages', done, #task_funs)
end
-end, 2)
+end)
return M