aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorWilliam Boman <william@redwill.se>2023-08-17 19:34:23 +0200
committerGitHub <noreply@github.com>2023-08-17 19:34:23 +0200
commitb5bb138312dbd3f7729197ca659cbe5221d36a03 (patch)
treea0d5a79d6e743693e4f064e5a091ed4f0e47ac89
parentfix(ui): properly reset new package version state (#1454) (diff)
downloadmason-b5bb138312dbd3f7729197ca659cbe5221d36a03.tar
mason-b5bb138312dbd3f7729197ca659cbe5221d36a03.tar.gz
mason-b5bb138312dbd3f7729197ca659cbe5221d36a03.tar.bz2
mason-b5bb138312dbd3f7729197ca659cbe5221d36a03.tar.lz
mason-b5bb138312dbd3f7729197ca659cbe5221d36a03.tar.xz
mason-b5bb138312dbd3f7729197ca659cbe5221d36a03.tar.zst
mason-b5bb138312dbd3f7729197ca659cbe5221d36a03.zip
chore(async): add Channel (#1456)
-rw-r--r--lua/mason-core/async/control.lua62
-rw-r--r--tests/mason-core/async/async_spec.lua55
2 files changed, 115 insertions, 2 deletions
diff --git a/lua/mason-core/async/control.lua b/lua/mason-core/async/control.lua
index c9479540..57aa88db 100644
--- a/lua/mason-core/async/control.lua
+++ b/lua/mason-core/async/control.lua
@@ -15,9 +15,14 @@ function Condvar:wait()
end)
end
+function Condvar:notify()
+ local handle = table.remove(self.handles)
+ pcall(handle)
+end
+
function Condvar:notify_all()
- for _, handle in ipairs(self.handles) do
- pcall(handle)
+ while #self.handles > 0 do
+ self:notify()
end
self.handles = {}
end
@@ -97,8 +102,61 @@ function OneShotChannel:receive()
return unpack(self.value)
end
+---@class Channel
+---@field private condvar Condvar
+---@field private buffer any?
+---@field is_closed boolean
+local Channel = {}
+Channel.__index = Channel
+function Channel.new()
+ return setmetatable({
+ condvar = Condvar.new(),
+ buffer = nil,
+ is_closed = false,
+ }, Channel)
+end
+
+function Channel:close()
+ self.is_closed = true
+end
+
+---@async
+function Channel:send(value)
+ assert(not self.is_closed, "Channel is closed.")
+ while self.buffer ~= nil do
+ self.condvar:wait()
+ end
+ self.buffer = value
+ self.condvar:notify()
+ while self.buffer ~= nil do
+ self.condvar:wait()
+ end
+end
+
+---@async
+function Channel:receive()
+ assert(not self.is_closed, "Channel is closed.")
+ while self.buffer == nil do
+ self.condvar:wait()
+ end
+ local value = self.buffer
+ self.buffer = nil
+ self.condvar:notify()
+ return value
+end
+
+---@async
+function Channel:iter()
+ return function()
+ while not self.is_closed do
+ return self:receive()
+ end
+ end
+end
+
return {
Condvar = Condvar,
Semaphore = Semaphore,
OneShotChannel = OneShotChannel,
+ Channel = Channel,
}
diff --git a/tests/mason-core/async/async_spec.lua b/tests/mason-core/async/async_spec.lua
index 29ffd946..61eeeb1b 100644
--- a/tests/mason-core/async/async_spec.lua
+++ b/tests/mason-core/async/async_spec.lua
@@ -311,3 +311,58 @@ describe("async :: OneShotChannel", function()
assert.equals(42, channel:receive())
end)
end)
+
+describe("async :: Channel", function()
+ local Channel = control.Channel
+
+ it("should suspend send until buffer is received", function()
+ local channel = Channel.new()
+ spy.on(channel, "send")
+ local guard = spy.new()
+
+ a.run(function()
+ channel:send "message"
+ guard()
+ channel:send "another message"
+ end, function() end)
+
+ assert.spy(channel.send).was_called(1)
+ assert.spy(channel.send).was_called_with(match.is_ref(channel), "message")
+ assert.spy(guard).was_not_called()
+ end)
+
+ it("should send subsequent messages after they're received", function()
+ local channel = Channel.new()
+ spy.on(channel, "send")
+
+ a.run(function()
+ channel:send "message"
+ channel:send "another message"
+ end, function() end)
+
+ local value = channel:receive()
+ assert.equals(value, "message")
+
+ assert.spy(channel.send).was_called(2)
+ assert.spy(channel.send).was_called_with(match.is_ref(channel), "message")
+ assert.spy(channel.send).was_called_with(match.is_ref(channel), "another message")
+ end)
+
+ it("should suspend receive until message is sent", function()
+ local channel = Channel.new()
+
+ a.run(function()
+ a.sleep(100)
+ channel:send "hello world"
+ end, function() end)
+
+ local start = timestamp()
+ local value = a.run_blocking(function()
+ return channel:receive()
+ end)
+ local stop = timestamp()
+
+ assert.is_true((stop - start) > 80)
+ assert.equals(value, "hello world")
+ end)
+end)