aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
authorSteven Sojka <Steven.Sojka@tdameritrade.com>2020-07-31 08:01:07 -0500
committerSteven Sojka <Steven.Sojka@tdameritrade.com>2020-07-31 11:54:09 -0500
commit67509d4627bf67b477b501e2789b0d27ee21dcfb (patch)
treee5f6fa987d1591803f3994f68cc89e2354392e0d /lua
parentDocs: rework readme (diff)
downloadnvim-treesitter-67509d4627bf67b477b501e2789b0d27ee21dcfb.tar
nvim-treesitter-67509d4627bf67b477b501e2789b0d27ee21dcfb.tar.gz
nvim-treesitter-67509d4627bf67b477b501e2789b0d27ee21dcfb.tar.bz2
nvim-treesitter-67509d4627bf67b477b501e2789b0d27ee21dcfb.tar.lz
nvim-treesitter-67509d4627bf67b477b501e2789b0d27ee21dcfb.tar.xz
nvim-treesitter-67509d4627bf67b477b501e2789b0d27ee21dcfb.tar.zst
nvim-treesitter-67509d4627bf67b477b501e2789b0d27ee21dcfb.zip
feat(predicates): add adjacent predicate
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/query_predicates.lua42
1 files changed, 41 insertions, 1 deletions
diff --git a/lua/nvim-treesitter/query_predicates.lua b/lua/nvim-treesitter/query_predicates.lua
index 30a27af27..e010e6dd8 100644
--- a/lua/nvim-treesitter/query_predicates.lua
+++ b/lua/nvim-treesitter/query_predicates.lua
@@ -1,4 +1,5 @@
local utils = require'nvim-treesitter.utils'
+local ts_utils = require'nvim-treesitter.ts_utils'
local M = {}
@@ -12,6 +13,42 @@ local function get_node(query, match, pred_item)
return utils.get_at_path(match, query.captures[pred_item]..'.node')
end
+local function create_adjacent_predicate(match_successive_nodes)
+ return function(query, match, pred)
+ if #pred < 3 then error("adjacent? must have at least two arguments!") end
+ local node = get_node(query, match, pred[2])
+ if not node then return true end
+
+ local adjacent_types = {unpack(pred, 3)}
+ local adjacent_node = ts_utils.get_next_node(node)
+
+ if match_successive_nodes then
+ -- Move to the last node in a series that doesn't match the node type
+ -- and use that node to compare with.
+ while adjacent_node and adjacent_node:type() == node:type() do
+ node = adjacent_node
+ adjacent_node = ts_utils.get_next_node(node)
+ end
+ end
+
+ if not adjacent_node then return false end
+
+ for _, adjacent_type in ipairs(adjacent_types) do
+ if type(adjacent_type) == "number" then
+ if get_node(query, match, adjacent_type) == adjacent_node then
+ return true
+ end
+ elseif type(adjacent_type) == "string" then
+ if adjacent_node:type() == adjacent_type then
+ return true
+ end
+ end
+ end
+
+ return false
+ end
+end
+
function M.check_predicate(query, match, pred)
local check_function = M[pred[1]]
if check_function then
@@ -55,7 +92,7 @@ end
end
end
-M['has_ancestor?'] = function(query, match, pred)
+M['has-ancestor?'] = function(query, match, pred)
if #pred ~= 3 then error("has-ancestor? must have exactly two arguments!") end
local node = get_node(query, match, pred[2])
local ancestor_type = pred[3]
@@ -71,4 +108,7 @@ M['has_ancestor?'] = function(query, match, pred)
return false
end
+M['adjacent?'] = create_adjacent_predicate(false)
+M['adjacent-block?'] = create_adjacent_predicate(true)
+
return M