From 282e33ad9c96a44a092ab8a356bba7626c838b68 Mon Sep 17 00:00:00 2001 From: Steven Sojka Date: Wed, 12 Aug 2020 07:38:15 -0500 Subject: fix(definitions): optimize and fix definition highlighting --- lua/nvim-treesitter/locals.lua | 117 +++++++++++++-------- lua/nvim-treesitter/query.lua | 4 +- .../refactor/highlight_definitions.lua | 2 +- lua/nvim-treesitter/ts_utils.lua | 30 ++++++ 4 files changed, 109 insertions(+), 44 deletions(-) (limited to 'lua') diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua index f43a8e24b..18d355c35 100644 --- a/lua/nvim-treesitter/locals.lua +++ b/lua/nvim-treesitter/locals.lua @@ -24,6 +24,16 @@ function M.get_locals(bufnr) return queries.get_matches(bufnr, 'locals') end +--- Creates unique id for a node based on text and range +-- @param scope: the scope node of the definition +-- @param bufnr: the buffer +-- @param node_text: the node text to use +-- @returns a string id +function M.get_definition_id(scope, node_text) + -- Add a vaild starting character in case node text doesn't start with a valid one. + return table.concat({ 'k', node_text, scope:range() }, '_') +end + function M.get_definitions(bufnr) local locals = M.get_locals(bufnr) @@ -69,56 +79,30 @@ end --- Gets a table with all the scopes containing a node -- The order is from most specific to least (bottom up) function M.get_scope_tree(node, bufnr) - local current_scope = M.containing_scope(node, bufnr, false) or parsers.get_tree_root(bufnr) local scopes = {} - while current_scope do - table.insert(scopes, current_scope) - current_scope = current_scope:parent() - and (M.containing_scope(current_scope:parent(), bufnr, false) or parsers.get_tree_root(bufnr)) - or nil + for scope in M.iter_scope_tree(node, bufnr) do + table.insert(scopes, scope) end return scopes end --- Finds the definition node and it's scope node of a node --- @param node starting node --- @param bufnr buffer --- @returns the definition node and the definition nodes scope node -function M.find_definition(node, bufnr) - local bufnr = bufnr or api.nvim_get_current_buf() - local node_text = ts_utils.get_node_text(node, bufnr)[1] - local scope_tree = M.get_scope_tree(node, bufnr) - local match - local last_scope_index +--- Iterates over a nodes scopes moving from the bottom up +function M.iter_scope_tree(node, bufnr) + local last_node = node - -- Loop over every definition - for _, definition in ipairs(M.get_definitions(bufnr)) do - for _, node_entry in ipairs(M.get_local_nodes(definition)) do - local def_scope = M.containing_scope(node_entry.node, bufnr, false) or parsers.get_tree_root(bufnr) - - -- Only match definitions that match the text of the node - -- Look for the most specific definition in the tree - -- The lower the index, the more specific the definition is - if ts_utils.get_node_text(node_entry.node, bufnr)[1] == node_text then - for i, scope_node in ipairs(scope_tree) do - -- If we already found a close definition in scope, just skip checking - if last_scope_index and i >= last_scope_index then break end - if scope_node == def_scope then - last_scope_index = i - match = node_entry - end - end - end + return function() + if not last_node then + return end - end - if match and last_scope_index then - return match.node, scope_tree[last_scope_index], match.kind - end + local scope = M.containing_scope(last_node, bufnr, false) or parsers.get_tree_root(bufnr) + + last_node = scope:parent() - return node, parsers.get_parser(bufnr).tree:root(), nil + return scope + end end -- Gets a table of all nodes and their 'kinds' from a locals list @@ -158,7 +142,54 @@ function M.recurse_local_nodes(local_def, accumulator, full_match, last_match) end end --- Finds usages of a node in a given scope +--- Get a single dimension table to look definition nodes. +-- Keys are generated by using the range of the containing scope and the text of the definition node. +-- This makes looking up a definition for a given scope a simple key lookup. +-- +-- This is memoized by buffer tick. If the function is called in succession +-- without the buffer tick changing, then the previous result will be used +-- since the syntax tree hasn't changed. +-- +-- Usage lookups require finding the definition of the node, so `find_definition` +-- is called very frequently, which is why this lookup must be fast as possible. +-- +-- @param bufnr: the buffer +-- @returns a table for looking up definitions +M.get_definitions_lookup_table = ts_utils.memoize_by_buf_tick(function(bufnr) + local definitions = M.get_definitions(bufnr) + local result = {} + + for _, definition in ipairs(definitions) do + for _, node_entry in ipairs(M.get_local_nodes(definition)) do + local scope = M.containing_scope(node_entry.node, bufnr, false) or parsers.get_tree_root(bufnr) + local node_text = ts_utils.get_node_text(node_entry.node, bufnr)[1] + local id = M.get_definition_id(scope, node_text) + + result[id] = node_entry + end + end + + return result +end) + +function M.find_definition(node, bufnr) + local def_lookup = M.get_definitions_lookup_table(bufnr) + local node_text = ts_utils.get_node_text(node, bufnr)[1] + + for scope in M.iter_scope_tree(node, bufnr) do + local id = M.get_definition_id(scope, node_text) + + if def_lookup[id] then + local entry = def_lookup[id] + + return entry.node, scope, entry.kind + end + end + + return node, parsers.get_tree_root(bufnr), nil +end + +-- Finds usages of a node in a given scope. -- @param node the node to find usages for -- @param scope_node the node to look within -- @returns a list of nodes @@ -176,7 +207,11 @@ function M.find_usages(node, scope_node, bufnr) and match.reference.node and ts_utils.get_node_text(match.reference.node, bufnr)[1] == node_text then - table.insert(usages, match.reference.node) + local def_node, _, kind = M.find_definition(match.reference.node, bufnr) + + if kind == nil or def_node == node then + table.insert(usages, match.reference.node) + end end end diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index cd2495af1..33637be3f 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -203,7 +203,6 @@ end -- @param query_group the query file to use -- @param root the root node function M.iter_group_results(bufnr, query_group, root) - local lang = parsers.get_buf_lang(bufnr) if not lang then return end @@ -216,7 +215,8 @@ function M.iter_group_results(bufnr, query_group, root) local root = root or parser:parse():root() local start_row, _, end_row, _ = root:range() - return M.iter_prepared_matches(query, root, bufnr, start_row, end_row) + -- The end row is exclusive so we need to add 1 to it. + return M.iter_prepared_matches(query, root, bufnr, start_row, end_row + 1) end function M.collect_group_results(bufnr, query_group, root) diff --git a/lua/nvim-treesitter/refactor/highlight_definitions.lua b/lua/nvim-treesitter/refactor/highlight_definitions.lua index 710814427..8a75980d6 100644 --- a/lua/nvim-treesitter/refactor/highlight_definitions.lua +++ b/lua/nvim-treesitter/refactor/highlight_definitions.lua @@ -21,7 +21,7 @@ function M.highlight_usages(bufnr) end local def_node, scope = locals.find_definition(node_at_point, bufnr) - local usages = locals.find_usages(node_at_point, scope) + local usages = locals.find_usages(def_node, scope, bufnr) for _, usage_node in ipairs(usages) do if usage_node ~= node_at_point then diff --git a/lua/nvim-treesitter/ts_utils.lua b/lua/nvim-treesitter/ts_utils.lua index bf5b22081..0a5cbc608 100644 --- a/lua/nvim-treesitter/ts_utils.lua +++ b/lua/nvim-treesitter/ts_utils.lua @@ -195,4 +195,34 @@ function M.node_to_lsp_range(node) return rtn end +--- Memoizes a function based on the buffer tick of the provided bufnr. +-- The cache entry is cleared when the buffer is detached to avoid memory leaks. +-- @param fn: the fn to memoize +-- @param bufnr_fn: a function that receives all arguments passed to the function +-- and returns the bufnr from the arguments +-- @returns a memoized function +function M.memoize_by_buf_tick(fn, bufnr_fn) + local bufnr_fn = bufnr_fn or function(a) return a end + local cache = {} + + return function(...) + local bufnr = bufnr_fn(...) + local tick = api.nvim_buf_get_changedtick(bufnr) + + if cache[bufnr] then + if cache[bufnr].last_tick == tick then + return cache[bufnr].result + end + else + cache[bufnr] = {} + api.nvim_buf_attach(bufnr, false, { on_detach = function() cache[bufnr] = nil end }) + end + + cache[bufnr].last_tick = tick + cache[bufnr].result = fn(...) + + return cache[bufnr].result + end +end + return M -- cgit v1.2.3-70-g09d2