aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/highlight.lua2
-rw-r--r--lua/nvim-treesitter/incremental_selection.lua3
-rw-r--r--lua/nvim-treesitter/locals.lua246
-rw-r--r--lua/nvim-treesitter/parsers.lua2
-rw-r--r--lua/nvim-treesitter/query.lua113
-rw-r--r--lua/nvim-treesitter/refactor/highlight_current_scope.lua3
-rw-r--r--lua/nvim-treesitter/refactor/highlight_definitions.lua4
-rw-r--r--lua/nvim-treesitter/refactor/navigation.lua4
-rw-r--r--lua/nvim-treesitter/refactor/smart_rename.lua5
-rw-r--r--lua/nvim-treesitter/textobjects.lua6
-rw-r--r--lua/nvim-treesitter/ts_utils.lua195
11 files changed, 289 insertions, 294 deletions
diff --git a/lua/nvim-treesitter/highlight.lua b/lua/nvim-treesitter/highlight.lua
index 14f1425aa..867c70caf 100644
--- a/lua/nvim-treesitter/highlight.lua
+++ b/lua/nvim-treesitter/highlight.lua
@@ -54,7 +54,7 @@ hlmap["include"] = "TSInclude"
function M.attach(bufnr, lang)
local bufnr = bufnr or api.nvim_get_current_buf()
- local lang = parsers.get_buf_lang(bufnr, lang)
+ local lang = lang or parsers.get_buf_lang(bufnr)
local config = configs.get_module('highlight')
for k, v in pairs(config.custom_captures) do
diff --git a/lua/nvim-treesitter/incremental_selection.lua b/lua/nvim-treesitter/incremental_selection.lua
index ecb02330d..045e7db23 100644
--- a/lua/nvim-treesitter/incremental_selection.lua
+++ b/lua/nvim-treesitter/incremental_selection.lua
@@ -2,6 +2,7 @@ local api = vim.api
local configs = require'nvim-treesitter.configs'
local ts_utils = require'nvim-treesitter.ts_utils'
+local locals = require'nvim-treesitter.locals'
local parsers = require'nvim-treesitter.parsers'
local M = {}
@@ -74,7 +75,7 @@ M.node_incremental = select_incremental(function(node)
end)
M.scope_incremental = select_incremental(function(node)
- return ts_utils.containing_scope(node:parent() or node)
+ return locals.containing_scope(node:parent() or node)
end)
function M.node_decremental()
diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua
index fe99cd811..947187d5e 100644
--- a/lua/nvim-treesitter/locals.lua
+++ b/lua/nvim-treesitter/locals.lua
@@ -1,73 +1,27 @@
-- Functions to handle locals
-- Locals are a generalization of definition and scopes
-- its the way nvim-treesitter uses to "understand" the code
-local api = vim.api
local queries = require'nvim-treesitter.query'
local parsers = require'nvim-treesitter.parsers'
-local utils = require'nvim-treesitter.utils'
-
-local default_dict = {
- __index = function(table, key)
- local exists = rawget(table, key)
- if not exists then
- table[key] = {}
- end
- return rawget(table, key)
- end
-}
-
-local query_cache = {}
-setmetatable(query_cache, default_dict)
+local ts_utils = require'nvim-treesitter.ts_utils'
+local api = vim.api
local M = {}
-function M.collect_locals(bufnr, query_kind)
- local locals = {}
-
- for prepared_match in M.iter_locals(bufnr, nil, query_kind) do
- table.insert(locals, prepared_match)
- end
-
- return locals
-end
-
-local function update_cached_locals(bufnr, changed_tick, query_kind)
- query_cache[query_kind][bufnr] = {tick=changed_tick, cache=( M.collect_locals(bufnr, query_kind) or {} )}
+function M.collect_locals(bufnr)
+ return queries.collect_group_results(bufnr, 'locals')
end
-- Iterates matches from a locals query file.
-- @param bufnr the buffer
-- @param root the root node
--- @param query_kind the query file to use
-function M.iter_locals(bufnr, root, query_kind)
- query_kind = query_kind or 'locals'
-
- local lang = parsers.get_buf_lang(bufnr)
- if not lang then return end
-
- local query = queries.get_query(lang, query_kind)
- if not query then return end
-
- local parser = parsers.get_parser(bufnr, lang)
- if not parser then return end
-
- local root = root or parser:parse():root()
- local start_row, _, end_row, _ = root:range()
-
- return queries.iter_prepared_matches(query, root, bufnr, start_row, end_row)
+function M.iter_locals(bufnr, root)
+ return queries.iter_group_results(bufnr, 'locals', root)
end
-function M.get_locals(bufnr, query_kind)
- query_kind = query_kind or 'locals'
-
- local bufnr = bufnr or api.nvim_get_current_buf()
- local cached_local = query_cache[query_kind][bufnr]
- if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then
- update_cached_locals(bufnr,api.nvim_buf_get_changedtick(bufnr), query_kind)
- end
-
- return query_cache[query_kind][bufnr].cache
+function M.get_locals(bufnr)
+ return queries.get_matches(bufnr, 'locals')
end
function M.get_definitions(bufnr)
@@ -112,27 +66,179 @@ function M.get_references(bufnr)
return refs
end
---- Return all nodes in locals corresponding to a specific capture (like @scope, @reference)
--- Works like M.get_references or M.get_scopes except you can choose the capture
--- Can also be a nested capture like @definition.function to get all nodes defining a function
-function M.get_capture_matches(bufnr, capture_string, query_kind)
- if not string.sub(capture_string, 1,2) == '@' then
- print('capture_string must start with "@"')
- return
- end
+-- Finds the definition node and it's scope node of a node
+-- @param node starting node
+-- @param bufnr buffer
+-- @returns the definition node and the definition nodes scope node
+function M.find_definition(node, bufnr)
+ local bufnr = bufnr or api.nvim_get_current_buf()
+ local node_text = ts_utils.get_node_text(node)[1]
+ local current_scope = M.containing_scope(node)
+ local matching_def_nodes = {}
- --remove leading "@"
- capture_string = string.sub(capture_string, 2)
+ -- If a scope wasn't found then use the root node
+ if current_scope == node then
+ current_scope = parsers.get_parser(bufnr).tree:root()
+ end
- local matches = {}
- for _, match in pairs(M.get_locals(bufnr, query_kind)) do
- local insert = utils.get_at_path(match, capture_string)
+ -- Get all definitions that match the node text
+ for _, def in ipairs(M.get_definitions(bufnr)) do
+ for _, def_node in ipairs(M.get_local_nodes(def)) do
+ if ts_utils.get_node_text(def_node)[1] == node_text then
+ table.insert(matching_def_nodes, def_node)
+ end
+ end
+ end
- if insert then
- table.insert(matches, insert)
+ -- Continue up each scope until we find the scope that contains the definition
+ while current_scope do
+ for _, def_node in ipairs(matching_def_nodes) do
+ if ts_utils.is_parent(current_scope, def_node) then
+ return def_node, current_scope
end
end
- return matches
+ current_scope = M.containing_scope(current_scope:parent())
+ end
+
+ return node, parsers.get_parser(bufnr).tree:root()
+end
+
+-- Gets all nodes from a local list result.
+-- @param local_def the local list result
+-- @returns a list of nodes
+function M.get_local_nodes(local_def)
+ local result = {}
+
+ M.recurse_local_nodes(local_def, function(_, node)
+ table.insert(result, node)
+ end)
+
+ return result
+end
+
+-- Recurse locals results until a node is found.
+-- The accumulator function is given
+-- * The table of the node
+-- * The node
+-- * The full definition match `@definition.var.something` -> 'var.something'
+-- * The last definition match `@definition.var.something` -> 'something'
+-- @param The locals result
+-- @param The accumulator function
+-- @param The full match path to append to
+-- @param The last match
+function M.recurse_local_nodes(local_def, accumulator, full_match, last_match)
+ if local_def.node then
+ accumulator(local_def, local_def.node, full_match, last_match)
+ else
+ for match_key, def in pairs(local_def) do
+ M.recurse_local_nodes(
+ def,
+ accumulator,
+ full_match and (full_match..'.'..match_key) or match_key,
+ match_key)
+ end
+ end
+end
+
+-- Finds usages of a node in a given scope
+-- @param node the node to find usages for
+-- @param scope_node the node to look within
+-- @returns a list of nodes
+function M.find_usages(node, scope_node, bufnr)
+ local bufnr = bufnr or api.nvim_get_current_buf()
+ local node_text = ts_utils.get_node_text(node)[1]
+
+ if not node_text or #node_text < 1 then return {} end
+
+ local scope_node = scope_node or parsers.get_parser(bufnr).tree:root()
+ local usages = {}
+
+ for match in M.iter_locals(bufnr, scope_node) do
+ if match.reference
+ and match.reference.node
+ and ts_utils.get_node_text(match.reference.node)[1] == node_text
+ then
+ table.insert(usages, match.reference.node)
+ end
+ end
+
+ return usages
+end
+
+function M.containing_scope(node, bufnr)
+ local bufnr = bufnr or api.nvim_get_current_buf()
+
+ local scopes = M.get_scopes(bufnr)
+ if not node or not scopes then return end
+
+ local iter_node = node
+
+ while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do
+ iter_node = iter_node:parent()
+ end
+
+ return iter_node or node
+end
+
+function M.nested_scope(node, cursor_pos)
+ local bufnr = api.nvim_get_current_buf()
+
+ local scopes = M.get_scopes(bufnr)
+ if not node or not scopes then return end
+
+ local row = cursor_pos.row
+ local col = cursor_pos.col
+ local scope = M.containing_scope(node)
+
+ for _, child in ipairs(ts_utils.get_named_children(scope)) do
+ local row_, col_ = child:start()
+ if vim.tbl_contains(scopes, child) and ((row_+1 == row and col_ > col) or row_+1 > row) then
+ return child
+ end
+ end
+end
+
+function M.next_scope(node)
+ local bufnr = api.nvim_get_current_buf()
+
+ local scopes = M.get_scopes(bufnr)
+ if not node or not scopes then return end
+
+ local scope = M.containing_scope(node)
+
+ local parent = scope:parent()
+ if not parent then return end
+
+ local is_prev = true
+ for _, child in ipairs(ts_utils.get_named_children(parent)) do
+ if child == scope then
+ is_prev = false
+ elseif not is_prev and vim.tbl_contains(scopes, child) then
+ return child
+ end
+ end
+end
+
+function M.previous_scope(node)
+ local bufnr = api.nvim_get_current_buf()
+
+ local scopes = M.get_scopes(bufnr)
+ if not node or not scopes then return end
+
+ local scope = M.containing_scope(node)
+
+ local parent = scope:parent()
+ if not parent then return end
+
+ local is_prev = true
+ local children = ts_utils.get_named_children(parent)
+ for i=#children,1,-1 do
+ if children[i] == scope then
+ is_prev = false
+ elseif not is_prev and vim.tbl_contains(scopes, children[i]) then
+ return children[i]
+ end
+ end
end
return M
diff --git a/lua/nvim-treesitter/parsers.lua b/lua/nvim-treesitter/parsers.lua
index dd4f14c7f..f4244ddd1 100644
--- a/lua/nvim-treesitter/parsers.lua
+++ b/lua/nvim-treesitter/parsers.lua
@@ -257,7 +257,7 @@ end
function M.has_parser(lang)
local buf = api.nvim_get_current_buf()
- local lang = M.get_buf_lang(buf) or lang
+ local lang = lang or M.get_buf_lang(buf)
if not lang or #lang == 0 then return false end
return #api.nvim_get_runtime_file('parser/' .. lang .. '.*', false) > 0
end
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua
index 69e52e1e6..896b11dad 100644
--- a/lua/nvim-treesitter/query.lua
+++ b/lua/nvim-treesitter/query.lua
@@ -1,25 +1,21 @@
local api = vim.api
local ts = vim.treesitter
+local utils = require'nvim-treesitter.utils'
+local parsers = require'nvim-treesitter.parsers'
local M = {}
-local function read_query_files(filenames)
- local contents = {}
-
- for _,filename in ipairs(filenames) do
- vim.list_extend(contents, vim.fn.readfile(filename))
+local default_dict = {
+ __index = function(table, key)
+ local exists = rawget(table, key)
+ if not exists then
+ table[key] = {}
+ end
+ return rawget(table, key)
end
+}
- return table.concat(contents, '\n')
-end
-
--- Creates a function that checks whether a certain query exists
--- for a specific language.
-local function get_query_guard(query)
- return function(lang)
- return M.get_query(lang, query) ~= nil
- end
-end
+local query_cache = setmetatable({}, default_dict)
-- Some treesitter grammars extend others.
-- We can use that to import the queries of the base language
@@ -36,10 +32,42 @@ M.query_extensions = {
M.built_in_query_groups = {'highlights', 'locals', 'textobjects'}
+-- Creates a function that checks whether a certain query exists
+-- for a specific language.
+local function get_query_guard(query)
+ return function(lang)
+ return M.get_query(lang, query) ~= nil
+ end
+end
+
for _, query in ipairs(M.built_in_query_groups) do
M["has_" .. query] = get_query_guard(query)
end
+local function read_query_files(filenames)
+ local contents = {}
+
+ for _,filename in ipairs(filenames) do
+ vim.list_extend(contents, vim.fn.readfile(filename))
+ end
+
+ return table.concat(contents, '\n')
+end
+
+local function update_cached_matches(bufnr, changed_tick, query_group)
+ query_cache[query_group][bufnr] = {tick=changed_tick, cache=( M.collect_group_results(bufnr, query_group) or {} )}
+end
+
+function M.get_matches(bufnr, query_group)
+ local bufnr = bufnr or api.nvim_get_current_buf()
+ local cached_local = query_cache[query_group][bufnr]
+ if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then
+ update_cached_matches(bufnr,api.nvim_buf_get_changedtick(bufnr), query_group)
+ end
+
+ return query_cache[query_group][bufnr].cache
+end
+
function M.get_query(lang, query_name)
local query_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', lang, query_name), true)
local query_string = ''
@@ -84,7 +112,6 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
return t
end
-
-- Given a path (i.e. a List(String)) this functions inserts value at path
local function insert_to_path(object, path, value)
local curr_obj = object
@@ -131,4 +158,58 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
end
end
+--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
+-- Works like M.get_references or M.get_scopes except you can choose the capture
+-- Can also be a nested capture like @definition.function to get all nodes defining a function
+function M.get_capture_matches(bufnr, capture_string, query_group)
+ if not string.sub(capture_string, 1,2) == '@' then
+ print('capture_string must start with "@"')
+ return
+ end
+
+ --remove leading "@"
+ capture_string = string.sub(capture_string, 2)
+
+ local matches = {}
+ for match in M.iter_group_results(bufnr, query_group) do
+ local insert = utils.get_at_path(match, capture_string)
+
+ if insert then
+ table.insert(matches, insert)
+ end
+ end
+ return matches
+end
+
+-- Iterates matches from a query file.
+-- @param bufnr the buffer
+-- @param query_group the query file to use
+-- @param root the root node
+function M.iter_group_results(bufnr, query_group, root)
+
+ local lang = parsers.get_buf_lang(bufnr)
+ if not lang then return end
+
+ local query = M.get_query(lang, query_group)
+ if not query then return end
+
+ local parser = parsers.get_parser(bufnr, lang)
+ if not parser then return end
+
+ local root = root or parser:parse():root()
+ local start_row, _, end_row, _ = root:range()
+
+ return M.iter_prepared_matches(query, root, bufnr, start_row, end_row)
+end
+
+function M.collect_group_results(bufnr, query_group, root)
+ local matches = {}
+
+ for prepared_match in M.iter_group_results(bufnr, query_group, root) do
+ table.insert(matches, prepared_match)
+ end
+
+ return matches
+end
+
return M
diff --git a/lua/nvim-treesitter/refactor/highlight_current_scope.lua b/lua/nvim-treesitter/refactor/highlight_current_scope.lua
index b32ccaf3b..8b785ef8d 100644
--- a/lua/nvim-treesitter/refactor/highlight_current_scope.lua
+++ b/lua/nvim-treesitter/refactor/highlight_current_scope.lua
@@ -1,6 +1,7 @@
-- This module highlights the current scope of at the cursor position
local ts_utils = require'nvim-treesitter.ts_utils'
+local locals = require'nvim-treesitter.locals'
local api = vim.api
local cmd = api.nvim_command
@@ -12,7 +13,7 @@ function M.highlight_current_scope(bufnr)
M.clear_highlights(bufnr)
local node_at_point = ts_utils.get_node_at_cursor()
- local current_scope = ts_utils.containing_scope(node_at_point, bufnr)
+ local current_scope = locals.containing_scope(node_at_point, bufnr)
local start_line = current_scope:start()
diff --git a/lua/nvim-treesitter/refactor/highlight_definitions.lua b/lua/nvim-treesitter/refactor/highlight_definitions.lua
index d1eb4712d..710814427 100644
--- a/lua/nvim-treesitter/refactor/highlight_definitions.lua
+++ b/lua/nvim-treesitter/refactor/highlight_definitions.lua
@@ -20,8 +20,8 @@ function M.highlight_usages(bufnr)
return
end
- local def_node, scope = ts_utils.find_definition(node_at_point, bufnr)
- local usages = ts_utils.find_usages(node_at_point, scope)
+ local def_node, scope = locals.find_definition(node_at_point, bufnr)
+ local usages = locals.find_usages(node_at_point, scope)
for _, usage_node in ipairs(usages) do
if usage_node ~= node_at_point then
diff --git a/lua/nvim-treesitter/refactor/navigation.lua b/lua/nvim-treesitter/refactor/navigation.lua
index ce653d6a7..160fbd2e3 100644
--- a/lua/nvim-treesitter/refactor/navigation.lua
+++ b/lua/nvim-treesitter/refactor/navigation.lua
@@ -16,7 +16,7 @@ function M.goto_definition(bufnr)
if not node_at_point then return end
- local definition, _ = ts_utils.find_definition(node_at_point, bufnr)
+ local definition, _ = locals.find_definition(node_at_point, bufnr)
local start_row, start_col, _ = definition:start()
api.nvim_win_set_cursor(0, { start_row + 1, start_col })
@@ -31,7 +31,7 @@ function M.list_definitions(bufnr)
local qf_list = {}
for _, def in ipairs(definitions) do
- ts_utils.recurse_local_nodes(def, function(_, node, _, match)
+ locals.recurse_local_nodes(def, function(_, node, _, match)
local lnum, col, _ = node:start()
table.insert(qf_list, {
diff --git a/lua/nvim-treesitter/refactor/smart_rename.lua b/lua/nvim-treesitter/refactor/smart_rename.lua
index e5ee37ac1..ad26085c8 100644
--- a/lua/nvim-treesitter/refactor/smart_rename.lua
+++ b/lua/nvim-treesitter/refactor/smart_rename.lua
@@ -2,6 +2,7 @@
-- Can be used directly using the `smart_rename` function.
local ts_utils = require'nvim-treesitter.ts_utils'
+local locals = require'nvim-treesitter.locals'
local configs = require'nvim-treesitter.configs'
local utils = require'nvim-treesitter.utils'
local api = vim.api
@@ -23,8 +24,8 @@ function M.smart_rename(bufnr)
-- Empty name cancels the interaction or ESC
if not new_name or #new_name < 1 then return end
- local definition, scope = ts_utils.find_definition(node_at_point, bufnr)
- local nodes_to_rename = ts_utils.find_usages(node_at_point, scope)
+ local definition, scope = locals.find_definition(node_at_point, bufnr)
+ local nodes_to_rename = locals.find_usages(node_at_point, scope)
if not vim.tbl_contains(nodes_to_rename, node_at_point) then
table.insert(nodes_to_rename, node_at_point)
diff --git a/lua/nvim-treesitter/textobjects.lua b/lua/nvim-treesitter/textobjects.lua
index b6c5d7356..2d6a25ff5 100644
--- a/lua/nvim-treesitter/textobjects.lua
+++ b/lua/nvim-treesitter/textobjects.lua
@@ -4,7 +4,6 @@ local ts = vim.treesitter
local configs = require "nvim-treesitter.configs"
local parsers = require "nvim-treesitter.parsers"
local queries = require'nvim-treesitter.query'
-local locals = require'nvim-treesitter.locals'
local ts_utils = require'nvim-treesitter.ts_utils'
local M = {}
@@ -20,10 +19,11 @@ function M.select_textobject(query_string)
local matches = {}
if string.match(query_string, '^@.*') then
- matches = locals.get_capture_matches(bufnr, query_string, 'textobjects')
+ matches = queries.get_capture_matches(bufnr, query_string, 'textobjects')
else
local parser = parsers.get_parser(bufnr, lang)
local root = parser:parse():root()
+
local start_row, _, end_row, _ = root:range()
local query = ts.parse_query(lang, query_string)
@@ -41,7 +41,7 @@ function M.select_textobject(query_string)
local earliest_start
for _, m in pairs(matches) do
- if ts_utils.is_in_node_range(m.node, row, col) then
+ if m.node and ts_utils.is_in_node_range(m.node, row, col) then
local length = ts_utils.node_length(m.node)
if not match_length or length < match_length then
smallest_range = m
diff --git a/lua/nvim-treesitter/ts_utils.lua b/lua/nvim-treesitter/ts_utils.lua
index 0882155c4..b01783d12 100644
--- a/lua/nvim-treesitter/ts_utils.lua
+++ b/lua/nvim-treesitter/ts_utils.lua
@@ -1,6 +1,5 @@
local api = vim.api
-local locals = require'nvim-treesitter.locals'
local parsers = require'nvim-treesitter.parsers'
local M = {}
@@ -105,40 +104,6 @@ function M.get_previous_node(node, allow_switch_parents, allow_previous_parent)
return destination_node
end
-function M.parent_scope(node, cursor_pos)
- local bufnr = api.nvim_get_current_buf()
-
- local scopes = locals.get_scopes(bufnr)
- if not node or not scopes then return end
-
- local row = cursor_pos.row
- local col = cursor_pos.col
- local iter_node = node
-
- while iter_node ~= nil do
- local row_, col_ = iter_node:start()
- if vim.tbl_contains(scopes, iter_node) and (row_+1 ~= row or col_ ~= col) then
- return iter_node
- end
- iter_node = iter_node:parent()
- end
-end
-
-function M.containing_scope(node, bufnr)
- local bufnr = bufnr or api.nvim_get_current_buf()
-
- local scopes = locals.get_scopes(bufnr)
- if not node or not scopes then return end
-
- local iter_node = node
-
- while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do
- iter_node = iter_node:parent()
- end
-
- return iter_node or node
-end
-
function M.get_named_children(node)
local nodes = {}
for i=0,node:named_child_count() - 1,1 do
@@ -147,172 +112,12 @@ function M.get_named_children(node)
return nodes
end
-function M.nested_scope(node, cursor_pos)
- local bufnr = api.nvim_get_current_buf()
-
- local scopes = locals.get_scopes(bufnr)
- if not node or not scopes then return end
-
- local row = cursor_pos.row
- local col = cursor_pos.col
- local scope = M.containing_scope(node)
-
- for _, child in ipairs(M.get_named_children(scope)) do
- local row_, col_ = child:start()
- if vim.tbl_contains(scopes, child) and ((row_+1 == row and col_ > col) or row_+1 > row) then
- return child
- end
- end
-end
-
-function M.next_scope(node)
- local bufnr = api.nvim_get_current_buf()
-
- local scopes = locals.get_scopes(bufnr)
- if not node or not scopes then return end
-
- local scope = M.containing_scope(node)
-
- local parent = scope:parent()
- if not parent then return end
-
- local is_prev = true
- for _, child in ipairs(M.get_named_children(parent)) do
- if child == scope then
- is_prev = false
- elseif not is_prev and vim.tbl_contains(scopes, child) then
- return child
- end
- end
-end
-
-function M.previous_scope(node)
- local bufnr = api.nvim_get_current_buf()
-
- local scopes = locals.get_scopes(bufnr)
- if not node or not scopes then return end
-
- local scope = M.containing_scope(node)
-
- local parent = scope:parent()
- if not parent then return end
-
- local is_prev = true
- local children = M.get_named_children(parent)
- for i=#children,1,-1 do
- if children[i] == scope then
- is_prev = false
- elseif not is_prev and vim.tbl_contains(scopes, children[i]) then
- return children[i]
- end
- end
-end
-
function M.get_node_at_cursor(winnr)
local cursor = api.nvim_win_get_cursor(winnr or 0)
local root = parsers.get_parser():parse():root()
return root:named_descendant_for_range(cursor[1]-1,cursor[2],cursor[1]-1,cursor[2])
end
--- Finds the definition node and it's scope node of a node
--- @param node starting node
--- @param bufnr buffer
--- @returns the definition node and the definition nodes scope node
-function M.find_definition(node, bufnr)
- local bufnr = bufnr or api.nvim_get_current_buf()
- local node_text = M.get_node_text(node)[1]
- local current_scope = M.containing_scope(node)
- local matching_def_nodes = {}
-
- -- If a scope wasn't found then use the root node
- if current_scope == node then
- current_scope = parsers.get_parser(bufnr).tree:root()
- end
-
- -- Get all definitions that match the node text
- for _, def in ipairs(locals.get_definitions(bufnr)) do
- for _, def_node in ipairs(M.get_local_nodes(def)) do
- if M.get_node_text(def_node)[1] == node_text then
- table.insert(matching_def_nodes, def_node)
- end
- end
- end
-
- -- Continue up each scope until we find the scope that contains the definition
- while current_scope do
- for _, def_node in ipairs(matching_def_nodes) do
- if M.is_parent(current_scope, def_node) then
- return def_node, current_scope
- end
- end
- current_scope = M.containing_scope(current_scope:parent())
- end
-
- return node, parsers.get_parser(bufnr).tree:root()
-end
-
--- Gets all nodes from a local list result.
--- @param local_def the local list result
--- @returns a list of nodes
-function M.get_local_nodes(local_def)
- local result = {}
-
- M.recurse_local_nodes(local_def, function(_, node)
- table.insert(result, node)
- end)
-
- return result
-end
-
--- Recurse locals results until a node is found.
--- The accumulator function is given
--- * The table of the node
--- * The node
--- * The full definition match `@definition.var.something` -> 'var.something'
--- * The last definition match `@definition.var.something` -> 'something'
--- @param The locals result
--- @param The accumulator function
--- @param The full match path to append to
--- @param The last match
-function M.recurse_local_nodes(local_def, accumulator, full_match, last_match)
- if local_def.node then
- accumulator(local_def, local_def.node, full_match, last_match)
- else
- for match_key, def in pairs(local_def) do
- M.recurse_local_nodes(
- def,
- accumulator,
- full_match and (full_match..'.'..match_key) or match_key,
- match_key)
- end
- end
-end
-
--- Finds usages of a node in a given scope
--- @param node the node to find usages for
--- @param scope_node the node to look within
--- @returns a list of nodes
-function M.find_usages(node, scope_node, bufnr)
- local bufnr = bufnr or api.nvim_get_current_buf()
- local node_text = M.get_node_text(node)[1]
-
- if not node_text or #node_text < 1 then return {} end
-
- local scope_node = scope_node or parsers.get_parser(bufnr).tree:root()
- local usages = {}
-
- for match in locals.iter_locals(bufnr, scope_node) do
- if match.reference
- and match.reference.node
- and M.get_node_text(match.reference.node)[1] == node_text
- then
- table.insert(usages, match.reference.node)
- end
- end
-
- return usages
-end
-
function M.highlight_node(node, buf, hl_namespace, hl_group)
if not node then return end
M.highlight_range({node:range()}, buf, hl_namespace, hl_group)