aboutsummaryrefslogtreecommitdiffstats
path: root/lua/nvim-lsp-installer/core/async/init.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua/nvim-lsp-installer/core/async/init.lua')
-rw-r--r--lua/nvim-lsp-installer/core/async/init.lua81
1 files changed, 76 insertions, 5 deletions
diff --git a/lua/nvim-lsp-installer/core/async/init.lua b/lua/nvim-lsp-installer/core/async/init.lua
index d448bc30..8537b909 100644
--- a/lua/nvim-lsp-installer/core/async/init.lua
+++ b/lua/nvim-lsp-installer/core/async/init.lua
@@ -73,11 +73,12 @@ local function new_execution_context(suspend_fn, callback, ...)
local ok, promise_or_result = co.resume(thread, ...)
if ok then
if co.status(thread) == "suspended" then
- assert(
- getmetatable(promise_or_result) == Promise,
- "Expected Promise to have been yielded in async coroutine."
- )
- promise_or_result(step)
+ if getmetatable(promise_or_result) == Promise then
+ promise_or_result(step)
+ else
+ -- yield to parent coroutine
+ step(coroutine.yield(promise_or_result))
+ end
else
callback(true, promise_or_result)
thread = nil
@@ -139,4 +140,74 @@ exports.scheduler = function()
await(vim.schedule)
end
+---Creates a oneshot channel that can only send once.
+local function oneshot_channel()
+ local has_sent = false
+ local sent_value
+ local saved_callback
+
+ return {
+ send = function(...)
+ assert(not has_sent, "Oneshot channel can only send once.")
+ has_sent = true
+ sent_value = { ... }
+ if saved_callback then
+ saved_callback(unpack(sent_value))
+ end
+ end,
+ receive = function()
+ return await(function(resolve)
+ if has_sent then
+ resolve(unpack(sent_value))
+ else
+ saved_callback = resolve
+ end
+ end)
+ end,
+ }
+end
+
+---@async
+---@param suspend_fns async fun()[]
+exports.wait_all = function(suspend_fns)
+ local channel = oneshot_channel()
+
+ do
+ local results = {}
+ local threads = {}
+ local count = #suspend_fns
+ local completed = 0
+
+ local function callback(i)
+ return function(success, result)
+ if not success then
+ for _, cancel_thread in ipairs(threads) do
+ cancel_thread()
+ end
+ channel.send(false, result)
+ results = nil
+ threads = nil
+ else
+ results[i] = result
+ completed = completed + 1
+ if completed >= count then
+ channel.send(true, results)
+ results = nil
+ threads = nil
+ end
+ end
+ end
+ end
+ for i, suspend_fn in ipairs(suspend_fns) do
+ threads[i] = exports.run(suspend_fn, callback(i))
+ end
+ end
+
+ local ok, results = channel.receive()
+ if not ok then
+ error(results, 2)
+ end
+ return unpack(results)
+end
+
return exports