diff options
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/nvim-treesitter/async.lua | 769 | ||||
| -rw-r--r-- | lua/nvim-treesitter/install.lua | 117 |
2 files changed, 760 insertions, 126 deletions
diff --git a/lua/nvim-treesitter/async.lua b/lua/nvim-treesitter/async.lua index 8a5e7c8af..e78c0f964 100644 --- a/lua/nvim-treesitter/async.lua +++ b/lua/nvim-treesitter/async.lua @@ -1,112 +1,725 @@ -local co = coroutine +local pcall = copcall or pcall +--- @param ... any +--- @return {[integer]: any, n: integer} +local function pack_len(...) + return { n = select('#', ...), ... } +end + +--- like unpack() but use the length set by F.pack_len if present +--- @param t? { [integer]: any, n?: integer } +--- @param first? integer +--- @return ...any +local function unpack_len(t, first) + if t then + return unpack(t, first or 1, t.n or table.maxn(t)) + end +end + +--- @class async local M = {} ----Executes a future with a callback when it is done ----@param func function ----@param callback function ----@param ... unknown -local function execute(func, callback, ...) - local thread = co.create(func) +--- Weak table to keep track of running tasks +--- @type table<thread,async.Task?> +local threads = setmetatable({}, { __mode = 'k' }) + +--- @return async.Task? +local function running() + local task = threads[coroutine.running()] + if task and not (task:_completed() or task._closing) then + return task + end +end + +--- Base class for async tasks. Async functions should return a subclass of +--- this. This is designed specifically to be a base class of uv_handle_t +--- @class async.Handle +--- @field close fun(self: async.Handle, callback?: fun()) +--- @field is_closing? fun(self: async.Handle): boolean + +--- @alias async.CallbackFn fun(...: any): async.Handle? + +--- @class async.Task : async.Handle +--- @field package _callbacks table<integer,fun(err?: any, ...: any)> +--- @field package _callback_pos integer +--- @field private _thread thread +--- +--- Tasks can call other async functions (task of callback functions) +--- when we are waiting on a child, we store the handle to it here so we can +--- cancel it. +--- @field private _current_child? async.Handle +--- +--- Error result of the task is an error occurs. +--- Must use `await` to get the result. +--- @field private _err? any +--- +--- Result of the task. +--- Must use `await` to get the result. +--- @field private _result? any[] +local Task = {} +Task.__index = Task + +--- @private +--- @param func function +--- @return async.Task +function Task._new(func) + local thread = coroutine.create(func) + + local self = setmetatable({ + _closing = false, + _thread = thread, + _callbacks = {}, + _callback_pos = 1, + }, Task) + + threads[thread] = self + + return self +end + +--- @param callback fun(err?: any, ...: any) +function Task:await(callback) + if self._closing then + callback('closing') + elseif self:_completed() then -- TODO(lewis6991): test + -- Already finished or closed + callback(self._err, unpack_len(self._result)) + else + self._callbacks[self._callback_pos] = callback + self._callback_pos = self._callback_pos + 1 + end +end + +--- @package +function Task:_completed() + return (self._err or self._result) ~= nil +end + +-- Use max 32-bit signed int value to avoid overflow on 32-bit systems. +-- Do not use `math.huge` as it is not interpreted as a positive integer on all +-- platforms. +local MAX_TIMEOUT = 2 ^ 31 - 1 + +--- Synchronously wait (protected) for a task to finish (blocking) +--- +--- If an error is returned, `Task:traceback()` can be used to get the +--- stack trace of the error. +--- +--- Example: +--- ```lua +--- +--- local ok, err_or_result = task:pwait(10) +--- +--- if not ok then +--- error(task:traceback(err_or_result)) +--- end +--- +--- local _, result = assert(task:pwait(10)) +--- ``` +--- +--- Can be called if a task is closing. +--- @param timeout? integer +--- @return boolean status +--- @return any ... result or error +function Task:pwait(timeout) + local done = vim.wait(timeout or MAX_TIMEOUT, function() + -- Note we use self:_completed() instead of self:await() to avoid creating a + -- callback. This avoids having to cleanup/unregister any callback in the + -- case of a timeout. + return self:_completed() + end) + + if not done then + return false, 'timeout' + elseif self._err then + return false, self._err + else + return true, unpack_len(self._result) + end +end + +--- Synchronously wait for a task to finish (blocking) +--- +--- Example: +--- ```lua +--- local result = task:wait(10) -- wait for 10ms or else error +--- +--- local result = task:wait() -- wait indefinitely +--- ``` +--- @param timeout? integer Timeout in milliseconds +--- @return any ... result +function Task:wait(timeout) + local res = pack_len(self:pwait(timeout)) + local stat = res[1] + + if not stat then + error(self:traceback(res[2])) + end + + return unpack_len(res, 2) +end + +--- @private +--- @param msg? string +--- @param _lvl? integer +--- @return string +function Task:_traceback(msg, _lvl) + _lvl = _lvl or 0 + + local thread = ('[%s] '):format(self._thread) + + local child = self._current_child + if getmetatable(child) == Task then + --- @cast child async.Task + msg = child:_traceback(msg, _lvl + 1) + end + + local tblvl = getmetatable(child) == Task and 2 or nil + msg = (msg or '') .. debug.traceback(self._thread, '', tblvl):gsub('\n\t', '\n\t' .. thread) + + if _lvl == 0 then + --- @type string + msg = msg + :gsub('\nstack traceback:\n', '\nSTACK TRACEBACK:\n', 1) + :gsub('\nstack traceback:\n', '\n') + :gsub('\nSTACK TRACEBACK:\n', '\nstack traceback:\n', 1) + end + + return msg +end - local function step(...) - local ret = { co.resume(thread, ...) } - ---@type boolean, any - local stat, nargs_or_err = unpack(ret) +--- Get the traceback of a task when it is not active. +--- Will also get the traceback of nested tasks. +--- +--- @param msg? string +--- @return string +function Task:traceback(msg) + return self:_traceback(msg) +end - if not stat then - error( - string.format( - 'The coroutine failed with this message: %s\n%s', - nargs_or_err, - debug.traceback(thread) - ) - ) +--- If a task completes with an error, raise the error +function Task:raise_on_error() + self:await(function(err) + if err then + error(self:_traceback(err), 0) end + end) + return self +end - if co.status(thread) == 'dead' then - if callback then - callback(unpack(ret, 3, table.maxn(ret))) +--- @private +--- @param err? any +--- @param result? {[integer]: any, n: integer} +function Task:_finish(err, result) + self._current_child = nil + self._err = err + self._result = result + threads[self._thread] = nil + + local errs = {} --- @type string[] + for _, cb in pairs(self._callbacks) do + --- @type boolean, string + local ok, cb_err = pcall(cb, err, unpack_len(result)) + if not ok then + errs[#errs + 1] = cb_err + end + end + + if #errs > 0 then + error(table.concat(errs, '\n'), 0) + end +end + +--- @return boolean +function Task:is_closing() + return self._closing +end + +--- Close the task and all its children. +--- If callback is provided it will run asynchronously, +--- else it will run synchronously. +--- +--- @param callback? fun() +function Task:close(callback) + if self:_completed() then + if callback then + callback() + end + return + end + + if self._closing then + return + end + + self._closing = true + + if callback then -- async + if self._current_child then + self._current_child:close(function() + self:_finish('closed') + callback() + end) + else + self:_finish('closed') + callback() + end + else -- sync + if self._current_child then + self._current_child:close(function() + self:_finish('closed') + end) + else + self:_finish('closed') + end + vim.wait(0, function() + return self:_completed() + end) + end +end + +--- @param obj any +--- @return boolean +local function is_async_handle(obj) + local ty = type(obj) + return (ty == 'table' or ty == 'userdata') and vim.is_callable(obj.close) +end + +--- @param ... any +function Task:_resume(...) + --- @type [boolean, string|async.CallbackFn] + local ret = pack_len(coroutine.resume(self._thread, ...)) + local stat = ret[1] + + if not stat then + -- Coroutine had error + self:_finish(ret[2]) + elseif coroutine.status(self._thread) == 'dead' then + -- Coroutine finished + local result = pack_len(unpack_len(ret, 2)) + self:_finish(nil, result) + else + local fn = ret[2] + --- @cast fn -string |
