From ee082883d18a8990cec359862db4e93ea850cb8c Mon Sep 17 00:00:00 2001 From: William Boman Date: Wed, 13 Apr 2022 15:51:42 +0200 Subject: feat(async): add a.wait_all (#596) --- lua/nvim-lsp-installer/core/async/init.lua | 81 ++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 5 deletions(-) (limited to 'lua') diff --git a/lua/nvim-lsp-installer/core/async/init.lua b/lua/nvim-lsp-installer/core/async/init.lua index d448bc30..8537b909 100644 --- a/lua/nvim-lsp-installer/core/async/init.lua +++ b/lua/nvim-lsp-installer/core/async/init.lua @@ -73,11 +73,12 @@ local function new_execution_context(suspend_fn, callback, ...) local ok, promise_or_result = co.resume(thread, ...) if ok then if co.status(thread) == "suspended" then - assert( - getmetatable(promise_or_result) == Promise, - "Expected Promise to have been yielded in async coroutine." - ) - promise_or_result(step) + if getmetatable(promise_or_result) == Promise then + promise_or_result(step) + else + -- yield to parent coroutine + step(coroutine.yield(promise_or_result)) + end else callback(true, promise_or_result) thread = nil @@ -139,4 +140,74 @@ exports.scheduler = function() await(vim.schedule) end +---Creates a oneshot channel that can only send once. +local function oneshot_channel() + local has_sent = false + local sent_value + local saved_callback + + return { + send = function(...) + assert(not has_sent, "Oneshot channel can only send once.") + has_sent = true + sent_value = { ... } + if saved_callback then + saved_callback(unpack(sent_value)) + end + end, + receive = function() + return await(function(resolve) + if has_sent then + resolve(unpack(sent_value)) + else + saved_callback = resolve + end + end) + end, + } +end + +---@async +---@param suspend_fns async fun()[] +exports.wait_all = function(suspend_fns) + local channel = oneshot_channel() + + do + local results = {} + local threads = {} + local count = #suspend_fns + local completed = 0 + + local function callback(i) + return function(success, result) + if not success then + for _, cancel_thread in ipairs(threads) do + cancel_thread() + end + channel.send(false, result) + results = nil + threads = nil + else + results[i] = result + completed = completed + 1 + if completed >= count then + channel.send(true, results) + results = nil + threads = nil + end + end + end + end + for i, suspend_fn in ipairs(suspend_fns) do + threads[i] = exports.run(suspend_fn, callback(i)) + end + end + + local ok, results = channel.receive() + if not ok then + error(results, 2) + end + return unpack(results) +end + return exports -- cgit v1.2.3-70-g09d2