diff options
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/nvim-lsp-installer/core/async/init.lua | 81 |
1 files changed, 76 insertions, 5 deletions
diff --git a/lua/nvim-lsp-installer/core/async/init.lua b/lua/nvim-lsp-installer/core/async/init.lua index d448bc30..8537b909 100644 --- a/lua/nvim-lsp-installer/core/async/init.lua +++ b/lua/nvim-lsp-installer/core/async/init.lua @@ -73,11 +73,12 @@ local function new_execution_context(suspend_fn, callback, ...) local ok, promise_or_result = co.resume(thread, ...) if ok then if co.status(thread) == "suspended" then - assert( - getmetatable(promise_or_result) == Promise, - "Expected Promise to have been yielded in async coroutine." - ) - promise_or_result(step) + if getmetatable(promise_or_result) == Promise then + promise_or_result(step) + else + -- yield to parent coroutine + step(coroutine.yield(promise_or_result)) + end else callback(true, promise_or_result) thread = nil @@ -139,4 +140,74 @@ exports.scheduler = function() await(vim.schedule) end +---Creates a oneshot channel that can only send once. +local function oneshot_channel() + local has_sent = false + local sent_value + local saved_callback + + return { + send = function(...) + assert(not has_sent, "Oneshot channel can only send once.") + has_sent = true + sent_value = { ... } + if saved_callback then + saved_callback(unpack(sent_value)) + end + end, + receive = function() + return await(function(resolve) + if has_sent then + resolve(unpack(sent_value)) + else + saved_callback = resolve + end + end) + end, + } +end + +---@async +---@param suspend_fns async fun()[] +exports.wait_all = function(suspend_fns) + local channel = oneshot_channel() + + do + local results = {} + local threads = {} + local count = #suspend_fns + local completed = 0 + + local function callback(i) + return function(success, result) + if not success then + for _, cancel_thread in ipairs(threads) do + cancel_thread() + end + channel.send(false, result) + results = nil + threads = nil + else + results[i] = result + completed = completed + 1 + if completed >= count then + channel.send(true, results) + results = nil + threads = nil + end + end + end + end + for i, suspend_fn in ipairs(suspend_fns) do + threads[i] = exports.run(suspend_fn, callback(i)) + end + end + + local ok, results = channel.receive() + if not ok then + error(results, 2) + end + return unpack(results) +end + return exports |
