diff options
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/nvim-treesitter/fold.lua | 58 |
1 files changed, 45 insertions, 13 deletions
diff --git a/lua/nvim-treesitter/fold.lua b/lua/nvim-treesitter/fold.lua index 401f52606..46320b14a 100644 --- a/lua/nvim-treesitter/fold.lua +++ b/lua/nvim-treesitter/fold.lua @@ -1,28 +1,60 @@ +local api = vim.api +local utils = require'nvim-treesitter.ts_utils' +local query = require'nvim-treesitter.query' local parsers = require'nvim-treesitter.parsers' local M = {} -function M.get_fold_indic(lnum) - if not parsers.has_parser() or not lnum then return '0' end +local folds_levels = utils.memoize_by_buf_tick(function(bufnr) + local lang = parsers.get_buf_lang(bufnr) + + local matches + if query.has_fold(lang) then + matches = query.get_capture_matches(bufnr, "@fold", "fold") + elseif query.has_locals(lang) then + matches = query.get_capture_matches(bufnr, "@scope", "locals") + else + return {} + end - local function smallest_multiline_containing(node, level) - for index = 0,(node:named_child_count() -1) do - local child = node:named_child(index) - local start, _, stop, _ = child:range() + local levels_tmp = {} - if start ~= stop and start <= (lnum -1) and stop >= (lnum -1) then - return smallest_multiline_containing(child, level + 1) - end + for _, node in ipairs(matches) do + local start, _, stop, stop_col = node.node:range() + + if stop_col > 0 then + stop = stop + 1 + end + + -- This can be folded + -- Fold only multiline nodes that are not exactly the same as prevsiously met folds + if start ~= stop and not (levels_tmp[start] and levels_tmp[stop]) then + levels_tmp[start] = (levels_tmp[start] or 0) + 1 + levels_tmp[stop] = (levels_tmp[stop] or 0) - 1 end - return node, level end - local parser = parsers.get_parser() + local levels = {} + local current_level = 0 + + for lnum=0,api.nvim_buf_line_count(bufnr) do + current_level = current_level + (levels_tmp[lnum] or 0) + levels[lnum + 1] = current_level + end + + return levels +end) + +function M.get_fold_indic(lnum) + if not parsers.has_parser() or not lnum then return '0' end + + local buf = api.nvim_get_current_buf() + + local levels = folds_levels(buf) or {} - local _, level = smallest_multiline_containing(parser:parse():root(), 0) + return tostring(levels[lnum] or 0) - return tostring(level) end return M |
