aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/query.lua37
1 files changed, 25 insertions, 12 deletions
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua
index 1d8099ec9..0d2e8cb3d 100644
--- a/lua/nvim-treesitter/query.lua
+++ b/lua/nvim-treesitter/query.lua
@@ -195,22 +195,35 @@ end
--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
-- Works like M.get_references or M.get_scopes except you can choose the capture
--- Can also be a nested capture like @definition.function to get all nodes defining a function
-function M.get_capture_matches(bufnr, capture_string, query_group, root, lang)
- if not string.sub(capture_string, 1, 1) == "@" then
- print 'capture_string must start with "@"'
- return
+-- Can also be a nested capture like @definition.function to get all nodes defining a function.
+--
+-- @param bufnr the buffer
+-- @param captures a single string or a list of strings
+-- @param query_group the name of query group (highlights or injections for example)
+-- @param root (optional) node from where to start the search
+-- @param lang (optional) the language from where to get the captures.
+-- Root nodes can have several languages.
+function M.get_capture_matches(bufnr, captures, query_group, root, lang)
+ if type(captures) == "string" then
+ captures = { captures }
+ end
+ local strip_captures = {}
+ for i, capture in ipairs(captures) do
+ if not capture:sub(1, 1) == "@" then
+ error 'Captures must start with "@"'
+ return
+ end
+ -- Remove leading "@".
+ strip_captures[i] = capture:sub(2)
end
-
- --remove leading "@"
- capture_string = string.sub(capture_string, 2)
local matches = {}
for match in M.iter_group_results(bufnr, query_group, root, lang) do
- local insert = utils.get_at_path(match, capture_string)
-
- if insert then
- table.insert(matches, insert)
+ for _, capture in ipairs(strip_captures) do
+ local insert = utils.get_at_path(match, capture)
+ if insert then
+ table.insert(matches, insert)
+ end
end
end
return matches