From 98b75150140ca37a29a8244c7606846a9cb0af3c Mon Sep 17 00:00:00 2001 From: Thomas Vigouroux Date: Tue, 11 Aug 2020 23:20:21 +0200 Subject: fix: refactor after upstream refactor --- lua/nvim-treesitter/highlight.lua | 4 +- lua/nvim-treesitter/query.lua | 9 --- lua/nvim-treesitter/query_predicates.lua | 111 +++++-------------------------- 3 files changed, 19 insertions(+), 105 deletions(-) (limited to 'lua') diff --git a/lua/nvim-treesitter/highlight.lua b/lua/nvim-treesitter/highlight.lua index ac0ec5c07..965eccbe3 100644 --- a/lua/nvim-treesitter/highlight.lua +++ b/lua/nvim-treesitter/highlight.lua @@ -9,7 +9,7 @@ local M = { highlighters = {} } -local hlmap = vim.treesitter.TSHighlighter.hl_map +local hlmap = vim.treesitter.highlighter.hl_map -- Misc hlmap.error = "TSError" @@ -79,7 +79,7 @@ function M.attach(bufnr, lang) local query = queries.get_query(lang, "highlights") if not query then return end - M.highlighters[bufnr] = ts.TSHighlighter.new(query, bufnr, lang) + M.highlighters[bufnr] = ts.highlighter.new(query, bufnr, lang) end function M.detach(bufnr) diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index 33637be3f..7463abfe4 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -2,7 +2,6 @@ local api = vim.api local ts = vim.treesitter local utils = require'nvim-treesitter.utils' local parsers = require'nvim-treesitter.parsers' -local predicates = require'nvim-treesitter.query_predicates' local M = {} @@ -158,14 +157,6 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) if pred[1] == "set!" and type(pred[2]) == "string" then insert_to_path(prepared_match, split(pred[2]), pred[3]) end - - -- predicates - if type(pred[1]) == 'string' then - if not predicates.check_predicate(query, prepared_match, pred) or - not predicates.check_negated_predicate(query, prepared_match, pred) then - return iterator() - end - end end end diff --git a/lua/nvim-treesitter/query_predicates.lua b/lua/nvim-treesitter/query_predicates.lua index e010e6dd8..67850deaf 100644 --- a/lua/nvim-treesitter/query_predicates.lua +++ b/lua/nvim-treesitter/query_predicates.lua @@ -1,100 +1,28 @@ -local utils = require'nvim-treesitter.utils' -local ts_utils = require'nvim-treesitter.ts_utils' +local query = require"vim.treesitter.query" -local M = {} - -local function get_nth_child(node, n) - if node:named_child_count() > n then - return node:named_child(n) - end -end - -local function get_node(query, match, pred_item) - return utils.get_at_path(match, query.captures[pred_item]..'.node') +local function error(str) + vim.api.nvim_err_writeln(str) end -local function create_adjacent_predicate(match_successive_nodes) - return function(query, match, pred) - if #pred < 3 then error("adjacent? must have at least two arguments!") end - local node = get_node(query, match, pred[2]) - if not node then return true end - - local adjacent_types = {unpack(pred, 3)} - local adjacent_node = ts_utils.get_next_node(node) - - if match_successive_nodes then - -- Move to the last node in a series that doesn't match the node type - -- and use that node to compare with. - while adjacent_node and adjacent_node:type() == node:type() do - node = adjacent_node - adjacent_node = ts_utils.get_next_node(node) - end - end - - if not adjacent_node then return false end - - for _, adjacent_type in ipairs(adjacent_types) do - if type(adjacent_type) == "number" then - if get_node(query, match, adjacent_type) == adjacent_node then - return true - end - elseif type(adjacent_type) == "string" then - if adjacent_node:type() == adjacent_type then - return true - end - end - end - - return false +query.add_predicate("nth?", function(match, pattern, bufnr, pred) + if #pred ~= 3 then + error("nth? must hav exactly two arguments") + return end -end -function M.check_predicate(query, match, pred) - local check_function = M[pred[1]] - if check_function then - return check_function(query, match, pred) - else - return true + local node = match[pred[2]] + local n = pred[3] - 1 + if node and node:parent() and node:named_child_count() > n then + return node:named_child(n) == node end -end -function M.check_negated_predicate(query, match, pred) - local check_function = M[string.sub(pred[1], #"not-" + 1)] - if check_function then - return not check_function(query, match, pred) - else - return true - end -end - -M['first?'] = function (query, match, pred) - if #pred ~= 2 then error("first? must have exactly one argument!") end - local node = get_node(query, match, pred[2]) - if node and node:parent() then - return get_nth_child(node:parent(), 0) == node - end -end - -M['last?'] = function (query, match, pred) - if #pred ~= 2 then error("first? must have exactly one argument!") end - local node = get_node(query, match, pred[2]) - if node and node:parent() then - local num_children = node:parent():named_child_count() - return get_nth_child(node:parent(), num_children - 1) == node - end -end + return false +end) - M['nth?'] = function(query, match, pred) - if #pred ~= 3 then error("nth? must have exactly two arguments!") end - local node = get_node(query, match, pred[2]) - if node and node:parent() then - return get_nth_child(node:parent(), pred[3] - 1) == node - end -end +query.add_predicate('has-ancestor?', function(match, pattern, bufnr, pred) + if #pred ~= 3 then error("has-ancestor? must have exactly two arguments!") return end -M['has-ancestor?'] = function(query, match, pred) - if #pred ~= 3 then error("has-ancestor? must have exactly two arguments!") end - local node = get_node(query, match, pred[2]) + local node = match[pred[2]] local ancestor_type = pred[3] if not node then return true end @@ -106,9 +34,4 @@ M['has-ancestor?'] = function(query, match, pred) node = node:parent() end return false -end - -M['adjacent?'] = create_adjacent_predicate(false) -M['adjacent-block?'] = create_adjacent_predicate(true) - -return M +end) -- cgit v1.2.3-70-g09d2