From 67509d4627bf67b477b501e2789b0d27ee21dcfb Mon Sep 17 00:00:00 2001 From: Steven Sojka Date: Fri, 31 Jul 2020 08:01:07 -0500 Subject: feat(predicates): add adjacent predicate --- lua/nvim-treesitter/query_predicates.lua | 42 +++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) (limited to 'lua') diff --git a/lua/nvim-treesitter/query_predicates.lua b/lua/nvim-treesitter/query_predicates.lua index 30a27af27..e010e6dd8 100644 --- a/lua/nvim-treesitter/query_predicates.lua +++ b/lua/nvim-treesitter/query_predicates.lua @@ -1,4 +1,5 @@ local utils = require'nvim-treesitter.utils' +local ts_utils = require'nvim-treesitter.ts_utils' local M = {} @@ -12,6 +13,42 @@ local function get_node(query, match, pred_item) return utils.get_at_path(match, query.captures[pred_item]..'.node') end +local function create_adjacent_predicate(match_successive_nodes) + return function(query, match, pred) + if #pred < 3 then error("adjacent? must have at least two arguments!") end + local node = get_node(query, match, pred[2]) + if not node then return true end + + local adjacent_types = {unpack(pred, 3)} + local adjacent_node = ts_utils.get_next_node(node) + + if match_successive_nodes then + -- Move to the last node in a series that doesn't match the node type + -- and use that node to compare with. + while adjacent_node and adjacent_node:type() == node:type() do + node = adjacent_node + adjacent_node = ts_utils.get_next_node(node) + end + end + + if not adjacent_node then return false end + + for _, adjacent_type in ipairs(adjacent_types) do + if type(adjacent_type) == "number" then + if get_node(query, match, adjacent_type) == adjacent_node then + return true + end + elseif type(adjacent_type) == "string" then + if adjacent_node:type() == adjacent_type then + return true + end + end + end + + return false + end +end + function M.check_predicate(query, match, pred) local check_function = M[pred[1]] if check_function then @@ -55,7 +92,7 @@ end end end -M['has_ancestor?'] = function(query, match, pred) +M['has-ancestor?'] = function(query, match, pred) if #pred ~= 3 then error("has-ancestor? must have exactly two arguments!") end local node = get_node(query, match, pred[2]) local ancestor_type = pred[3] @@ -71,4 +108,7 @@ M['has_ancestor?'] = function(query, match, pred) return false end +M['adjacent?'] = create_adjacent_predicate(false) +M['adjacent-block?'] = create_adjacent_predicate(true) + return M -- cgit v1.2.3-70-g09d2