aboutsummaryrefslogtreecommitdiffstats
path: root/lua/mason-core/async/init.lua
blob: 63d4ec940366c3f18e1412c821e025a5278721c4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
local _ = require "mason-core.functional"
local co = coroutine

local exports = {}

local Promise = {}
Promise.__index = Promise

function Promise.new(resolver)
    return setmetatable({ resolver = resolver, has_resolved = false }, Promise)
end

---@param success boolean
---@param cb fun(success: boolean, value: table)
function Promise:_wrap_resolver_cb(success, cb)
    return function(...)
        if self.has_resolved then
            return
        end
        self.has_resolved = true
        cb(success, { ... })
    end
end

function Promise:__call(callback)
    self.resolver(self:_wrap_resolver_cb(true, callback), self:_wrap_resolver_cb(false, callback))
end

local function await(resolver)
    local ok, value = co.yield(Promise.new(resolver))
    if not ok then
        error(value[1], 0)
    end
    return unpack(value)
end

local function table_pack(...)
    return { n = select("#", ...), ... }
end

---@generic T
---@param async_fn T
---@param should_reject_err boolean? Whether the provided async_fn takes a callback with the signature `fun(err, result)`
---@return T
local function promisify(async_fn, should_reject_err)
    return function(...)
        local args = table_pack(...)
        return await(function(resolve, reject)
            if should_reject_err then
                args[args.n + 1] = function(err, result)
                    if err then
                        reject(err)
                    else
                        resolve(result)
                    end
                end
            else
                args[args.n + 1] = resolve
            end
            local ok, err = pcall(async_fn, unpack(args, 1, args.n + 1))
            if not ok then
                reject(err)
            end
        end)
    end
end

local function new_execution_context(suspend_fn, callback, ...)
    ---@type thread?
    local thread = co.create(suspend_fn)
    local cancelled = false
    local step
    step = function(...)
        if cancelled or not thread then
            return
        end
        local ok, promise_or_result = co.resume(thread, ...)
        if ok then
            if co.status(thread) == "suspended" then
                if getmetatable(promise_or_result) == Promise then
                    promise_or_result(step)
                else
                    -- yield to parent coroutine
                    step(coroutine.yield(promise_or_result))
                end
            else
                callback(true, promise_or_result)
                thread = nil
            end
        else
            callback(false, promise_or_result)
            thread = nil
        end
    end

    step(...)
    return function()
        cancelled = true
        thread = nil
    end
end

exports.run = function(suspend_fn, callback, ...)
    return new_execution_context(suspend_fn, callback, ...)
end

---@generic T
---@param suspend_fn T
exports.scope = function(suspend_fn)
    return function(...)
        return new_execution_context(suspend_fn, function(success, err)
            if not success then
                error(err, 0)
            end
        end, ...)
    end
end

exports.run_blocking = function(suspend_fn, ...)
    local resolved, ok, result
    local cancel_coroutine = new_execution_context(suspend_fn, function(a, b)
        resolved = true
        ok = a
        result = b
    end, ...)

    if resolved or vim.wait(0x7FFFFFFF, function()
        return resolved == true
    end, 50) then
        if not ok then
            error(result, 2)
        end
        return result
    else
        cancel_coroutine()
        error("async function failed to resolve in time.", 2)
    end
end

exports.wait = await
exports.promisify = promisify

exports.sleep = function(ms)
    await(function(resolve)
        vim.defer_fn(resolve, ms)
    end)
end

exports.scheduler = function()
    if vim.in_fast_event() then
        await(vim.schedule)
    end
end

---@async
---@param suspend_fns async fun()[]
---@param mode '"first"' | '"all"'
local function wait(suspend_fns, mode)
    local channel = require("mason-core.async.control").OneShotChannel.new()
    if #suspend_fns == 0 then
        return
    end

    do
        local results = {}
        local thread_cancellations = {}
        local count = #suspend_fns
        local completed = 0

        local function cancel()
            for _, cancel_thread in ipairs(thread_cancellations) do
                cancel_thread()
            end
        end

        for i, suspend_fn in ipairs(suspend_fns) do
            thread_cancellations[i] = exports.run(suspend_fn, function(success, result)
                completed = completed + 1
                if channel:is_closed() then
                    return
                end
                if not success then
                    cancel()
                    channel:send(false, result)
                    results = nil
                    thread_cancellations = {}
                else
                    results[i] = result
                    if mode == "first" or completed >= count then
                        cancel()
                        channel:send(true, mode == "first" and { result } or results)
                        results = nil
                        thread_cancellations = {}
                    end
                end
            end)
        end
    end

    local ok, results = channel:receive()
    if not ok then
        error(results, 2)
    end
    return unpack(results)
end

---@async
---@param suspend_fns async fun()[]
function exports.wait_all(suspend_fns)
    return wait(suspend_fns, "all")
end

---@async
---@param suspend_fns async fun()[]
function exports.wait_first(suspend_fns)
    return wait(suspend_fns, "first")
end

function exports.blocking(suspend_fn)
    return _.partial(exports.run_blocking, suspend_fn)
end

return exports