diff options
| author | Munif Tanjim <hello@muniftanjim.dev> | 2022-01-19 02:22:29 +0600 |
|---|---|---|
| committer | Christian Clason <christian.clason@uni-due.de> | 2022-01-21 10:51:51 +0100 |
| commit | 85140a7a479c30b872fd562b299a4afefc58576f (patch) | |
| tree | bd1bb6d289f3acacb40b5c1e00eb37c952dcaf43 /lua | |
| parent | feat: rewrite indent module (diff) | |
| download | nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.gz nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.bz2 nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.lz nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.xz nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.tar.zst nvim-treesitter-85140a7a479c30b872fd562b299a4afefc58576f.zip | |
feat(indent): use native Query:iter_captures
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/nvim-treesitter/indent.lua | 35 | ||||
| -rw-r--r-- | lua/nvim-treesitter/query.lua | 120 |
2 files changed, 95 insertions, 60 deletions
diff --git a/lua/nvim-treesitter/indent.lua b/lua/nvim-treesitter/indent.lua index 7f7891261..da3b73162 100644 --- a/lua/nvim-treesitter/indent.lua +++ b/lua/nvim-treesitter/indent.lua @@ -5,24 +5,19 @@ local tsutils = require "nvim-treesitter.ts_utils" local M = {} local get_indents = tsutils.memoize_by_buf_tick(function(bufnr, root, lang) - local get_map = function(capture) - local matches = queries.get_capture_matches(bufnr, capture, "indents", root, lang) or {} - local map = {} - for _, node in ipairs(matches) do - map[node:id()] = true - end - return map + local map = { + auto = {}, + indent = {}, + dedent = {}, + branch = {}, + ignore = {}, + } + + for name, node in queries.iter_captures(bufnr, "indents", root, lang) do + map[name][node:id()] = true end - return { - autos = get_map "@auto.node", - indents = get_map "@indent.node", - dedents = get_map "@dedent.node", - branches = get_map "@branch.node", - ignores = get_map "@ignore.node", - aligned_indents = get_map "@aligned_indent.node", - hanging_indents = get_map "@hanging_indent.node", - } + return map end, { -- Memoize by bufnr and lang together. key = function(bufnr, root, lang) @@ -69,14 +64,14 @@ function M.get_indent(lnum) while node do -- do 'autoindent' if not marked as @indent - if not q.indents[node:id()] and q.autos[node:id()] and node:start() < lnum - 1 and lnum - 1 <= node:end_() then + if not q.indent[node:id()] and q.auto[node:id()] and node:start() < lnum - 1 and lnum - 1 <= node:end_() then return -1 end -- Do not indent if we are inside an @ignore block. -- If a node spans from L1,C1 to L2,C2, we know that lines where L1 < line <= L2 would -- have their indentations contained by the node. - if not q.indents[node:id()] and q.ignores[node:id()] and node:start() < lnum - 1 and lnum - 1 <= node:end_() then + if not q.indent[node:id()] and q.ignore[node:id()] and node:start() < lnum - 1 and lnum - 1 <= node:end_() then return 0 end @@ -86,14 +81,14 @@ function M.get_indent(lnum) if not is_processed_by_row[srow] - and ((q.branches[node:id()] and srow == lnum - 1) or (q.dedents[node:id()] and srow ~= lnum - 1)) + and ((q.branch[node:id()] and srow == lnum - 1) or (q.dedent[node:id()] and srow ~= lnum - 1)) then indent = indent - indent_size is_processed = true end -- do not indent for nodes that starts-and-ends on same line and starts on target line (lnum) - if not is_processed_by_row[srow] and (q.indents[node:id()] and srow ~= erow and srow ~= lnum - 1) then + if not is_processed_by_row[srow] and (q.indent[node:id()] and srow ~= erow and srow ~= lnum - 1) then indent = indent + indent_size is_processed = true end diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index 0d2e8cb3d..7009e9f2b 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -128,6 +128,59 @@ function M.invalidate_query_file(fname) M.invalidate_query_cache(fnamemodify(fname, ":p:h:t"), fnamemodify(fname, ":t:r")) end +local function prepare_query(bufnr, query_name, root, root_lang) + local buf_lang = parsers.get_buf_lang(bufnr) + + if not buf_lang then + return + end + + local parser = parsers.get_parser(bufnr, buf_lang) + if not parser then + return + end + + if not root then + local first_tree = parser:trees()[1] + + if first_tree then + root = first_tree:root() + end + end + + if not root then + return + end + + local range = { root:range() } + + if not root_lang then + local lang_tree = parser:language_for_range(range) + + if lang_tree then + root_lang = lang_tree:lang() + end + end + + if not root_lang then + return + end + + local query = M.get_query(root_lang, query_name) + if not query then + return + end + + return query, + { + root = root, + source = bufnr, + start = range[1], + -- The end row is exclusive so we need to add 1 to it. + stop = range[3] + 1, + } +end + function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) -- A function that splits a string on '.' local function split(string) @@ -229,6 +282,31 @@ function M.get_capture_matches(bufnr, captures, query_group, root, lang) return matches end +function M.iter_captures(bufnr, query_name, root, lang) + local query, params = prepare_query(bufnr, query_name, root, lang) + if not query then + return EMPTY_ITER + end + + local iter = query:iter_captures(params.root, params.source, params.start, params.stop) + + local function wrapped_iter() + local id, node, metadata = iter() + if not id then + return + end + + local name = query.captures[id] + if string.sub(name, 1, 1) == "_" then + return wrapped_iter() + end + + return name, node, metadata + end + + return wrapped_iter +end + function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root) if string.sub(capture_string, 1, 1) == "@" then --remove leading "@" @@ -262,50 +340,12 @@ end -- @param root the root node -- @param root the root node lang, if known function M.iter_group_results(bufnr, query_group, root, root_lang) - local buf_lang = parsers.get_buf_lang(bufnr) - - if not buf_lang then - return EMPTY_ITER - end - - local parser = parsers.get_parser(bufnr, buf_lang) - if not parser then - return EMPTY_ITER - end - - if not root then - local first_tree = parser:trees()[1] - - if first_tree then - root = first_tree:root() - end - end - - if not root then - return EMPTY_ITER - end - - local range = { root:range() } - - if not root_lang then - local lang_tree = parser:language_for_range(range) - - if lang_tree then - root_lang = lang_tree:lang() - end - end - - if not root_lang then - return EMPTY_ITER - end - - local query = M.get_query(root_lang, query_group) + local query, params = prepare_query(bufnr, query_group, root, root_lang) if not query then return EMPTY_ITER end - -- The end row is exclusive so we need to add 1 to it. - return M.iter_prepared_matches(query, root, bufnr, range[1], range[3] + 1) + return M.iter_prepared_matches(query, params.root, params.source, params.start, params.stop) end function M.collect_group_results(bufnr, query_group, root, lang) |
