diff options
| -rw-r--r-- | lua/nvim-treesitter/install.lua | 90 |
1 files changed, 41 insertions, 49 deletions
diff --git a/lua/nvim-treesitter/install.lua b/lua/nvim-treesitter/install.lua index 336f83717..c179600b6 100644 --- a/lua/nvim-treesitter/install.lua +++ b/lua/nvim-treesitter/install.lua @@ -40,7 +40,7 @@ end ---@async ---@param path string -local function rmdir(path) +local function rmpath(path) local stat = uv.fs_lstat(path) if not stat then return @@ -48,7 +48,7 @@ local function rmdir(path) if stat.type == 'directory' then for file in fs.dir(path) do - rmdir(fs.joinpath(path, file)) + rmpath(fs.joinpath(path, file)) end return uv_rmdir(path) else @@ -61,16 +61,16 @@ local INSTALL_TIMEOUT = 60000 --- @async --- @param max_jobs integer ---- @param task_funs async.TaskFun[] -local function join(max_jobs, task_funs) - if #task_funs == 0 then +--- @param tasks async.TaskFun[] +local function join(max_jobs, tasks) + if #tasks == 0 then return end - max_jobs = math.min(max_jobs, #task_funs) + max_jobs = math.min(max_jobs, #tasks) - local remaining = { select(max_jobs + 1, unpack(task_funs)) } - local to_go = #task_funs + local remaining = { select(max_jobs + 1, unpack(tasks)) } + local to_go = #tasks a.await(1, function(finish) local function cb() @@ -84,7 +84,7 @@ local function join(max_jobs, task_funs) end for i = 1, max_jobs do - task_funs[i]():await(cb) + tasks[i]():await(cb) end end) end @@ -108,26 +108,6 @@ local function system(cmd, opts) return r end ----@async ----@param url string ----@param output string ----@return string? err -local function download_file(url, output) - local r = system({ - 'curl', - '--silent', - '--fail', - '--show-error', - '-L', -- follow redirects - url, - '--output', - output, - }) - if r.code > 0 then - return r.stderr - end -end - local M = {} --- @@ -220,7 +200,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu local tmp = output_dir .. '-tmp' - rmdir(tmp) + rmpath(tmp) a.schedule() url = url:gsub('.git$', '') @@ -232,9 +212,18 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu do -- Download tarball logger:info('Downloading %s...', project_name) - local err = download_file(target, tarball_path) - if err then - return logger:error('Error during download: %s', err) + local r = system({ + 'curl', + '--silent', + '--fail', + '--show-error', + '-L', -- follow redirects + target, + '--output', + tarball_path, + }) + if r.code > 0 then + return logger:error('Error during download: %s', r.stderr) end end @@ -280,7 +269,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu end end - rmdir(tmp) + rmpath(tmp) a.schedule() end @@ -343,7 +332,7 @@ end ---@param query_dir string ---@return string? err local function do_copy_queries(logger, query_src, query_dir) - rmdir(query_dir) + rmpath(query_dir) local err = uv_mkdir(query_dir, 493) -- tonumber('755', 8) for f in fs.dir(query_src) do @@ -374,7 +363,7 @@ local function try_install_lang(lang, cache_dir, install_dir, generate) compile_location = fs.normalize(repo.path) else local project_dir = fs.joinpath(cache_dir, project_name) - rmdir(project_dir) + rmpath(project_dir) revision = revision or repo.branch or 'main' @@ -443,7 +432,7 @@ local function try_install_lang(lang, cache_dir, install_dir, generate) -- clean up if repo and not repo.path then - rmdir(fs.joinpath(cache_dir, project_name)) + rmpath(fs.joinpath(cache_dir, project_name)) a.schedule() end @@ -517,10 +506,10 @@ local function install(languages, options) local install_dir = config.get_install_dir('parser') - local task_funs = {} ---@type async.TaskFun[] + local tasks = {} ---@type async.TaskFun[] local done = 0 for _, lang in ipairs(languages) do - task_funs[#task_funs + 1] = a.async(function() + tasks[#tasks + 1] = a.async(--[[@async]] function() a.schedule() local status = install_lang(lang, cache_dir, install_dir, options.force, options.generate) if status ~= 'failed' then @@ -529,16 +518,17 @@ local function install(languages, options) end) end - join(options and options.max_jobs or MAX_JOBS, task_funs) - if #task_funs > 1 then + join(options and options.max_jobs or MAX_JOBS, tasks) + if #tasks > 1 then a.schedule() if options and options.summary then - log.info('Installed %d/%d languages', done, #task_funs) + log.info('Installed %d/%d languages', done, #tasks) end end - return done == #task_funs + return done == #tasks end +---@async ---@param languages string[]|string ---@param options? InstallOptions M.install = a.async(function(languages, options) @@ -547,6 +537,7 @@ M.install = a.async(function(languages, options) return install(languages, options) end) +---@async ---@param languages? string[]|string ---@param options? InstallOptions M.update = a.async(function(languages, options) @@ -597,7 +588,7 @@ local function uninstall_lang(logger, lang, parser, queries) if stat.type == 'link' then qerr = uv_unlink(queries) else - qerr = rmdir(queries) + qerr = rmpath(queries) end a.schedule() if qerr then @@ -608,6 +599,7 @@ local function uninstall_lang(logger, lang, parser, queries) logger:info('Language uninstalled') end +---@async ---@param languages string[]|string ---@param options? InstallOptions M.uninstall = a.async(function(languages, options) @@ -618,7 +610,7 @@ M.uninstall = a.async(function(languages, options) local query_dir = config.get_install_dir('queries') local installed = config.get_installed() - local task_funs = {} ---@type async.TaskFun[] + local tasks = {} ---@type async.TaskFun[] local done = 0 for _, lang in ipairs(languages) do local logger = log.new('uninstall/' .. lang) @@ -627,7 +619,7 @@ M.uninstall = a.async(function(languages, options) else local parser = fs.joinpath(parser_dir, lang) .. '.so' local queries = fs.joinpath(query_dir, lang) - task_funs[#task_funs + 1] = a.async(function() + tasks[#tasks + 1] = a.async(--[[@async]] function() local err = uninstall_lang(logger, lang, parser, queries) if not err then done = done + 1 @@ -636,11 +628,11 @@ M.uninstall = a.async(function(languages, options) end end - join(MAX_JOBS, task_funs) - if #task_funs > 1 then + join(MAX_JOBS, tasks) + if #tasks > 1 then a.schedule() if options and options.summary then - log.info('Uninstalled %d/%d languages', done, #task_funs) + log.info('Uninstalled %d/%d languages', done, #tasks) end end end) |
