aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
authorSteven Sojka <steelsojka@users.noreply.github.com>2021-03-30 08:18:24 -0500
committerGitHub <noreply@github.com>2021-03-30 08:18:24 -0500
commit6863f79118d3cb331fd4e726cdb2384bbd8bf8f2 (patch)
treec703b490f9c3e7e601e673704984b42f9edfe6ab /lua
parentFix jsdoc: play nice with the comment parser (#1108) (diff)
downloadnvim-treesitter-6863f79118d3cb331fd4e726cdb2384bbd8bf8f2.tar
nvim-treesitter-6863f79118d3cb331fd4e726cdb2384bbd8bf8f2.tar.gz
nvim-treesitter-6863f79118d3cb331fd4e726cdb2384bbd8bf8f2.tar.bz2
nvim-treesitter-6863f79118d3cb331fd4e726cdb2384bbd8bf8f2.tar.lz
nvim-treesitter-6863f79118d3cb331fd4e726cdb2384bbd8bf8f2.tar.xz
nvim-treesitter-6863f79118d3cb331fd4e726cdb2384bbd8bf8f2.tar.zst
nvim-treesitter-6863f79118d3cb331fd4e726cdb2384bbd8bf8f2.zip
refactor(all): language tree adaption (#1105)
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/fold.lua24
-rw-r--r--lua/nvim-treesitter/indent.lua21
-rw-r--r--lua/nvim-treesitter/locals.lua8
-rw-r--r--lua/nvim-treesitter/parsers.lua2
-rw-r--r--lua/nvim-treesitter/query.lua85
-rw-r--r--lua/nvim-treesitter/ts_utils.lua87
-rw-r--r--lua/nvim-treesitter/tsrange.lua8
-rw-r--r--lua/nvim-treesitter/utils.lua12
8 files changed, 190 insertions, 57 deletions
diff --git a/lua/nvim-treesitter/fold.lua b/lua/nvim-treesitter/fold.lua
index faaf5f542..d1416ef4b 100644
--- a/lua/nvim-treesitter/fold.lua
+++ b/lua/nvim-treesitter/fold.lua
@@ -1,5 +1,5 @@
local api = vim.api
-local utils = require'nvim-treesitter.ts_utils'
+local tsutils = require'nvim-treesitter.ts_utils'
local query = require'nvim-treesitter.query'
local parsers = require'nvim-treesitter.parsers'
@@ -7,18 +7,19 @@ local M = {}
-- This is cached on buf tick to avoid computing that multiple times
-- Especially not for every line in the file when `zx` is hit
-local folds_levels = utils.memoize_by_buf_tick(function(bufnr)
- local lang = parsers.get_buf_lang(bufnr)
+local folds_levels = tsutils.memoize_by_buf_tick(function(bufnr)
local max_fold_level = api.nvim_win_get_option(0, 'foldnestmax')
+ local parser = parsers.get_parser(bufnr)
- local matches
- if query.has_folds(lang) then
- matches = query.get_capture_matches(bufnr, "@fold", "folds")
- elseif query.has_locals(lang) then
- matches = query.get_capture_matches(bufnr, "@scope", "locals")
- else
- return {}
- end
+ if not parser then return {} end
+
+ local matches = query.get_capture_matches_recursively(bufnr, function(lang)
+ if query.has_folds(lang) then
+ return "@fold", "folds"
+ elseif query.has_locals(lang) then
+ return "@scope", "locals"
+ end
+ end)
local levels_tmp = {}
@@ -35,7 +36,6 @@ local folds_levels = utils.memoize_by_buf_tick(function(bufnr)
levels_tmp[start] = (levels_tmp[start] or 0) + 1
levels_tmp[stop] = (levels_tmp[stop] or 0) - 1
end
-
end
local levels = {}
diff --git a/lua/nvim-treesitter/indent.lua b/lua/nvim-treesitter/indent.lua
index cdbf66489..d0e71b4c6 100644
--- a/lua/nvim-treesitter/indent.lua
+++ b/lua/nvim-treesitter/indent.lua
@@ -1,6 +1,6 @@
local parsers = require'nvim-treesitter.parsers'
local queries = require'nvim-treesitter.query'
-local utils = require'nvim-treesitter.ts_utils'
+local tsutils = require'nvim-treesitter.ts_utils'
local M = {}
@@ -21,9 +21,9 @@ local function node_fmt(node)
return tostring(node)
end
-local get_indents = utils.memoize_by_buf_tick(function(bufnr)
+local get_indents = tsutils.memoize_by_buf_tick(function(bufnr, root, lang)
local get_map = function(capture)
- local matches = queries.get_capture_matches(bufnr, capture, 'indents') or {}
+ local matches = queries.get_capture_matches(bufnr, capture, 'indents', root, lang) or {}
local map = {}
for _, node in ipairs(matches) do
map[tostring(node)] = true
@@ -37,14 +37,23 @@ local get_indents = utils.memoize_by_buf_tick(function(bufnr)
returns = get_map('@return.node'),
ignores = get_map('@ignore.node'),
}
-end)
+end, {
+ -- Memoize by bufnr and lang together.
+ key = function(bufnr, _, lang)
+ return tostring(bufnr) .. '_' .. lang
+ end
+})
function M.get_indent(lnum)
local parser = parsers.get_parser()
if not parser or not lnum then return -1 end
- local q = get_indents(vim.api.nvim_get_current_buf())
- local root = parser:parse()[1]:root()
+ local root, _, lang_tree = tsutils.get_root_for_position(lnum, 0, parser)
+
+ -- Not likely, but just in case...
+ if not root then return 0 end
+
+ local q = get_indents(vim.api.nvim_get_current_buf(), root, lang_tree:lang())
local node = get_node_at_line(root, lnum-1)
local indent = 0
diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua
index 01cacc6e5..17982340d 100644
--- a/lua/nvim-treesitter/locals.lua
+++ b/lua/nvim-treesitter/locals.lua
@@ -3,7 +3,6 @@
-- its the way nvim-treesitter uses to "understand" the code
local queries = require'nvim-treesitter.query'
-local parsers = require'nvim-treesitter.parsers'
local ts_utils = require'nvim-treesitter.ts_utils'
local api = vim.api
@@ -91,13 +90,12 @@ end
--- Iterates over a nodes scopes moving from the bottom up
function M.iter_scope_tree(node, bufnr)
local last_node = node
-
return function()
if not last_node then
return
end
- local scope = M.containing_scope(last_node, bufnr, false) or parsers.get_tree_root(bufnr)
+ local scope = M.containing_scope(last_node, bufnr, false) or ts_utils.get_root_for_node(node)
last_node = scope:parent()
@@ -222,7 +220,7 @@ function M.find_definition(node, bufnr)
end
end
- return node, parsers.get_tree_root(bufnr), nil
+ return node, ts_utils.get_root_for_node(node), nil
end
-- Finds usages of a node in a given scope.
@@ -235,7 +233,7 @@ function M.find_usages(node, scope_node, bufnr)
if not node_text or #node_text < 1 then return {} end
- local scope_node = scope_node or parsers.get_parser(bufnr):parse()[1]:root()
+ local scope_node = scope_node or ts_utils.get_root_for_node(node)
local usages = {}
for match in M.iter_locals(bufnr, scope_node) do
diff --git a/lua/nvim-treesitter/parsers.lua b/lua/nvim-treesitter/parsers.lua
index 9bac6a8f3..a9d779ef2 100644
--- a/lua/nvim-treesitter/parsers.lua
+++ b/lua/nvim-treesitter/parsers.lua
@@ -584,6 +584,8 @@ function M.get_parser(bufnr, lang)
end
end
+-- @deprecated This is only kept for legacy purposes.
+-- All root nodes should be accounted for.
function M.get_tree_root(bufnr)
local bufnr = bufnr or api.nvim_get_current_buf()
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua
index df3e70805..e6683139e 100644
--- a/lua/nvim-treesitter/query.lua
+++ b/lua/nvim-treesitter/query.lua
@@ -7,6 +7,8 @@ local caching = require'nvim-treesitter.caching'
local M = {}
+local EMPTY_ITER = function() end
+
M.built_in_query_groups = {'highlights', 'locals', 'folds', 'indents'}
-- Creates a function that checks whether a given query exists
@@ -166,7 +168,7 @@ end
--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
-- Works like M.get_references or M.get_scopes except you can choose the capture
-- Can also be a nested capture like @definition.function to get all nodes defining a function
-function M.get_capture_matches(bufnr, capture_string, query_group)
+function M.get_capture_matches(bufnr, capture_string, query_group, root, lang)
if not string.sub(capture_string, 1,2) == '@' then
print('capture_string must start with "@"')
return
@@ -176,7 +178,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group)
capture_string = string.sub(capture_string, 2)
local matches = {}
- for match in M.iter_group_results(bufnr, query_group) do
+ for match in M.iter_group_results(bufnr, query_group, root, lang) do
local insert = utils.get_at_path(match, capture_string)
if insert then
@@ -186,7 +188,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group)
return matches
end
-function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function)
+function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root)
if not string.sub(capture_string, 1,2) == '@' then
api.nvim_err_writeln('capture_string must start with "@"')
return
@@ -198,7 +200,7 @@ function M.find_best_match(bufnr, capture_string, query_group, filter_predicate,
local best
local best_score
- for maybe_match in M.iter_group_results(bufnr, query_group) do
+ for maybe_match in M.iter_group_results(bufnr, query_group, root) do
local match = utils.get_at_path(maybe_match, capture_string)
if match and filter_predicate(match) then
@@ -220,31 +222,82 @@ end
-- @param bufnr the buffer
-- @param query_group the query file to use
-- @param root the root node
-function M.iter_group_results(bufnr, query_group, root)
- local lang = parsers.get_buf_lang(bufnr)
- if not lang then return function() end end
+-- @param root the root node lang, if known
+function M.iter_group_results(bufnr, query_group, root, root_lang)
+ local buf_lang = parsers.get_buf_lang(bufnr)
+
+ if not buf_lang then return EMPTY_ITER end
+
+ local parser = parsers.get_parser(bufnr, buf_lang)
+ if not parser then return EMPTY_ITER end
- local query = M.get_query(lang, query_group)
- if not query then return function() end end
+ if not root then
+ local first_tree = parser:trees()[1]
+
+ if first_tree then
+ root = first_tree:root()
+ end
+ end
+
+ if not root then return EMPTY_ITER end
+
+ local range = {root:range()}
+
+ if not root_lang then
+ local lang_tree = parser:language_for_range(range)
+
+ if lang_tree then
+ root_lang = lang_tree:lang()
+ end
+ end
- local parser = parsers.get_parser(bufnr, lang)
- if not parser then return function() end end
+ if not root_lang then return EMPTY_ITER end
- local root = root or parser:parse()[1]:root()
- local start_row, _, end_row, _ = root:range()
+ local query = M.get_query(root_lang, query_group)
+ if not query then return EMPTY_ITER end
-- The end row is exclusive so we need to add 1 to it.
- return M.iter_prepared_matches(query, root, bufnr, start_row, end_row + 1)
+ return M.iter_prepared_matches(query, root, bufnr, range[1], range[3] + 1)
end
-function M.collect_group_results(bufnr, query_group, root)
+function M.collect_group_results(bufnr, query_group, root, lang)
local matches = {}
- for prepared_match in M.iter_group_results(bufnr, query_group, root) do
+ for prepared_match in M.iter_group_results(bufnr, query_group, root, lang) do
table.insert(matches, prepared_match)
end
return matches
end
+--- Same as get_capture_matches except this will recursively get matches for every language in the tree.
+-- @param bufnr The bufnr
+-- @param capture_or_fn The capture to get. If a function is provided then that
+-- function will be used to resolve both the capture and query argument.
+-- The function can return `nil` to ignore that tree.
+-- @param query_type The query to get the capture from. This is ignore if a function is provided
+-- for the captuer argument.
+function M.get_capture_matches_recursively(bufnr, capture_or_fn, query_type)
+ local type_fn = type(capture_or_fn) == 'function'
+ and capture_or_fn
+ or function()
+ return capture_or_fn, query_type
+ end
+ local parser = parsers.get_parser(bufnr)
+ local matches = {}
+
+ if parser then
+ parser:for_each_tree(function(tree, lang_tree)
+ local lang = lang_tree:lang()
+ local capture, type_ = type_fn(lang, tree, lang_tree)
+
+ if capture then
+ vim.list_extend(matches, M.get_capture_matches(bufnr, capture, type_, tree:root(), lang))
+ end
+ end)
+ end
+
+ return matches
+end
+
return M
diff --git a/lua/nvim-treesitter/ts_utils.lua b/lua/nvim-treesitter/ts_utils.lua