diff options
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/nvim-treesitter/indent.lua | 45 |
1 files changed, 34 insertions, 11 deletions
diff --git a/lua/nvim-treesitter/indent.lua b/lua/nvim-treesitter/indent.lua index ffbb8c0b7..1f49cc26b 100644 --- a/lua/nvim-treesitter/indent.lua +++ b/lua/nvim-treesitter/indent.lua @@ -4,6 +4,7 @@ local utils = require'nvim-treesitter.ts_utils' local M = {} +-- TODO(kiyan): move this in tsutils and document it local function get_node_at_line(root, lnum) for node in root:iter_children() do local srow, _, erow = node:range() @@ -13,10 +14,11 @@ local function get_node_at_line(root, lnum) return get_node_at_line(node, lnum) end end +end - local wrapper = root:descendant_for_range(lnum, 0, lnum, -1) - local child = wrapper:child(0) - return child or wrapper +local function node_fmt(node) + if not node then return nil end + return tostring(node) end local get_indents = utils.memoize_by_buf_tick(function(bufnr) @@ -36,30 +38,51 @@ local get_indents = utils.memoize_by_buf_tick(function(bufnr) return { indents = indents_map, branches = branches_map } end) +local function get_indent_size() + return vim.bo.softtabstop < 0 and vim.bo.shiftwidth or vim.bo.tabstop +end + function M.get_indent(lnum) local parser = parsers.get_parser() if not parser or not lnum then return -1 end - local node = get_node_at_line(parser:parse()[1]:root(), lnum-1) local indent_queries = get_indents(vim.api.nvim_get_current_buf()) local indents = indent_queries.indents local branches = indent_queries.branches - if not indents then return 0 end + local root = parser:parse()[1]:root() + local node = get_node_at_line(root, lnum-1) + + local indent = 0 + local indent_size = get_indent_size() - while node and branches[tostring(node)] do + -- if we are on a new line (for instance by typing `o` or `O`) + -- we should get the node that wraps the line our cursor sits in + -- and if the node is an indent node, we should set the indent level as the indent_size + -- and we set the node as the first child of this wrapper node or the wrapper itself + if not node then + local wrapper = root:descendant_for_range(lnum, 0, lnum, -1) + node = wrapper:child(0) or wrapper + if indents[node_fmt(wrapper)] ~= nil and wrapper ~= root then + indent = indent_size + end + end + + while node and branches[node_fmt(node)] do node = node:parent() end - local ind_size = vim.bo.softtabstop < 0 and vim.bo.shiftwidth or vim.bo.tabstop - local ind = 0 + local prev_row = node:start() + while node do node = node:parent() - if indents[tostring(node)] then - ind = ind + ind_size + local row = node and node:start() or prev_row + if indents[node_fmt(node)] and prev_row ~= row then + indent = indent + indent_size + prev_row = row end end - return ind + return indent end local indent_funcs = {} |
