diff options
| -rw-r--r-- | doc/nvim-treesitter.txt | 10 | ||||
| -rw-r--r-- | lua/nvim-treesitter/fold.lua | 24 | ||||
| -rw-r--r-- | lua/nvim-treesitter/indent.lua | 21 | ||||
| -rw-r--r-- | lua/nvim-treesitter/locals.lua | 8 | ||||
| -rw-r--r-- | lua/nvim-treesitter/parsers.lua | 2 | ||||
| -rw-r--r-- | lua/nvim-treesitter/query.lua | 85 | ||||
| -rw-r--r-- | lua/nvim-treesitter/ts_utils.lua | 87 | ||||
| -rw-r--r-- | lua/nvim-treesitter/tsrange.lua | 8 | ||||
| -rw-r--r-- | lua/nvim-treesitter/utils.lua | 12 |
9 files changed, 197 insertions, 60 deletions
diff --git a/doc/nvim-treesitter.txt b/doc/nvim-treesitter.txt index 3d084e5e6..4bc47bff7 100644 --- a/doc/nvim-treesitter.txt +++ b/doc/nvim-treesitter.txt @@ -333,12 +333,16 @@ Swaps the nodes or ranges. set `cursor_to_second` to true to move the cursor to the second node *ts_utils.memoize_by_buf_tick* -memoize_by_buf_tick(fn)~ +memoize_by_buf_tick(fn, options)~ -Cache values by bufnr tick change +Caches the return value for a function and returns the cache value if the tick +of the buffer has not changed from the previous. - `fn`: a function that takes a bufnr as argument + `fn`: a function that takes any arguments and returns a value to store. + `options?`: <table> + - `bufnr`: a function/value that extracts the bufnr from the given arguments. + - `key`: a function/value that extracts the cache key from the given arguments. `returns`: a function to call with bufnr as argument to retrieve the value from the cache diff --git a/lua/nvim-treesitter/fold.lua b/lua/nvim-treesitter/fold.lua index faaf5f542..d1416ef4b 100644 --- a/lua/nvim-treesitter/fold.lua +++ b/lua/nvim-treesitter/fold.lua @@ -1,5 +1,5 @@ local api = vim.api -local utils = require'nvim-treesitter.ts_utils' +local tsutils = require'nvim-treesitter.ts_utils' local query = require'nvim-treesitter.query' local parsers = require'nvim-treesitter.parsers' @@ -7,18 +7,19 @@ local M = {} -- This is cached on buf tick to avoid computing that multiple times -- Especially not for every line in the file when `zx` is hit -local folds_levels = utils.memoize_by_buf_tick(function(bufnr) - local lang = parsers.get_buf_lang(bufnr) +local folds_levels = tsutils.memoize_by_buf_tick(function(bufnr) local max_fold_level = api.nvim_win_get_option(0, 'foldnestmax') + local parser = parsers.get_parser(bufnr) - local matches - if query.has_folds(lang) then - matches = query.get_capture_matches(bufnr, "@fold", "folds") - elseif query.has_locals(lang) then - matches = query.get_capture_matches(bufnr, "@scope", "locals") - else - return {} - end + if not parser then return {} end + + local matches = query.get_capture_matches_recursively(bufnr, function(lang) + if query.has_folds(lang) then + return "@fold", "folds" + elseif query.has_locals(lang) then + return "@scope", "locals" + end + end) local levels_tmp = {} @@ -35,7 +36,6 @@ local folds_levels = utils.memoize_by_buf_tick(function(bufnr) levels_tmp[start] = (levels_tmp[start] or 0) + 1 levels_tmp[stop] = (levels_tmp[stop] or 0) - 1 end - end local levels = {} diff --git a/lua/nvim-treesitter/indent.lua b/lua/nvim-treesitter/indent.lua index cdbf66489..d0e71b4c6 100644 --- a/lua/nvim-treesitter/indent.lua +++ b/lua/nvim-treesitter/indent.lua @@ -1,6 +1,6 @@ local parsers = require'nvim-treesitter.parsers' local queries = require'nvim-treesitter.query' -local utils = require'nvim-treesitter.ts_utils' +local tsutils = require'nvim-treesitter.ts_utils' local M = {} @@ -21,9 +21,9 @@ local function node_fmt(node) return tostring(node) end -local get_indents = utils.memoize_by_buf_tick(function(bufnr) +local get_indents = tsutils.memoize_by_buf_tick(function(bufnr, root, lang) local get_map = function(capture) - local matches = queries.get_capture_matches(bufnr, capture, 'indents') or {} + local matches = queries.get_capture_matches(bufnr, capture, 'indents', root, lang) or {} local map = {} for _, node in ipairs(matches) do map[tostring(node)] = true @@ -37,14 +37,23 @@ local get_indents = utils.memoize_by_buf_tick(function(bufnr) returns = get_map('@return.node'), ignores = get_map('@ignore.node'), } -end) +end, { + -- Memoize by bufnr and lang together. + key = function(bufnr, _, lang) + return tostring(bufnr) .. '_' .. lang + end +}) function M.get_indent(lnum) local parser = parsers.get_parser() if not parser or not lnum then return -1 end - local q = get_indents(vim.api.nvim_get_current_buf()) - local root = parser:parse()[1]:root() + local root, _, lang_tree = tsutils.get_root_for_position(lnum, 0, parser) + + -- Not likely, but just in case... + if not root then return 0 end + + local q = get_indents(vim.api.nvim_get_current_buf(), root, lang_tree:lang()) local node = get_node_at_line(root, lnum-1) local indent = 0 diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua index 01cacc6e5..17982340d 100644 --- a/lua/nvim-treesitter/locals.lua +++ b/lua/nvim-treesitter/locals.lua @@ -3,7 +3,6 @@ -- its the way nvim-treesitter uses to "understand" the code local queries = require'nvim-treesitter.query' -local parsers = require'nvim-treesitter.parsers' local ts_utils = require'nvim-treesitter.ts_utils' local api = vim.api @@ -91,13 +90,12 @@ end --- Iterates over a nodes scopes moving from the bottom up function M.iter_scope_tree(node, bufnr) local last_node = node - return function() if not last_node then return end - local scope = M.containing_scope(last_node, bufnr, false) or parsers.get_tree_root(bufnr) + local scope = M.containing_scope(last_node, bufnr, false) or ts_utils.get_root_for_node(node) last_node = scope:parent() @@ -222,7 +220,7 @@ function M.find_definition(node, bufnr) end end - return node, parsers.get_tree_root(bufnr), nil + return node, ts_utils.get_root_for_node(node), nil end -- Finds usages of a node in a given scope. @@ -235,7 +233,7 @@ function M.find_usages(node, scope_node, bufnr) if not node_text or #node_text < 1 then return {} end - local scope_node = scope_node or parsers.get_parser(bufnr):parse()[1]:root() + local scope_node = scope_node or ts_utils.get_root_for_node(node) local usages = {} for match in M.iter_locals(bufnr, scope_node) do diff --git a/lua/nvim-treesitter/parsers.lua b/lua/nvim-treesitter/parsers.lua index 9bac6a8f3..a9d779ef2 100644 --- a/lua/nvim-treesitter/parsers.lua +++ b/lua/nvim-treesitter/parsers.lua @@ -584,6 +584,8 @@ function M.get_parser(bufnr, lang) end end +-- @deprecated This is only kept for legacy purposes. +-- All root nodes should be accounted for. function M.get_tree_root(bufnr) local bufnr = bufnr or api.nvim_get_current_buf() diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index df3e70805..e6683139e 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -7,6 +7,8 @@ local caching = require'nvim-treesitter.caching' local M = {} +local EMPTY_ITER = function() end + M.built_in_query_groups = {'highlights', 'locals', 'folds', 'indents'} -- Creates a function that checks whether a given query exists @@ -166,7 +168,7 @@ end --- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type) -- Works like M.get_references or M.get_scopes except you can choose the capture -- Can also be a nested capture like @definition.function to get all nodes defining a function -function M.get_capture_matches(bufnr, capture_string, query_group) +function M.get_capture_matches(bufnr, capture_string, query_group, root, lang) if not string.sub(capture_string, 1,2) == '@' then print('capture_string must start with "@"') return @@ -176,7 +178,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group) capture_string = string.sub(capture_string, 2) local matches = {} - for match in M.iter_group_results(bufnr, query_group) do + for match in M.iter_group_results(bufnr, query_group, root, lang) do local insert = utils.get_at_path(match, capture_string) if insert then @@ -186,7 +188,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group) return matches end -function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function) +function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root) if not string.sub(capture_string, 1,2) == '@' then api.nvim_err_writeln('capture_string must start with "@"') return @@ -198,7 +200,7 @@ function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, local best local best_score - for maybe_match in M.iter_group_results(bufnr, query_group) do + for maybe_match in M.iter_group_results(bufnr, query_group, root) do local match = utils.get_at_path(maybe_match, capture_string) if match and filter_predicate(match) then @@ -220,31 +222,82 @@ end -- @param bufnr the buffer -- @param query_group the query file to use -- @param root the root node -function M.iter_group_results(bufnr, query_group, root) - local lang = parsers.get_buf_lang(bufnr) - if not lang then return function() end end +-- @param root the root node lang, if known +function M.iter_group_results(bufnr, query_group, root, root_lang) + local buf_lang = parsers.get_buf_lang(bufnr) + + if not buf_lang then return EMPTY_ITER end + + local parser = parsers.get_parser(bufnr, buf_lang) + if not parser then return EMPTY_ITER end - local query = M.get_query(lang, query_group) - if not query then return function() end end + if not root then + local first_tree = parser:trees()[1] + + if first_tree then + root = first_tree:root() + end + end + + if not root then return EMPTY_ITER end + + local range = {root:range()} + + if not root_lang then + local lang_tree = parser:language_for_range(range) + + if lang_tree then + root_lang = lang_tree:lang() + end + end - local parser = parsers.get_parser(bufnr, lang) - if not parser then return function() end end + if not root_lang then return EMPTY_ITER end - local root = root or parser:parse()[1]:root() - local start_row, _, end_row, _ = root:range() + local query = M.get_query(root_lang, query_group) + if not query then return EMPTY_ITER end -- The end row is exclusive so we need to add 1 to it. - return M.iter_prepared_matches(query, root, bufnr, start_row, end_row + 1) + return M.iter_prepared_matches(query, root, bufnr, range[1], range[3] + 1) end -function M.collect_group_results(bufnr, query_group, root) +function M.collect_group_results(bufnr, query_group, root, lang) local matches = {} - for prepared_match in M.iter_group_results(bufnr, query_group, root) do + for prepared_match in M.iter_group_results(bufnr, query_group, root, lang) do table.insert(matches, prepared_match) end return matches end +--- Same as get_capture_matches except this will recursively get matches for every language in the tree. +-- @param bufnr The bufnr +-- @param capture_or_fn The capture to get. If a function is provided then that +-- function will be used to resolve both the capture and query argument. +-- The function can return `nil` to ignore that tree. +-- @param query_type The query to get the capture from. This is ignore if a function is provided +-- for the captuer argument. +function M.get_capture_matches_recursively(bufnr, capture_or_fn, query_type) + local type_fn = type(capture_or_fn) == 'function' + and capture_or_fn + or function() + return capture_or_fn, query_type + end + local parser = parsers.get_parser(bufnr) + local matches = {} + + if parser then + parser:for_each_tree(function(tree, lang_tree) + local lang = lang_tree:lang() + local capture, type_ = type_fn(lang, tree, lang_tree) + + if capture then + vim.list_extend(matches, M.get_capture_matches(bufnr, capture, type_, tree:root(), lang)) + end + end) + end + + return matches +end + return M diff --git a/lua/nvim-treesitter/ts_utils.lua b/lua/nvim-treesitter/ts_utils.lua index c5475d829..1f7f20933 100644 --- a/lua/nvim-treesitter/ts_utils.lua +++ b/lua/nvim-treesitter/ts_utils.lua @@ -114,10 +114,46 @@ function M.get_named_children(node) end function M.get_node_at_cursor(winnr) - if not parsers.has_parser() then return end local cursor = api.nvim_win_get_cursor(winnr or 0) - local root = parsers.get_parser():parse()[1]:root() - return root:named_descendant_for_range(cursor[1]-1,cursor[2],cursor[1]-1,cursor[2]) + local cursor_range = { cursor[1] - 1, cursor[2] } + local root = M.get_root_for_position(unpack(cursor_range)) + + if not root then return end + + return root:named_descendant_for_range(cursor_range[1], cursor_range[2], cursor_range[1], cursor_range[2]) +end + +function M.get_root_for_position(line, col, root_lang_tree) + if not root_lang_tree then + if not parsers.has_parser() then return end + + root_lang_tree = parsers.get_parser() + end + + local lang_tree = root_lang_tree:language_for_range({ line, col, line, col }) + + for _, tree in ipairs(lang_tree:trees()) do + local root = tree:root() + + if root and M.is_in_node_range(root, line, col) then + return root, tree, lang_tree + end + end + + -- This isn't a likely scenario, since the position must belong to a tree somewhere. + return nil, nil, lang_tree +end + +function M.get_root_for_node(node) + local parent = node + local result = node + + while parent ~= nil do + result = parent + parent = result:parent() + end + + return result end function M.highlight_node(node, buf, hl_namespace, hl_group) @@ -213,25 +249,44 @@ end --- Memoizes a function based on the buffer tick of the provided bufnr. -- The cache entry is cleared when the buffer is detached to avoid memory leaks. -- @param fn: the fn to memoize, taking the bufnr as first argument +-- @param options: +-- - bufnr: extracts a bufnr from the given arguments. +-- - key: extracts the cache key from the given arguments. -- @returns a memoized function -function M.memoize_by_buf_tick(fn) +function M.memoize_by_buf_tick(fn, options) + options = options or {} + local cache = {} + local bufnr_fn = utils.to_func(options.bufnr or utils.identity) + local key_fn = utils.to_func(options.key or utils.identity) - return function(bufnr) - if cache[bufnr] then - return cache[bufnr] + return function(...) + local bufnr = bufnr_fn(...) + local key = key_fn(...) + local tick = api.nvim_buf_get_changedtick(bufnr) + + if cache[key] then + if cache[key].last_tick == tick then + return cache[key].result + end else - cache[bufnr] = {} - api.nvim_buf_attach(bufnr, false, - { - on_changedtick = function() cache[bufnr] = fn(bufnr) end, - on_detach = function() cache[bufnr] = nil end - } - ) + local function detach_handler() + cache[key] = nil + end + + -- Clean up logic only! + api.nvim_buf_attach(bufnr, false, { + on_detach = detach_handler, + on_reload = detach_handler + }) end - cache[bufnr] = fn(bufnr) - return cache[bufnr] + cache[key] = { + result = fn(...), + last_tick = tick + } + + return cache[key].result end end diff --git a/lua/nvim-treesitter/tsrange.lua b/lua/nvim-treesitter/tsrange.lua index 309d21b85..a8ce26f31 100644 --- a/lua/nvim-treesitter/tsrange.lua +++ b/lua/nvim-treesitter/tsrange.lua @@ -4,6 +4,7 @@ TSRange.__index = TSRange local api = vim.api local parsers = require'nvim-treesitter.parsers' +local ts_utils = require'nvim-treesitter.ts_utils' local function get_byte_offset(buf, row, col) return api.nvim_buf_get_offset(buf, row) @@ -57,8 +58,11 @@ end function TSRange:parent(range) local parser = parsers.get_parser(self.buf, parsers.get_buf_lang(range)) - local root = parser:parse()[1]:root() - return root:named_descendant_for_range(self.start_pos[1], self.start_pos[2], self.end_pos[1], self.end_pos[2]) + local root = ts_utils.get_root_for_position(range[1], range[2], parser) + + return root + and root:named_descendant_for_range(self.start_pos[1], self.start_pos[2], self.end_pos[1], self.end_pos[2]) + or nil end function TSRange:field() diff --git a/lua/nvim-treesitter/utils.lua b/lua/nvim-treesitter/utils.lua index f82c2d4a7..ab8498b0e 100644 --- a/lua/nvim-treesitter/utils.lua +++ b/lua/nvim-treesitter/utils.lua @@ -166,4 +166,16 @@ function M.difference(tbl1, tbl2) end) end +function M.identity(a) + return a +end + +function M.constant(a) + return function() return a end +end + +function M.to_func(a) + return type(a) == 'function' and a or M.constant(a) +end + return M |
