diff options
| -rw-r--r-- | lua/nvim-treesitter/locals.lua | 59 |
1 files changed, 49 insertions, 10 deletions
diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua index d30953348..b52d42b3f 100644 --- a/lua/nvim-treesitter/locals.lua +++ b/lua/nvim-treesitter/locals.lua @@ -5,16 +5,30 @@ local api = vim.api local queries = require'nvim-treesitter.query' local parsers = require'nvim-treesitter.parsers' +local utils = require'nvim-treesitter.utils' -local M = { - locals = {} +local default_dict = { + __index = function(table, key) + local exists = rawget(table, key) + if not exists then + table[key] = {} + end + return rawget(table, key) + end } -function M.collect_locals(bufnr) +local query_cache = {} +setmetatable(query_cache, default_dict) + +local M = {} + +function M.collect_locals(bufnr, query_kind) + query_kind = query_kind or 'locals' + local lang = parsers.ft_to_lang(api.nvim_buf_get_option(bufnr, "ft")) if not lang then return end - local query = queries.get_query(lang, 'locals') + local query = queries.get_query(lang, query_kind) if not query then return end local parser = parsers.get_parser(bufnr, lang) @@ -32,18 +46,20 @@ function M.collect_locals(bufnr) return locals end -local function update_cached_locals(bufnr, changed_tick) - M.locals[bufnr] = {tick=changed_tick, cache=( M.collect_locals(bufnr) or {} )} +local function update_cached_locals(bufnr, changed_tick, query_kind) + query_cache[query_kind][bufnr] = {tick=changed_tick, cache=( M.collect_locals(bufnr, query_kind) or {} )} end -function M.get_locals(bufnr) +function M.get_locals(bufnr, query_kind) + query_kind = query_kind or 'locals' + local bufnr = bufnr or api.nvim_get_current_buf() - local cached_local = M.locals[bufnr] + local cached_local = query_cache[query_kind][bufnr] if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then - update_cached_locals(bufnr,api.nvim_buf_get_changedtick(bufnr)) + update_cached_locals(bufnr,api.nvim_buf_get_changedtick(bufnr), query_kind) end - return M.locals[bufnr].cache + return query_cache[query_kind][bufnr].cache end function M.get_definitions(bufnr) @@ -88,4 +104,27 @@ function M.get_references(bufnr) return refs end +--- Return all nodes in locals corresponding to a specific capture (like @scope, @reference) +-- Works like M.get_references or M.get_scopes except you can choose the capture +-- Can also be a nested capture like @definition.function to get all nodes defining a function +function M.get_capture_matches(bufnr, capture_string, query_kind) + if not string.sub(capture_string, 1,2) == '@' then + print('capture_string must start with "@"') + return + end + + --remove leading "@" + capture_string = string.sub(capture_string, 2) + + local matches = {} + for _, match in pairs(M.get_locals(bufnr, query_kind)) do + local insert = utils.get_at_path(match, capture_string..'.node') + + if insert then + table.insert(matches, insert) + end + end + return matches +end + return M |
