aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/async.lua769
-rw-r--r--lua/nvim-treesitter/install.lua117
2 files changed, 760 insertions, 126 deletions
diff --git a/lua/nvim-treesitter/async.lua b/lua/nvim-treesitter/async.lua
index 8a5e7c8af..e78c0f964 100644
--- a/lua/nvim-treesitter/async.lua
+++ b/lua/nvim-treesitter/async.lua
@@ -1,112 +1,725 @@
-local co = coroutine
+local pcall = copcall or pcall
+--- @param ... any
+--- @return {[integer]: any, n: integer}
+local function pack_len(...)
+ return { n = select('#', ...), ... }
+end
+
+--- like unpack() but use the length set by F.pack_len if present
+--- @param t? { [integer]: any, n?: integer }
+--- @param first? integer
+--- @return ...any
+local function unpack_len(t, first)
+ if t then
+ return unpack(t, first or 1, t.n or table.maxn(t))
+ end
+end
+
+--- @class async
local M = {}
----Executes a future with a callback when it is done
----@param func function
----@param callback function
----@param ... unknown
-local function execute(func, callback, ...)
- local thread = co.create(func)
+--- Weak table to keep track of running tasks
+--- @type table<thread,async.Task?>
+local threads = setmetatable({}, { __mode = 'k' })
+
+--- @return async.Task?
+local function running()
+ local task = threads[coroutine.running()]
+ if task and not (task:_completed() or task._closing) then
+ return task
+ end
+end
+
+--- Base class for async tasks. Async functions should return a subclass of
+--- this. This is designed specifically to be a base class of uv_handle_t
+--- @class async.Handle
+--- @field close fun(self: async.Handle, callback?: fun())
+--- @field is_closing? fun(self: async.Handle): boolean
+
+--- @alias async.CallbackFn fun(...: any): async.Handle?
+
+--- @class async.Task : async.Handle
+--- @field package _callbacks table<integer,fun(err?: any, ...: any)>
+--- @field package _callback_pos integer
+--- @field private _thread thread
+---
+--- Tasks can call other async functions (task of callback functions)
+--- when we are waiting on a child, we store the handle to it here so we can
+--- cancel it.
+--- @field private _current_child? async.Handle
+---
+--- Error result of the task is an error occurs.
+--- Must use `await` to get the result.
+--- @field private _err? any
+---
+--- Result of the task.
+--- Must use `await` to get the result.
+--- @field private _result? any[]
+local Task = {}
+Task.__index = Task
+
+--- @private
+--- @param func function
+--- @return async.Task
+function Task._new(func)
+ local thread = coroutine.create(func)
+
+ local self = setmetatable({
+ _closing = false,
+ _thread = thread,
+ _callbacks = {},
+ _callback_pos = 1,
+ }, Task)
+
+ threads[thread] = self
+
+ return self
+end
+
+--- @param callback fun(err?: any, ...: any)
+function Task:await(callback)
+ if self._closing then
+ callback('closing')
+ elseif self:_completed() then -- TODO(lewis6991): test
+ -- Already finished or closed
+ callback(self._err, unpack_len(self._result))
+ else
+ self._callbacks[self._callback_pos] = callback
+ self._callback_pos = self._callback_pos + 1
+ end
+end
+
+--- @package
+function Task:_completed()
+ return (self._err or self._result) ~= nil
+end
+
+-- Use max 32-bit signed int value to avoid overflow on 32-bit systems.
+-- Do not use `math.huge` as it is not interpreted as a positive integer on all
+-- platforms.
+local MAX_TIMEOUT = 2 ^ 31 - 1
+
+--- Synchronously wait (protected) for a task to finish (blocking)
+---
+--- If an error is returned, `Task:traceback()` can be used to get the
+--- stack trace of the error.
+---
+--- Example:
+--- ```lua
+---
+--- local ok, err_or_result = task:pwait(10)
+---
+--- if not ok then
+--- error(task:traceback(err_or_result))
+--- end
+---
+--- local _, result = assert(task:pwait(10))
+--- ```
+---
+--- Can be called if a task is closing.
+--- @param timeout? integer
+--- @return boolean status
+--- @return any ... result or error
+function Task:pwait(timeout)
+ local done = vim.wait(timeout or MAX_TIMEOUT, function()
+ -- Note we use self:_completed() instead of self:await() to avoid creating a
+ -- callback. This avoids having to cleanup/unregister any callback in the
+ -- case of a timeout.
+ return self:_completed()
+ end)
+
+ if not done then
+ return false, 'timeout'
+ elseif self._err then
+ return false, self._err
+ else
+ return true, unpack_len(self._result)
+ end
+end
+
+--- Synchronously wait for a task to finish (blocking)
+---
+--- Example:
+--- ```lua
+--- local result = task:wait(10) -- wait for 10ms or else error
+---
+--- local result = task:wait() -- wait indefinitely
+--- ```
+--- @param timeout? integer Timeout in milliseconds
+--- @return any ... result
+function Task:wait(timeout)
+ local res = pack_len(self:pwait(timeout))
+ local stat = res[1]
+
+ if not stat then
+ error(self:traceback(res[2]))
+ end
+
+ return unpack_len(res, 2)
+end
+
+--- @private
+--- @param msg? string
+--- @param _lvl? integer
+--- @return string
+function Task:_traceback(msg, _lvl)
+ _lvl = _lvl or 0
+
+ local thread = ('[%s] '):format(self._thread)
+
+ local child = self._current_child
+ if getmetatable(child) == Task then
+ --- @cast child async.Task
+ msg = child:_traceback(msg, _lvl + 1)
+ end
+
+ local tblvl = getmetatable(child) == Task and 2 or nil
+ msg = (msg or '') .. debug.traceback(self._thread, '', tblvl):gsub('\n\t', '\n\t' .. thread)
+
+ if _lvl == 0 then
+ --- @type string
+ msg = msg
+ :gsub('\nstack traceback:\n', '\nSTACK TRACEBACK:\n', 1)
+ :gsub('\nstack traceback:\n', '\n')
+ :gsub('\nSTACK TRACEBACK:\n', '\nstack traceback:\n', 1)
+ end
+
+ return msg
+end
- local function step(...)
- local ret = { co.resume(thread, ...) }
- ---@type boolean, any
- local stat, nargs_or_err = unpack(ret)
+--- Get the traceback of a task when it is not active.
+--- Will also get the traceback of nested tasks.
+---
+--- @param msg? string
+--- @return string
+function Task:traceback(msg)
+ return self:_traceback(msg)
+end
- if not stat then
- error(
- string.format(
- 'The coroutine failed with this message: %s\n%s',
- nargs_or_err,
- debug.traceback(thread)
- )
- )
+--- If a task completes with an error, raise the error
+function Task:raise_on_error()
+ self:await(function(err)
+ if err then
+ error(self:_traceback(err), 0)
end
+ end)
+ return self
+end
- if co.status(thread) == 'dead' then
- if callback then
- callback(unpack(ret, 3, table.maxn(ret)))
+--- @private
+--- @param err? any
+--- @param result? {[integer]: any, n: integer}
+function Task:_finish(err, result)
+ self._current_child = nil
+ self._err = err
+ self._result = result
+ threads[self._thread] = nil
+
+ local errs = {} --- @type string[]
+ for _, cb in pairs(self._callbacks) do
+ --- @type boolean, string
+ local ok, cb_err = pcall(cb, err, unpack_len(result))
+ if not ok then
+ errs[#errs + 1] = cb_err
+ end
+ end
+
+ if #errs > 0 then
+ error(table.concat(errs, '\n'), 0)
+ end
+end
+
+--- @return boolean
+function Task:is_closing()
+ return self._closing
+end
+
+--- Close the task and all its children.
+--- If callback is provided it will run asynchronously,
+--- else it will run synchronously.
+---
+--- @param callback? fun()
+function Task:close(callback)
+ if self:_completed() then
+ if callback then
+ callback()
+ end
+ return
+ end
+
+ if self._closing then
+ return
+ end
+
+ self._closing = true
+
+ if callback then -- async
+ if self._current_child then
+ self._current_child:close(function()
+ self:_finish('closed')
+ callback()
+ end)
+ else
+ self:_finish('closed')
+ callback()
+ end
+ else -- sync
+ if self._current_child then
+ self._current_child:close(function()
+ self:_finish('closed')
+ end)
+ else
+ self:_finish('closed')
+ end
+ vim.wait(0, function()
+ return self:_completed()
+ end)
+ end
+end
+
+--- @param obj any
+--- @return boolean
+local function is_async_handle(obj)
+ local ty = type(obj)
+ return (ty == 'table' or ty == 'userdata') and vim.is_callable(obj.close)
+end
+
+--- @param ... any
+function Task:_resume(...)
+ --- @type [boolean, string|async.CallbackFn]
+ local ret = pack_len(coroutine.resume(self._thread, ...))
+ local stat = ret[1]
+
+ if not stat then
+ -- Coroutine had error
+ self:_finish(ret[2])
+ elseif coroutine.status(self._thread) == 'dead' then
+ -- Coroutine finished
+ local result = pack_len(unpack_len(ret, 2))
+ self:_finish(nil, result)
+ else
+ local fn = ret[2]
+ --- @cast fn -string