aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/indent.lua39
1 files changed, 20 insertions, 19 deletions
diff --git a/lua/nvim-treesitter/indent.lua b/lua/nvim-treesitter/indent.lua
index 333bdb4eb..5f346d643 100644
--- a/lua/nvim-treesitter/indent.lua
+++ b/lua/nvim-treesitter/indent.lua
@@ -22,20 +22,20 @@ local function node_fmt(node)
end
local get_indents = utils.memoize_by_buf_tick(function(bufnr)
- local indents = queries.get_capture_matches(bufnr, '@indent.node', 'indents') or {}
- local branches = queries.get_capture_matches(bufnr, '@branch.node', 'indents') or {}
-
- local indents_map = {}
- for _, node in ipairs(indents) do
- indents_map[tostring(node)] = true
- end
-
- local branches_map = {}
- for _, node in ipairs(branches) do
- branches_map[tostring(node)] = true
+ local get_map = function(capture)
+ local matches = queries.get_capture_matches(bufnr, capture, 'indents') or {}
+ local map = {}
+ for _, node in ipairs(matches) do
+ map[tostring(node)] = true
+ end
+ return map
end
- return { indents = indents_map, branches = branches_map }
+ return {
+ indents = get_map('@indent.node'),
+ branches = get_map('@branch.node'),
+ returns = get_map('@return.node'),
+ }
end)
local function get_indent_size()
@@ -46,9 +46,7 @@ function M.get_indent(lnum)
local parser = parsers.get_parser()
if not parser or not lnum then return -1 end
- local indent_queries = get_indents(vim.api.nvim_get_current_buf())
- local indents = indent_queries.indents
- local branches = indent_queries.branches
+ local q = get_indents(vim.api.nvim_get_current_buf())
local root = parser:parse()[1]:root()
local node = get_node_at_line(root, lnum-1)
@@ -64,7 +62,10 @@ function M.get_indent(lnum)
local prev_node = get_node_at_line(root, prevnonblank-1)
-- we take that node only if ends before lnum, or else we would get incorrect indent
-- on <cr> in positions like e.g. `{|}` in C (| denotes cursor position)
- if prev_node and (prev_node:end_() < lnum-1) then
+ local use_prev = prev_node and (prev_node:end_() < lnum-1)
+ -- nodes can be marked @return to prevent using them
+ use_prev = use_prev and not q.returns[node_fmt(prev_node)]
+ if use_prev then
node = prev_node
end
end
@@ -75,12 +76,12 @@ function M.get_indent(lnum)
if not node then
local wrapper = root:descendant_for_range(lnum-1, 0, lnum-1, -1)
node = wrapper:child(0) or wrapper
- if indents[node_fmt(wrapper)] ~= nil and wrapper ~= root then
+ if q.indents[node_fmt(wrapper)] ~= nil and wrapper ~= root then
indent = indent_size
end
end
- while node and branches[node_fmt(node)] do
+ while node and q.branches[node_fmt(node)] do
node = node:parent()
end
@@ -89,7 +90,7 @@ function M.get_indent(lnum)
while node do
node = node:parent()
local row = node and node:start() or prev_row
- if indents[node_fmt(node)] and prev_row ~= row then
+ if q.indents[node_fmt(node)] and prev_row ~= row then
indent = indent + indent_size
prev_row = row
end