From ddcfc019792e67ca632b5698ea739bbff51518c0 Mon Sep 17 00:00:00 2001 From: Santos Gallegos Date: Sun, 22 Aug 2021 20:22:20 -0500 Subject: Query: allow to pass a list to get_capture_matches (#1693) --- lua/nvim-treesitter/query.lua | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index 1d8099ec9..0d2e8cb3d 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -195,22 +195,35 @@ end --- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type) -- Works like M.get_references or M.get_scopes except you can choose the capture --- Can also be a nested capture like @definition.function to get all nodes defining a function -function M.get_capture_matches(bufnr, capture_string, query_group, root, lang) - if not string.sub(capture_string, 1, 1) == "@" then - print 'capture_string must start with "@"' - return +-- Can also be a nested capture like @definition.function to get all nodes defining a function. +-- +-- @param bufnr the buffer +-- @param captures a single string or a list of strings +-- @param query_group the name of query group (highlights or injections for example) +-- @param root (optional) node from where to start the search +-- @param lang (optional) the language from where to get the captures. +-- Root nodes can have several languages. +function M.get_capture_matches(bufnr, captures, query_group, root, lang) + if type(captures) == "string" then + captures = { captures } + end + local strip_captures = {} + for i, capture in ipairs(captures) do + if not capture:sub(1, 1) == "@" then + error 'Captures must start with "@"' + return + end + -- Remove leading "@". + strip_captures[i] = capture:sub(2) end - - --remove leading "@" - capture_string = string.sub(capture_string, 2) local matches = {} for match in M.iter_group_results(bufnr, query_group, root, lang) do - local insert = utils.get_at_path(match, capture_string) - - if insert then - table.insert(matches, insert) + for _, capture in ipairs(strip_captures) do + local insert = utils.get_at_path(match, capture) + if insert then + table.insert(matches, insert) + end end end return matches -- cgit v1.2.3-70-g09d2