1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
|
local _ = require "mason-core.functional"
local co = coroutine
local exports = {}
local Promise = {}
Promise.__index = Promise
function Promise.new(resolver)
return setmetatable({ resolver = resolver, has_resolved = false }, Promise)
end
---@param success boolean
---@param cb fun(success: boolean, value: table)
function Promise:_wrap_resolver_cb(success, cb)
return function(...)
if self.has_resolved then
return
end
self.has_resolved = true
cb(success, { ... })
end
end
function Promise:__call(callback)
self.resolver(self:_wrap_resolver_cb(true, callback), self:_wrap_resolver_cb(false, callback))
end
local function await(resolver)
local ok, value = co.yield(Promise.new(resolver))
if not ok then
error(value[1], 0)
end
return unpack(value)
end
local function table_pack(...)
return { n = select("#", ...), ... }
end
---@generic T
---@param async_fn T
---@param should_reject_err boolean? Whether the provided async_fn takes a callback with the signature `fun(err, result)`
---@return T
local function promisify(async_fn, should_reject_err)
return function(...)
local args = table_pack(...)
return await(function(resolve, reject)
if should_reject_err then
args[args.n + 1] = function(err, result)
if err then
reject(err)
else
resolve(result)
end
end
else
args[args.n + 1] = resolve
end
local ok, err = pcall(async_fn, unpack(args, 1, args.n + 1))
if not ok then
reject(err)
end
end)
end
end
local function new_execution_context(suspend_fn, callback, ...)
---@type thread?
local thread = co.create(suspend_fn)
local cancelled = false
local step
step = function(...)
if cancelled or not thread then
return
end
local ok, promise_or_result = co.resume(thread, ...)
if ok then
if co.status(thread) == "suspended" then
if getmetatable(promise_or_result) == Promise then
promise_or_result(step)
else
-- yield to parent coroutine
step(coroutine.yield(promise_or_result))
end
else
callback(true, promise_or_result)
thread = nil
end
else
callback(false, promise_or_result)
thread = nil
end
end
step(...)
return function()
cancelled = true
thread = nil
end
end
exports.run = function(suspend_fn, callback, ...)
return new_execution_context(suspend_fn, callback, ...)
end
---@generic T
---@param suspend_fn T
exports.scope = function(suspend_fn)
return function(...)
return new_execution_context(suspend_fn, function(success, err)
if not success then
error(err, 0)
end
end, ...)
end
end
exports.run_blocking = function(suspend_fn, ...)
local resolved, ok, result
local cancel_coroutine = new_execution_context(suspend_fn, function(a, b)
resolved = true
ok = a
result = b
end, ...)
if resolved or vim.wait(0x7FFFFFFF, function()
return resolved == true
end, 50) then
if not ok then
error(result, 2)
end
return result
else
cancel_coroutine()
error("async function failed to resolve in time.", 2)
end
end
exports.wait = await
exports.promisify = promisify
exports.sleep = function(ms)
await(function(resolve)
vim.defer_fn(resolve, ms)
end)
end
exports.scheduler = function()
if vim.in_fast_event() then
await(vim.schedule)
end
end
---@async
---@param suspend_fns async fun()[]
---@param mode '"first"' | '"all"'
local function wait(suspend_fns, mode)
local channel = require("mason-core.async.control").OneShotChannel.new()
if #suspend_fns == 0 then
return
end
do
local results = {}
local thread_cancellations = {}
local count = #suspend_fns
local completed = 0
local function cancel()
for _, cancel_thread in ipairs(thread_cancellations) do
cancel_thread()
end
end
for i, suspend_fn in ipairs(suspend_fns) do
thread_cancellations[i] = exports.run(suspend_fn, function(success, result)
completed = completed + 1
if channel:is_closed() then
return
end
if not success then
cancel()
channel:send(false, result)
results = nil
thread_cancellations = {}
else
results[i] = result
if mode == "first" or completed >= count then
cancel()
channel:send(true, mode == "first" and { result } or results)
results = nil
thread_cancellations = {}
end
end
end)
end
end
local ok, results = channel:receive()
if not ok then
error(results, 2)
end
return unpack(results)
end
---@async
---@param suspend_fns async fun()[]
function exports.wait_all(suspend_fns)
return wait(suspend_fns, "all")
end
---@async
---@param suspend_fns async fun()[]
function exports.wait_first(suspend_fns)
return wait(suspend_fns, "first")
end
function exports.blocking(suspend_fn)
return _.partial(exports.run_blocking, suspend_fn)
end
return exports
|