aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
authorMunif Tanjim <hello@muniftanjim.dev>2022-01-19 02:22:29 +0600
committerChristian Clason <christian.clason@uni-due.de>2022-01-21 10:51:51 +0100
commit85140a7a479c30b872fd562b299a4afefc58576f (patch)
treebd1bb6d289f3acacb40b5c1e00eb37c952dcaf43 /lua
parentfeat: rewrite indent module (diff)
downloadnvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar
nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.gz
nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.bz2
nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.lz
nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.xz
nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.zst
nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.zip
feat(indent): use native Query:iter_captures
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/indent.lua35
-rw-r--r--lua/nvim-treesitter/query.lua120
2 files changed, 95 insertions, 60 deletions
diff --git a/lua/nvim-treesitter/indent.lua b/lua/nvim-treesitter/indent.lua
index 7f7891261..da3b73162 100644
--- a/lua/nvim-treesitter/indent.lua
+++ b/lua/nvim-treesitter/indent.lua
@@ -5,24 +5,19 @@ local tsutils = require "nvim-treesitter.ts_utils"
local M = {}
local get_indents = tsutils.memoize_by_buf_tick(function(bufnr, root, lang)
- local get_map = function(capture)
- local matches = queries.get_capture_matches(bufnr, capture, "indents", root, lang) or {}
- local map = {}
- for _, node in ipairs(matches) do
- map[node:id()] = true
- end
- return map
+ local map = {
+ auto = {},
+ indent = {},
+ dedent = {},
+ branch = {},
+ ignore = {},
+ }
+
+ for name, node in queries.iter_captures(bufnr, "indents", root, lang) do
+ map[name][node:id()] = true
end
- return {
- autos = get_map "@auto.node",
- indents = get_map "@indent.node",
- dedents = get_map "@dedent.node",
- branches = get_map "@branch.node",
- ignores = get_map "@ignore.node",
- aligned_indents = get_map "@aligned_indent.node",
- hanging_indents = get_map "@hanging_indent.node",
- }
+ return map
end, {
-- Memoize by bufnr and lang together.
key = function(bufnr, root, lang)
@@ -69,14 +64,14 @@ function M.get_indent(lnum)
while node do
-- do 'autoindent' if not marked as @indent
- if not q.indents[node:id()] and q.autos[node:id()] and node:start() < lnum - 1 and lnum - 1 <= node:end_() then
+ if not q.indent[node:id()] and q.auto[node:id()] and node:start() < lnum - 1 and lnum - 1 <= node:end_() then
return -1
end
-- Do not indent if we are inside an @ignore block.
-- If a node spans from L1,C1 to L2,C2, we know that lines where L1 < line <= L2 would
-- have their indentations contained by the node.
- if not q.indents[node:id()] and q.ignores[node:id()] and node:start() < lnum - 1 and lnum - 1 <= node:end_() then
+ if not q.indent[node:id()] and q.ignore[node:id()] and node:start() < lnum - 1 and lnum - 1 <= node:end_() then
return 0
end
@@ -86,14 +81,14 @@ function M.get_indent(lnum)
if
not is_processed_by_row[srow]
- and ((q.branches[node:id()] and srow == lnum - 1) or (q.dedents[node:id()] and srow ~= lnum - 1))
+ and ((q.branch[node:id()] and srow == lnum - 1) or (q.dedent[node:id()] and srow ~= lnum - 1))
then
indent = indent - indent_size
is_processed = true
end
-- do not indent for nodes that starts-and-ends on same line and starts on target line (lnum)
- if not is_processed_by_row[srow] and (q.indents[node:id()] and srow ~= erow and srow ~= lnum - 1) then
+ if not is_processed_by_row[srow] and (q.indent[node:id()] and srow ~= erow and srow ~= lnum - 1) then
indent = indent + indent_size
is_processed = true
end
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua
index 0d2e8cb3d..7009e9f2b 100644
--- a/lua/nvim-treesitter/query.lua
+++ b/lua/nvim-treesitter/query.lua
@@ -128,6 +128,59 @@ function M.invalidate_query_file(fname)
M.invalidate_query_cache(fnamemodify(fname, ":p:h:t"), fnamemodify(fname, ":t:r"))
end
+local function prepare_query(bufnr, query_name, root, root_lang)
+ local buf_lang = parsers.get_buf_lang(bufnr)
+
+ if not buf_lang then
+ return
+ end
+
+ local parser = parsers.get_parser(bufnr, buf_lang)
+ if not parser then
+ return
+ end
+
+ if not root then
+ local first_tree = parser:trees()[1]
+
+ if first_tree then
+ root = first_tree:root()
+ end
+ end
+
+ if not root then
+ return
+ end
+
+ local range = { root:range() }
+
+ if not root_lang then
+ local lang_tree = parser:language_for_range(range)
+
+ if lang_tree then
+ root_lang = lang_tree:lang()
+ end
+ end
+
+ if not root_lang then
+ return
+ end
+
+ local query = M.get_query(root_lang, query_name)
+ if not query then
+ return
+ end
+
+ return query,
+ {
+ root = root,
+ source = bufnr,
+ start = range[1],
+ -- The end row is exclusive so we need to add 1 to it.
+ stop = range[3] + 1,
+ }
+end
+
function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
-- A function that splits a string on '.'
local function split(string)
@@ -229,6 +282,31 @@ function M.get_capture_matches(bufnr, captures, query_group, root, lang)
return matches
end
+function M.iter_captures(bufnr, query_name, root, lang)
+ local query, params = prepare_query(bufnr, query_name, root, lang)
+ if not query then
+ return EMPTY_ITER
+ end
+
+ local iter = query:iter_captures(params.root, params.source, params.start, params.stop)
+
+ local function wrapped_iter()
+ local id, node, metadata = iter()
+ if not id then
+ return
+ end
+
+ local name = query.captures[id]
+ if string.sub(name, 1, 1) == "_" then
+ return wrapped_iter()
+ end
+
+ return name, node, metadata
+ end
+
+ return wrapped_iter
+end
+
function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root)
if string.sub(capture_string, 1, 1) == "@" then
--remove leading "@"
@@ -262,50 +340,12 @@ end
-- @param root the root node
-- @param root the root node lang, if known
function M.iter_group_results(bufnr, query_group, root, root_lang)
- local buf_lang = parsers.get_buf_lang(bufnr)
-
- if not buf_lang then
- return EMPTY_ITER
- end
-
- local parser = parsers.get_parser(bufnr, buf_lang)
- if not parser then
- return EMPTY_ITER
- end
-
- if not root then
- local first_tree = parser:trees()[1]
-
- if first_tree then
- root = first_tree:root()
- end
- end
-
- if not root then
- return EMPTY_ITER
- end
-
- local range = { root:range() }
-
- if not root_lang then
- local lang_tree = parser:language_for_range(range)
-
- if lang_tree then
- root_lang = lang_tree:lang()
- end
- end
-
- if not root_lang then
- return EMPTY_ITER
- end
-
- local query = M.get_query(root_lang, query_group)
+ local query, params = prepare_query(bufnr, query_group, root, root_lang)
if not query then
return EMPTY_ITER
end
- -- The end row is exclusive so we need to add 1 to it.
- return M.iter_prepared_matches(query, root, bufnr, range[1], range[3] + 1)
+ return M.iter_prepared_matches(query, params.root, params.source, params.start, params.stop)
end
function M.collect_group_results(bufnr, query_group, root, lang)