aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/locals.lua218
-rw-r--r--lua/nvim-treesitter/query.lua165
2 files changed, 133 insertions, 250 deletions
diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua
index fa8496469..c8ebe0b0d 100644
--- a/lua/nvim-treesitter/locals.lua
+++ b/lua/nvim-treesitter/locals.lua
@@ -2,7 +2,6 @@
-- Locals are a generalization of definition and scopes
-- it's the way nvim-treesitter uses to "understand" the code
-local query = require('nvim-treesitter.query')
local api = vim.api
local ts = vim.treesitter
@@ -30,19 +29,6 @@ local function get_root_for_node(node)
return result
end
--- Iterates matches from a locals query file.
--- @param bufnr the buffer
--- @param root the root node
-function M.iter_locals(bufnr, root)
- return query.iter_group_results(bufnr, 'locals', root)
-end
-
----@param bufnr integer
----@return any
-function M.collect_locals(bufnr)
- return query.collect_group_results(bufnr, 'locals')
-end
-
-- Creates unique id for a node based on text and range
---@param scope TSNode: the scope node of the definition
---@param node_text string: the node text to use
@@ -52,48 +38,6 @@ function M.get_definition_id(scope, node_text)
return table.concat({ 'k', node_text or '', scope:range() }, '_')
end
-function M.get_definitions(bufnr)
- local locals = M.collect_locals(bufnr)
-
- local defs = {}
-
- for _, loc in ipairs(locals) do
- if loc['local.definition'] then
- table.insert(defs, loc['local.definition'])
- end
- end
-
- return defs
-end
-
-function M.get_scopes(bufnr)
- local locals = M.collect_locals(bufnr)
-
- local scopes = {}
-
- for _, loc in ipairs(locals) do
- if loc['local.scope'] and loc['local.scope'].node then
- table.insert(scopes, loc['local.scope'].node)
- end
- end
-
- return scopes
-end
-
-function M.get_references(bufnr)
- local locals = M.collect_locals(bufnr)
-
- local refs = {}
-
- for _, loc in ipairs(locals) do
- if loc['local.reference'] and loc['local.reference'].node then
- table.insert(refs, loc['local.reference'].node)
- end
- end
-
- return refs
-end
-
-- Gets a table with all the scopes containing a node
-- The order is from most specific to least (bottom up)
---@param node TSNode
@@ -129,8 +73,8 @@ function M.iter_scope_tree(node, bufnr)
end
-- Gets a table of all nodes and their 'kinds' from a locals list
----@param local_def any: the local list result
----@return table: a list of node entries
+---@param local_def TSLocal[] the local list result
+---@return TSLocal[] a list of node entries
function M.get_local_nodes(local_def)
local result = {}
@@ -145,9 +89,9 @@ end
-- The accumulator function is given
-- * The table of the node
-- * The node
--- * The full definition match `@definition.var.something` -> 'var.something'
--- * The last definition match `@definition.var.something` -> 'something'
----@param local_def any The locals result
+-- * The full definition match `@local.definition.var.something` -> 'var.something'
+-- * The last definition match `@local.definition.var.something` -> 'something'
+---@param local_def TSLocal The locals result
---@param accumulator function The accumulator function
---@param full_match? string The full match path to append to
---@param last_match? string The last match
@@ -181,14 +125,108 @@ local function memoize(fn, hash_fn)
return function(...)
local key = hash_fn(...)
if cache[key] == nil then
- local v = fn(...) ---@type any
- cache[key] = v ~= nil and v or vim.NIL
+ local v = { fn(...) } ---@type any
+
+ for k, value in pairs(v) do
+ if value == nil then
+ value[k] = vim.NIL
+ end
+ end
+
+ cache[key] = v
end
local v = cache[key]
- return v ~= vim.NIL and v or nil
+
+ for k, value in pairs(v) do
+ if value == vim.NIL then
+ value[k] = nil
+ end
+ end
+
+ return unpack(v)
end
end
+---@param bufnr integer: the buffer
+---@return TSNode|nil root: root node of the buffer
+local function get_root(bufnr)
+ local parser = ts.get_parser(bufnr)
+ if not parser then
+ return
+ end
+ parser:parse()
+ return parser:trees()[1]:root()
+end
+
+---@param bufnr integer: the buffer
+---@return Query|nil query: `locals` query
+---@return TSNode|nil root: root node of the bufferocal function get_query(bufnr)
+local function get_query(bufnr)
+ local root = get_root(bufnr)
+
+ local ft = vim.bo[bufnr].filetype
+ local lang = ts.language.get_lang(ft) or ft
+
+ local query = (ts.query.get(lang, 'locals'))
+
+ return query, root
+end
+
+---@alias TSScope "parent"|"local"|"global"
+
+---@class TSLocal
+---@field kind string
+---@field node TSNode
+---@field scope TSScope
+
+-- Return all locals for the buffer
+--
+-- memoized by buffer tick
+--
+---@param bufnr integer buffer
+---@return TSLocal[] definitions
+---@return TSLocal[] references
+---@return TSNode[] scopes
+M.get = memoize(function(bufnr)
+ local query, root = get_query(bufnr)
+ if not query or not root then
+ return {}, {}, {}
+ end
+
+ local definitions = {}
+ local scopes = {}
+ local references = {}
+ for id, node, metadata in query:iter_captures(root, bufnr) do
+ local kind = query.captures[id]
+
+ local scope = 'local' ---@type string
+ for k, v in pairs(metadata) do
+ if type(k) == 'string' and vim.endswith(k, 'local.scope') then
+ scope = v
+ end
+ end
+
+ if node and vim.startswith(kind, 'local.definition') then
+ table.insert(definitions, { kind = kind, node = node, scope = scope })
+ end
+
+ if node and kind == 'local.scope' then
+ table.insert(scopes, node)
+ end
+
+ if node and kind == 'local.reference' then
+ table.insert(references, { kind = kind, node = node, scope = scope })
+ end
+ end
+
+ return definitions, references, scopes
+end, function(bufnr)
+ local root = get_root(bufnr)
+ if not root then
+ return tostring(bufnr)
+ end
+ return tostring(root:id())
+end)
-- Get a single dimension table to look definition nodes.
-- Keys are generated by using the range of the containing scope and the text of the definition node.
@@ -202,11 +240,14 @@ end
-- is called very frequently, which is why this lookup must be fast as possible.
--
---@param bufnr integer: the buffer
----@return table result: a table for looking up definitions
+---@return TSLocal[] result: a table for looking up definitions
M.get_definitions_lookup_table = memoize(function(bufnr)
- local definitions = M.get_definitions(bufnr)
- local result = {}
+ local definitions, _, _ = M.get(bufnr)
+ if not definitions then
+ return {}
+ end
+ local result = {}
for _, definition in ipairs(definitions) do
for _, node_entry in ipairs(M.get_local_nodes(definition)) do
local scopes = M.get_definition_scopes(node_entry.node, bufnr, node_entry.scope)
@@ -221,7 +262,11 @@ M.get_definitions_lookup_table = memoize(function(bufnr)
return result
end, function(bufnr)
- return tostring(bufnr)
+ local root = get_root(bufnr)
+ if not root then
+ return tostring(bufnr)
+ end
+ return tostring(root:id())
end)
-- Gets all the scopes of a definition based on the scope type
@@ -233,7 +278,7 @@ end)
--
---@param node TSNode: the definition node
---@param bufnr integer: the buffer
----@param scope_type string: the scope type
+---@param scope_type TSScope: the scope type
function M.get_definition_scopes(node, bufnr, scope_type)
local scopes = {}
local scope_count = 1 ---@type integer|nil
@@ -248,8 +293,8 @@ function M.get_definition_scopes(node, bufnr, scope_type)
end
local i = 0
- for scope in M.iter_scope_tree(node, bufnr) do
- table.insert(scopes, scope)
+ for scope_node in M.iter_scope_tree(node, bufnr) do
+ table.insert(scopes, scope_node)
i = i + 1
if scope_count and i >= scope_count then
@@ -284,7 +329,8 @@ end
-- Finds usages of a node in a given scope.
---@param node TSNode the node to find usages for
----@param scope_node TSNode the node to look within
+---@param scope_node TSNode|nil the node to look within
+---@param bufnr integer|nil the bufnr to look into
---@return TSNode[]: a list of nodes
function M.find_usages(node, scope_node, bufnr)
bufnr = bufnr or api.nvim_get_current_buf()
@@ -297,17 +343,19 @@ function M.find_usages(node, scope_node, bufnr)
scope_node = scope_node or get_root_for_node(node)
local usages = {}
- for match in M.iter_locals(bufnr, scope_node) do
+ local query, _ = get_query(bufnr)
+ if not query then
+ return {}
+ end
+
+ for id, node_capture in query:iter_captures(scope_node, bufnr) do
+ local kind = query.captures[id]
if
- match.reference
- and match.reference.node
- and ts.get_node_text(match.reference.node, bufnr) == node_text
+ node_capture
+ and kind == 'local.reference'
+ and ts.get_node_text(node_capture, bufnr) == node_text
then
- local def_node, _, kind = M.find_definition(match.reference.node, bufnr)
-
- if kind == nil or def_node == node then
- table.insert(usages, match.reference.node)
- end
+ table.insert(usages, node_capture)
end
end
@@ -322,7 +370,7 @@ function M.containing_scope(node, bufnr, allow_scope)
bufnr = bufnr or api.nvim_get_current_buf()
allow_scope = allow_scope == nil or allow_scope == true
- local scopes = M.get_scopes(bufnr)
+ local _, _, scopes = M.get(bufnr)
if not node or not scopes then
return
end
@@ -339,7 +387,7 @@ end
function M.nested_scope(node, cursor_pos)
local bufnr = api.nvim_get_current_buf()
- local scopes = M.get_scopes(bufnr)
+ local _, _, scopes = M.get(bufnr)
if not node or not scopes then
return
end
@@ -359,7 +407,7 @@ end
function M.next_scope(node)
local bufnr = api.nvim_get_current_buf()
- local scopes = M.get_scopes(bufnr)
+ local _, _, scopes = M.get(bufnr)
if not node or not scopes then
return
end
@@ -389,7 +437,7 @@ end
function M.previous_scope(node)
local bufnr = api.nvim_get_current_buf()
- local scopes = M.get_scopes(bufnr)
+ local _, _, scopes = M.get(bufnr)
if not node or not scopes then
return
end
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua
deleted file mode 100644
index f7e5205f7..000000000
--- a/lua/nvim-treesitter/query.lua
+++ /dev/null
@@ -1,165 +0,0 @@
-local M = {}
-
-local EMPTY_ITER = function() end
-
----@class QueryInfo
----@field root TSNode
----@field source integer
----@field start integer
----@field stop integer
-
----@param bufnr integer
----@param query_name string
----@param root TSNode
----@param root_lang string|nil
----@return Query|nil, QueryInfo|nil
-local function prepare_query(bufnr, query_name, root, root_lang)
- local ft = vim.bo[bufnr].filetype
- local buf_lang = vim.treesitter.language.get_lang(ft) or ft
- if not buf_lang then
- return
- end
-
- local parser = vim.treesitter.get_parser(bufnr, buf_lang)
- if not parser then
- return
- end
-
- if not root then
- local first_tree = parser:trees()[1]
-
- if first_tree then
- root = first_tree:root()
- end
- end
-
- if not root then
- return
- end
-
- local range = { root:range() }
-
- if not root_lang then
- local lang_tree = parser:language_for_range(range)
-
- if lang_tree then
- root_lang = lang_tree:lang()
- end
- end
-
- if not root_lang then
- return
- end
-
- local query = vim.treesitter.query.get(root_lang, query_name)
- if not query then
- return
- end
-
- return query,
- {
- root = root,
- source = bufnr,
- start = range[1],
- -- The end row is exclusive so we need to add 1 to it.
- stop = range[3] + 1,
- }
-end
-
--- Given a path (i.e. a List(String)) this functions inserts value at path
----@param object any
----@param path string[]
----@param value any
-function M.insert_to_path(object, path, value)
- local curr_obj = object
-
- for index = 1, (#path - 1) do
- if curr_obj[path[index]] == nil then
- curr_obj[path[index]] = {}
- end
-
- curr_obj = curr_obj[path[index]]
- end
-
- curr_obj[path[#path]] = value
-end
-
----@param query Query
----@param bufnr integer
----@param start_row integer
----@param end_row integer
-function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
- -- A function that splits a string on '.'
- ---@param to_split string
- ---@return string[]
- local function split(to_split)
- local t = {}
- for str in string.gmatch(to_split, '([^.]+)') do
- table.insert(t, str)
- end
-
- return t
- end
-
- local matches = query:iter_matches(qnode, bufnr, start_row, end_row, { all = false })
-
- local function iterator()
- local pattern, match, metadata = matches()
- if pattern ~= nil then
- local prepared_match = {}
-
- -- Extract capture names from each match
- for id, node in pairs(match) do
- local name = query.captures[id] -- name of the capture in the query
- if name ~= nil then
- local path = split(name .. '.node')
- M.insert_to_path(prepared_match, path, node)
- local metadata_path = split(name .. '.metadata')
- M.insert_to_path(prepared_match, metadata_path, metadata[id])
- end
- end
-
- -- Add some predicates for testing
- ---@type string[][] ( TODO: make pred type so this can be pred[])
- local preds = query.info.patterns[pattern]
- if preds then
- for _, pred in pairs(preds) do
- -- functions
- if pred[1] == 'set!' and type(pred[2]) == 'string' then
- M.insert_to_path(prepared_match, split(pred[2]), pred[3])
- end
- end
- end
-
- return prepared_match
- end
- end
- return iterator
-end
-
----Iterates matches from a query file.
----@param bufnr integer the buffer
----@param query_group string the query file to use
----@param root TSNode the root node
----@param root_lang? string the root node lang, if known
-function M.iter_group_results(bufnr, query_group, root, root_lang)
- local query, params = prepare_query(bufnr, query_group, root, root_lang)
- if not query then
- return EMPTY_ITER
- end
- assert(params)
-
- return M.iter_prepared_matches(query, params.root, params.source, params.start, params.stop)
-end
-
-function M.collect_group_results(bufnr, query_group, root, lang)
- local matches = {}
-
- for prepared_match in M.iter_group_results(bufnr, query_group, root, lang) do
- table.insert(matches, prepared_match)
- end
-
- return matches
-end
-
-return M