diff options
| author | Stephan Seitz <stephan.seitz@fau.de> | 2020-07-20 17:48:58 +0200 |
|---|---|---|
| committer | Thomas Vigouroux <39092278+vigoux@users.noreply.github.com> | 2020-07-27 10:15:33 +0200 |
| commit | 5462fc92cbb6e94d93a7e20d15f81f68d918f71d (patch) | |
| tree | bad50530f61969a2f2691bcd487b4343a2925aa6 /lua | |
| parent | Parsers: add reStructuredText (diff) | |
| download | nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.gz nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.bz2 nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.lz nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.xz nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.tar.zst nvim-treesitter-5462fc92cbb6e94d93a7e20d15f81f68d918f71d.zip | |
Add predicates module
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/nvim-treesitter/query.lua | 13 | ||||
| -rw-r--r-- | lua/nvim-treesitter/query_predicates.lua | 58 |
2 files changed, 70 insertions, 1 deletions
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index 896b11dad..678c7b683 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -2,6 +2,7 @@ local api = vim.api local ts = vim.treesitter local utils = require'nvim-treesitter.utils' local parsers = require'nvim-treesitter.parsers' +local predicates = require'nvim-treesitter.query_predicates' local M = {} @@ -129,7 +130,7 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) local matches = query:iter_matches(qnode, bufnr, start_row, end_row) - return function() + local function iterator() local pattern, match = matches() if pattern ~= nil then local prepared_match = {} @@ -147,15 +148,25 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) local preds = query.info.patterns[pattern] if preds then for _, pred in pairs(preds) do + -- functions if pred[1] == "set!" and type(pred[2]) == "string" then insert_to_path(prepared_match, split(pred[2]), pred[3]) end + + -- predicates + if type(pred[1]) == 'string' then + if not predicates.check_predicate(query, prepared_match, pred) or + not predicates.check_negated_predicate(query, prepared_match, pred) then + return iterator() + end + end end end return prepared_match end end + return iterator end --- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type) diff --git a/lua/nvim-treesitter/query_predicates.lua b/lua/nvim-treesitter/query_predicates.lua new file mode 100644 index 000000000..00daa50d3 --- /dev/null +++ b/lua/nvim-treesitter/query_predicates.lua @@ -0,0 +1,58 @@ +local utils = require'nvim-treesitter.utils' + +local M = {} + +local function get_nth_child(node, n) + if node:named_child_count() > n then + return node:named_child(n) + end +end + +local function get_node(query, match, pred_item) + return utils.get_at_path(match, query.captures[pred_item]..'.node') +end + +function M.check_predicate(query, match, pred) + local check_function = M[string.gsub('check_'..pred[1], "%?$", '')] + if check_function then + return check_function(query, match, pred) + else + return true + end +end + +function M.check_negated_predicate(query, match, pred) + local check_function = M[string.gsub('check_'..string.sub(pred[1], #"not-" + 1), "%?$", '')] + if check_function then + return not check_function(query, match, pred) + else + return true + end +end + +function M.check_first(query, match, pred) + if #pred ~= 2 then error("first? must have exactly one argument!") end + local node = get_node(query, match, pred[2]) + if node and node:parent() then + return get_nth_child(node:parent(), 0) == node + end +end + +function M.check_last(query, match, pred) + if #pred ~= 2 then error("first? must have exactly one argument!") end + local node = get_node(query, match, pred[2]) + if node and node:parent() then + local num_children = node:parent():named_child_count() + return get_nth_child(node:parent(), num_children - 1) == node + end +end + +function M.check_nth(query, match, pred) + if #pred ~= 3 then error("nth? must have exactly two arguments!") end + local node = get_node(query, match, pred[2]) + if node and node:parent() then + return get_nth_child(node:parent(), pred[3] - 1) == node + end +end + +return M |
