aboutsummaryrefslogtreecommitdiffstats
path: root/lua/mason-core/lock/restore.lua
blob: f38c49c1c564502c4ea087e7aff91e021d40e984 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
local EventEmitter = require "mason-core.EventEmitter"
local FileRegistrySource = require "mason-registry.sources.file"
local GitHubRegistrySource = require "mason-registry.sources.github"
local LuaRegistrySource = require "mason-registry.sources.lua"
local Optional = require "mason-core.optional"
local Registry = require "mason-registry"
local Result = require "mason-core.result"
local _ = require "mason-core.functional"
local a = require "mason-core.async"

local providers = {
    ---@param registry_info LockfileRegistryGitHub
    ---@return GitHubRegistrySource
    github = function(registry_info, cache)
        local version, checksum = unpack(_.split("~", registry_info.integrity))
        local cache_key = registry_info.name .. registry_info.namespace .. version
        return cache(cache_key, function()
            return GitHubRegistrySource:new {
                id = ("%s/%s"):format(registry_info.namespace, registry_info.name),
                name = registry_info.name,
                namespace = registry_info.namespace,
                version = version,
            }
        end)
    end,
    ---@param registry_info LockfileRegistryFile
    ---@return FileRegistrySource
    file = function(registry_info, cache)
        return cache(registry_info.path, function()
            return FileRegistrySource:new {
                id = registry_info.path,
                path = registry_info.path,
            }
        end)
    end,
    ---@param registry_info LockfileRegistryLua
    ---@return LuaRegistrySource
    lua = function(registry_info, cache)
        return cache(registry_info.mod, function()
            return LuaRegistrySource:new {
                id = registry_info.mod,
                mod = registry_info.mod,
            }
        end)
    end,
}

---@class LockfileInstallGroup
---@field packages table<Package, LockfilePackage>
---@field unavailable_packages table<string, { error: string, metadata: LockfilePackage }>
---@field handles table<Package, InstallHandle>
---@field installed { completed: Package[], failed: Package[] }
local LockfileInstallGroup = {}
LockfileInstallGroup.__index = LockfileInstallGroup

---@param packages table<Package, LockfilePackage>
---@param unavailable_packages table<string, { error: string, metadata: LockfilePackage }>
function LockfileInstallGroup:new(packages, unavailable_packages)
    ---@type LockfileInstallGroup
    local instance = {}
    setmetatable(instance, self)
    instance.packages = packages
    instance.unavailable_packages = unavailable_packages
    instance.handles = {}
    instance.installed = {
        completed = {},
        failed = {},
    }
    return instance
end

---@alias LockfileInstallHandlers { on_install?: fun(pkg: Package, metadata: LockfilePackage), on_handle?: fun(handle: InstallHandle), on_completion?: fun(pkg: Package, success: boolean, result: any) }

---@async
---@param handlers LockfileInstallHandlers
function LockfileInstallGroup:install(handlers)
    local thunks = {}
    ---@type Package[]
    local sorted_packages = _.sort_by(_.prop "name", _.keys(self.packages))
    for __, pkg in ipairs(sorted_packages) do
        table.insert(thunks, function()
            local metadata = self.packages[pkg]
            a.wait(function(resolve)
                self.handles[pkg] = pkg:install({
                    no_lock = true,
                    version = metadata.version,
                }, function(success, result)
                    if success then
                        table.insert(self.installed.completed, pkg)
                    else
                        table.insert(self.installed.failed, pkg)
                    end
                    if handlers and handlers.on_completion then
                        handlers.on_completion(pkg, success, result)
                    end
                    resolve()
                end)
                if handlers and handlers.on_handle then
                    handlers.on_handle(self.handles[pkg])
                end
                if handlers and handlers.on_install then
                    handlers.on_install(pkg, metadata)
                end
            end)
        end)
    end
    a.wait_all(thunks)
end

local RegistryCache = {
    __index = function(self, root_key)
        self[root_key] = {}
        setmetatable(self[root_key], {
            __call = function(cache, key, init)
                if not cache[key] then
                    cache[key] = init()
                end
                return cache[key]
            end,
        })
        return self[root_key]
    end,
}

---@class LockfileRestore
---@field lockfile Lockfile
---@field registry_cache table
local LockfileRestore = {}
LockfileRestore.__index = LockfileRestore

---@param lockfile Lockfile
function LockfileRestore:new(lockfile)
    ---@type LockfileRestore
    local instance = {}
    setmetatable(instance, self)
    instance.lockfile = lockfile
    instance.registry_cache = setmetatable({}, RegistryCache)
    return instance
end

function LockfileRestore:get_package_count()
    return _.size(self.lockfile.body)
end

function LockfileRestore:get_packages()
    return self.lockfile.body
end

---@param registry_info LockfileRegistry
---@return RegistrySource
function LockfileRestore:get_registry(registry_info)
    if registry_info.proto == "github" then
        return providers.github(registry_info, self.registry_cache.github)
    elseif registry_info.proto == "lua" then
        return providers.lua(registry_info, self.registry_cache.lua)
    elseif registry_info.proto == "file" then
        return providers.file(registry_info, self.registry_cache.file)
    end
end

---@async
---@param pkg_name string
---@param metadata LockfilePackage
function LockfileRestore:resolve_package(pkg_name, metadata)
    return Result.try(function(try)
        local ephemeral_registry = self:get_registry(metadata.registry)
        if not ephemeral_registry:is_installed() then
            if metadata.registry.proto == "github" then
                local _, checksum = unpack(_.split("~", metadata.registry.integrity))
                try(ephemeral_registry:install { checksum = checksum })
            else
                try(ephemeral_registry:install())
            end
        end
        return Optional.of_nilable(ephemeral_registry:get_package(pkg_name)):ok_or "Unable to find package."
    end)
end

---@async
function LockfileRestore:prepare()
    local available = {}
    local unavailable = {}
    for pkg_name, metadata in pairs(self.lockfile.body) do
        self:resolve_package(pkg_name, metadata)
            :on_success(function(pkg)
                available[pkg] = metadata
            end)
            :on_failure(function(err)
                unavailable[pkg_name] = {
                    error = err,
                    metadata = metadata,
                }
            end)
    end
    return LockfileInstallGroup:new(available, unavailable)
end

function LockfileRestore:cleanup()
    for _, registry in pairs(self.registry_cache.github) do
        if not Registry.sources:contains(registry) then
            registry:uninstall()
        end
    end
end

return LockfileRestore