diff options
| author | Christian Clason <c.clason@uni-graz.at> | 2023-05-20 17:29:03 +0200 |
|---|---|---|
| committer | Christian Clason <c.clason@uni-graz.at> | 2025-05-12 18:43:40 +0200 |
| commit | 5817ff01b523a0dce5e7a42374ac137cb6490577 (patch) | |
| tree | bf2bcb2d3add6285bb0b6a4c72065d6ce20d0718 /lua | |
| parent | feat!: drop modules, general refactor and cleanup (diff) | |
| download | nvim-treesitter-5817ff01b523a0dce5e7a42374ac137cb6490577.tar nvim-treesitter-5817ff01b523a0dce5e7a42374ac137cb6490577.tar.gz nvim-treesitter-5817ff01b523a0dce5e7a42374ac137cb6490577.tar.bz2 nvim-treesitter-5817ff01b523a0dce5e7a42374ac137cb6490577.tar.lz nvim-treesitter-5817ff01b523a0dce5e7a42374ac137cb6490577.tar.xz nvim-treesitter-5817ff01b523a0dce5e7a42374ac137cb6490577.tar.zst nvim-treesitter-5817ff01b523a0dce5e7a42374ac137cb6490577.zip | |
feat(locals)!: refactor `locals.lua` into standalone
Co-authored-by: TheLeoP <eugenio2305@hotmail.com>
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/nvim-treesitter/locals.lua | 218 | ||||
| -rw-r--r-- | lua/nvim-treesitter/query.lua | 165 |
2 files changed, 133 insertions, 250 deletions
diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua index fa8496469..c8ebe0b0d 100644 --- a/lua/nvim-treesitter/locals.lua +++ b/lua/nvim-treesitter/locals.lua @@ -2,7 +2,6 @@ -- Locals are a generalization of definition and scopes -- it's the way nvim-treesitter uses to "understand" the code -local query = require('nvim-treesitter.query') local api = vim.api local ts = vim.treesitter @@ -30,19 +29,6 @@ local function get_root_for_node(node) return result end --- Iterates matches from a locals query file. --- @param bufnr the buffer --- @param root the root node -function M.iter_locals(bufnr, root) - return query.iter_group_results(bufnr, 'locals', root) -end - ----@param bufnr integer ----@return any -function M.collect_locals(bufnr) - return query.collect_group_results(bufnr, 'locals') -end - -- Creates unique id for a node based on text and range ---@param scope TSNode: the scope node of the definition ---@param node_text string: the node text to use @@ -52,48 +38,6 @@ function M.get_definition_id(scope, node_text) return table.concat({ 'k', node_text or '', scope:range() }, '_') end -function M.get_definitions(bufnr) - local locals = M.collect_locals(bufnr) - - local defs = {} - - for _, loc in ipairs(locals) do - if loc['local.definition'] then - table.insert(defs, loc['local.definition']) - end - end - - return defs -end - -function M.get_scopes(bufnr) - local locals = M.collect_locals(bufnr) - - local scopes = {} - - for _, loc in ipairs(locals) do - if loc['local.scope'] and loc['local.scope'].node then - table.insert(scopes, loc['local.scope'].node) - end - end - - return scopes -end - -function M.get_references(bufnr) - local locals = M.collect_locals(bufnr) - - local refs = {} - - for _, loc in ipairs(locals) do - if loc['local.reference'] and loc['local.reference'].node then - table.insert(refs, loc['local.reference'].node) - end - end - - return refs -end - -- Gets a table with all the scopes containing a node -- The order is from most specific to least (bottom up) ---@param node TSNode @@ -129,8 +73,8 @@ function M.iter_scope_tree(node, bufnr) end -- Gets a table of all nodes and their 'kinds' from a locals list ----@param local_def any: the local list result ----@return table: a list of node entries +---@param local_def TSLocal[] the local list result +---@return TSLocal[] a list of node entries function M.get_local_nodes(local_def) local result = {} @@ -145,9 +89,9 @@ end -- The accumulator function is given -- * The table of the node -- * The node --- * The full definition match `@definition.var.something` -> 'var.something' --- * The last definition match `@definition.var.something` -> 'something' ----@param local_def any The locals result +-- * The full definition match `@local.definition.var.something` -> 'var.something' +-- * The last definition match `@local.definition.var.something` -> 'something' +---@param local_def TSLocal The locals result ---@param accumulator function The accumulator function ---@param full_match? string The full match path to append to ---@param last_match? string The last match @@ -181,14 +125,108 @@ local function memoize(fn, hash_fn) return function(...) local key = hash_fn(...) if cache[key] == nil then - local v = fn(...) ---@type any - cache[key] = v ~= nil and v or vim.NIL + local v = { fn(...) } ---@type any + + for k, value in pairs(v) do + if value == nil then + value[k] = vim.NIL + end + end + + cache[key] = v end local v = cache[key] - return v ~= vim.NIL and v or nil + + for k, value in pairs(v) do + if value == vim.NIL then + value[k] = nil + end + end + + return unpack(v) end end +---@param bufnr integer: the buffer +---@return TSNode|nil root: root node of the buffer +local function get_root(bufnr) + local parser = ts.get_parser(bufnr) + if not parser then + return + end + parser:parse() + return parser:trees()[1]:root() +end + +---@param bufnr integer: the buffer +---@return Query|nil query: `locals` query +---@return TSNode|nil root: root node of the bufferocal function get_query(bufnr) +local function get_query(bufnr) + local root = get_root(bufnr) + + local ft = vim.bo[bufnr].filetype + local lang = ts.language.get_lang(ft) or ft + + local query = (ts.query.get(lang, 'locals')) + + return query, root +end + +---@alias TSScope "parent"|"local"|"global" + +---@class TSLocal +---@field kind string +---@field node TSNode +---@field scope TSScope + +-- Return all locals for the buffer +-- +-- memoized by buffer tick +-- +---@param bufnr integer buffer +---@return TSLocal[] definitions +---@return TSLocal[] references +---@return TSNode[] scopes +M.get = memoize(function(bufnr) + local query, root = get_query(bufnr) + if not query or not root then + return {}, {}, {} + end + + local definitions = {} + local scopes = {} + local references = {} + for id, node, metadata in query:iter_captures(root, bufnr) do + local kind = query.captures[id] + + local scope = 'local' ---@type string + for k, v in pairs(metadata) do + if type(k) == 'string' and vim.endswith(k, 'local.scope') then + scope = v + end + end + + if node and vim.startswith(kind, 'local.definition') then + table.insert(definitions, { kind = kind, node = node, scope = scope }) + end + + if node and kind == 'local.scope' then + table.insert(scopes, node) + end + + if node and kind == 'local.reference' then + table.insert(references, { kind = kind, node = node, scope = scope }) + end + end + + return definitions, references, scopes +end, function(bufnr) + local root = get_root(bufnr) + if not root then + return tostring(bufnr) + end + return tostring(root:id()) +end) -- Get a single dimension table to look definition nodes. -- Keys are generated by using the range of the containing scope and the text of the definition node. @@ -202,11 +240,14 @@ end -- is called very frequently, which is why this lookup must be fast as possible. -- ---@param bufnr integer: the buffer ----@return table result: a table for looking up definitions +---@return TSLocal[] result: a table for looking up definitions M.get_definitions_lookup_table = memoize(function(bufnr) - local definitions = M.get_definitions(bufnr) - local result = {} + local definitions, _, _ = M.get(bufnr) + if not definitions then + return {} + end + local result = {} for _, definition in ipairs(definitions) do for _, node_entry in ipairs(M.get_local_nodes(definition)) do local scopes = M.get_definition_scopes(node_entry.node, bufnr, node_entry.scope) @@ -221,7 +262,11 @@ M.get_definitions_lookup_table = memoize(function(bufnr) return result end, function(bufnr) - return tostring(bufnr) + local root = get_root(bufnr) + if not root then + return tostring(bufnr) + end + return tostring(root:id()) end) -- Gets all the scopes of a definition based on the scope type @@ -233,7 +278,7 @@ end) -- ---@param node TSNode: the definition node ---@param bufnr integer: the buffer ----@param scope_type string: the scope type +---@param scope_type TSScope: the scope type function M.get_definition_scopes(node, bufnr, scope_type) local scopes = {} local scope_count = 1 ---@type integer|nil @@ -248,8 +293,8 @@ function M.get_definition_scopes(node, bufnr, scope_type) end local i = 0 - for scope in M.iter_scope_tree(node, bufnr) do - table.insert(scopes, scope) + for scope_node in M.iter_scope_tree(node, bufnr) do + table.insert(scopes, scope_node) i = i + 1 if scope_count and i >= scope_count then @@ -284,7 +329,8 @@ end -- Finds usages of a node in a given scope. ---@param node TSNode the node to find usages for ----@param scope_node TSNode the node to look within +---@param scope_node TSNode|nil the node to look within +---@param bufnr integer|nil the bufnr to look into ---@return TSNode[]: a list of nodes function M.find_usages(node, scope_node, bufnr) bufnr = bufnr or api.nvim_get_current_buf() @@ -297,17 +343,19 @@ function M.find_usages(node, scope_node, bufnr) scope_node = scope_node or get_root_for_node(node) local usages = {} - for match in M.iter_locals(bufnr, scope_node) do + local query, _ = get_query(bufnr) + if not query then + return {} + end + + for id, node_capture in query:iter_captures(scope_node, bufnr) do + local kind = query.captures[id] if - match.reference - and match.reference.node - and ts.get_node_text(match.reference.node, bufnr) == node_text + node_capture + and kind == 'local.reference' + and ts.get_node_text(node_capture, bufnr) == node_text then - local def_node, _, kind = M.find_definition(match.reference.node, bufnr) - - if kind == nil or def_node == node then - table.insert(usages, match.reference.node) - end + table.insert(usages, node_capture) end end @@ -322,7 +370,7 @@ function M.containing_scope(node, bufnr, allow_scope) bufnr = bufnr or api.nvim_get_current_buf() allow_scope = allow_scope == nil or allow_scope == true - local scopes = M.get_scopes(bufnr) + local _, _, scopes = M.get(bufnr) if not node or not scopes then return end @@ -339,7 +387,7 @@ end function M.nested_scope(node, cursor_pos) local bufnr = api.nvim_get_current_buf() - local scopes = M.get_scopes(bufnr) + local _, _, scopes = M.get(bufnr) if not node or not scopes then return end @@ -359,7 +407,7 @@ end function M.next_scope(node) local bufnr = api.nvim_get_current_buf() - local scopes = M.get_scopes(bufnr) + local _, _, scopes = M.get(bufnr) if not node or not scopes then return end @@ -389,7 +437,7 @@ end function M.previous_scope(node) local bufnr = api.nvim_get_current_buf() - local scopes = M.get_scopes(bufnr) + local _, _, scopes = M.get(bufnr) if not node or not scopes then return end diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua deleted file mode 100644 index f7e5205f7..000000000 --- a/lua/nvim-treesitter/query.lua +++ /dev/null @@ -1,165 +0,0 @@ -local M = {} - -local EMPTY_ITER = function() end - ----@class QueryInfo ----@field root TSNode ----@field source integer ----@field start integer ----@field stop integer - ----@param bufnr integer ----@param query_name string ----@param root TSNode ----@param root_lang string|nil ----@return Query|nil, QueryInfo|nil -local function prepare_query(bufnr, query_name, root, root_lang) - local ft = vim.bo[bufnr].filetype - local buf_lang = vim.treesitter.language.get_lang(ft) or ft - if not buf_lang then - return - end - - local parser = vim.treesitter.get_parser(bufnr, buf_lang) - if not parser then - return - end - - if not root then - local first_tree = parser:trees()[1] - - if first_tree then - root = first_tree:root() - end - end - - if not root then - return - end - - local range = { root:range() } - - if not root_lang then - local lang_tree = parser:language_for_range(range) - - if lang_tree then - root_lang = lang_tree:lang() - end - end - - if not root_lang then - return - end - - local query = vim.treesitter.query.get(root_lang, query_name) - if not query then - return - end - - return query, - { - root = root, - source = bufnr, - start = range[1], - -- The end row is exclusive so we need to add 1 to it. - stop = range[3] + 1, - } -end - --- Given a path (i.e. a List(String)) this functions inserts value at path ----@param object any ----@param path string[] ----@param value any -function M.insert_to_path(object, path, value) - local curr_obj = object - - for index = 1, (#path - 1) do - if curr_obj[path[index]] == nil then - curr_obj[path[index]] = {} - end - - curr_obj = curr_obj[path[index]] - end - - curr_obj[path[#path]] = value -end - ----@param query Query ----@param bufnr integer ----@param start_row integer ----@param end_row integer -function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) - -- A function that splits a string on '.' - ---@param to_split string - ---@return string[] - local function split(to_split) - local t = {} - for str in string.gmatch(to_split, '([^.]+)') do - table.insert(t, str) - end - - return t - end - - local matches = query:iter_matches(qnode, bufnr, start_row, end_row, { all = false }) - - local function iterator() - local pattern, match, metadata = matches() - if pattern ~= nil then - local prepared_match = {} - - -- Extract capture names from each match - for id, node in pairs(match) do - local name = query.captures[id] -- name of the capture in the query - if name ~= nil then - local path = split(name .. '.node') - M.insert_to_path(prepared_match, path, node) - local metadata_path = split(name .. '.metadata') - M.insert_to_path(prepared_match, metadata_path, metadata[id]) - end - end - - -- Add some predicates for testing - ---@type string[][] ( TODO: make pred type so this can be pred[]) - local preds = query.info.patterns[pattern] - if preds then - for _, pred in pairs(preds) do - -- functions - if pred[1] == 'set!' and type(pred[2]) == 'string' then - M.insert_to_path(prepared_match, split(pred[2]), pred[3]) - end - end - end - - return prepared_match - end - end - return iterator -end - ----Iterates matches from a query file. ----@param bufnr integer the buffer ----@param query_group string the query file to use ----@param root TSNode the root node ----@param root_lang? string the root node lang, if known -function M.iter_group_results(bufnr, query_group, root, root_lang) - local query, params = prepare_query(bufnr, query_group, root, root_lang) - if not query then - return EMPTY_ITER - end - assert(params) - - return M.iter_prepared_matches(query, params.root, params.source, params.start, params.stop) -end - -function M.collect_group_results(bufnr, query_group, root, lang) - local matches = {} - - for prepared_match in M.iter_group_results(bufnr, query_group, root, lang) do - table.insert(matches, prepared_match) - end - - return matches -end - -return M |
