aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
authorStephan Seitz <stephan.seitz@fau.de>2020-07-20 17:48:58 +0200
committerThomas Vigouroux <39092278+vigoux@users.noreply.github.com>2020-07-27 10:15:33 +0200
commit5462fc92cbb6e94d93a7e20d15f81f68d918f71d (patch)
treebad50530f61969a2f2691bcd487b4343a2925aa6 /lua
parentParsers: add reStructuredText (diff)
downloadnvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar
nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.gz
nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.bz2
nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.lz
nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.xz
nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.zst
nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.zip
Add predicates module
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/query.lua13
-rw-r--r--lua/nvim-treesitter/query_predicates.lua58
2 files changed, 70 insertions, 1 deletions
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua
index 896b11dad..678c7b683 100644
--- a/lua/nvim-treesitter/query.lua
+++ b/lua/nvim-treesitter/query.lua
@@ -2,6 +2,7 @@ local api = vim.api
local ts = vim.treesitter
local utils = require'nvim-treesitter.utils'
local parsers = require'nvim-treesitter.parsers'
+local predicates = require'nvim-treesitter.query_predicates'
local M = {}
@@ -129,7 +130,7 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
local matches = query:iter_matches(qnode, bufnr, start_row, end_row)
- return function()
+ local function iterator()
local pattern, match = matches()
if pattern ~= nil then
local prepared_match = {}
@@ -147,15 +148,25 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
local preds = query.info.patterns[pattern]
if preds then
for _, pred in pairs(preds) do
+ -- functions
if pred[1] == "set!" and type(pred[2]) == "string" then
insert_to_path(prepared_match, split(pred[2]), pred[3])
end
+
+ -- predicates
+ if type(pred[1]) == 'string' then
+ if not predicates.check_predicate(query, prepared_match, pred) or
+ not predicates.check_negated_predicate(query, prepared_match, pred) then
+ return iterator()
+ end
+ end
end
end
return prepared_match
end
end
+ return iterator
end
--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
diff --git a/lua/nvim-treesitter/query_predicates.lua b/lua/nvim-treesitter/query_predicates.lua
new file mode 100644
index 000000000..00daa50d3
--- /dev/null
+++ b/lua/nvim-treesitter/query_predicates.lua
@@ -0,0 +1,58 @@
+local utils = require'nvim-treesitter.utils'
+
+local M = {}
+
+local function get_nth_child(node, n)
+ if node:named_child_count() > n then
+ return node:named_child(n)
+ end
+end
+
+local function get_node(query, match, pred_item)
+ return utils.get_at_path(match, query.captures[pred_item]..'.node')
+end
+
+function M.check_predicate(query, match, pred)
+ local check_function = M[string.gsub('check_'..pred[1], "%?$", '')]
+ if check_function then
+ return check_function(query, match, pred)
+ else
+ return true
+ end
+end
+
+function M.check_negated_predicate(query, match, pred)
+ local check_function = M[string.gsub('check_'..string.sub(pred[1], #"not-" + 1), "%?$", '')]
+ if check_function then
+ return not check_function(query, match, pred)
+ else
+ return true
+ end
+end
+
+function M.check_first(query, match, pred)
+ if #pred ~= 2 then error("first? must have exactly one argument!") end
+ local node = get_node(query, match, pred[2])
+ if node and node:parent() then
+ return get_nth_child(node:parent(), 0) == node
+ end
+end
+
+function M.check_last(query, match, pred)
+ if #pred ~= 2 then error("first? must have exactly one argument!") end
+ local node = get_node(query, match, pred[2])
+ if node and node:parent() then
+ local num_children = node:parent():named_child_count()
+ return get_nth_child(node:parent(), num_children - 1) == node
+ end
+end
+
+function M.check_nth(query, match, pred)
+ if #pred ~= 3 then error("nth? must have exactly two arguments!") end
+ local node = get_node(query, match, pred[2])
+ if node and node:parent() then
+ return get_nth_child(node:parent(), pred[3] - 1) == node
+ end
+end
+
+return M