aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/query.lua30
-rw-r--r--lua/nvim-treesitter/textobjects.lua70
2 files changed, 60 insertions, 40 deletions
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua
index c630b366c..034df223d 100644
--- a/lua/nvim-treesitter/query.lua
+++ b/lua/nvim-treesitter/query.lua
@@ -212,6 +212,36 @@ function M.get_capture_matches(bufnr, capture_string, query_group)
return matches
end
+function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function)
+ if not string.sub(capture_string, 1,2) == '@' then
+ api.nvim_err_writeln('capture_string must start with "@"')
+ return
+ end
+
+ --remove leading "@"
+ capture_string = string.sub(capture_string, 2)
+
+ local best
+ local best_score
+
+ for maybe_match in M.iter_group_results(bufnr, query_group) do
+ local match = utils.get_at_path(maybe_match, capture_string)
+
+ if match and filter_predicate(match) then
+ local current_score = scoring_function(match)
+ if not best then
+ best = match
+ best_score = current_score
+ end
+ if current_score > best_score then
+ best = match
+ best_score = current_score
+ end
+ end
+ end
+ return best
+end
+
-- Iterates matches from a query file.
-- @param bufnr the buffer
-- @param query_group the query file to use
diff --git a/lua/nvim-treesitter/textobjects.lua b/lua/nvim-treesitter/textobjects.lua
index af2c1fff8..0646a8ef0 100644
--- a/lua/nvim-treesitter/textobjects.lua
+++ b/lua/nvim-treesitter/textobjects.lua
@@ -145,56 +145,46 @@ end
function M.next_textobject(node, query_string, same_parent, bufnr)
local node = node or ts_utils.get_node_at_cursor()
+ if not node then return end
+ local _, _, node_end = node:end_()
local bufnr = bufnr or api.nvim_get_current_buf()
- local matches = queries.get_capture_matches(bufnr, query_string, 'textobjects')
- local _, _ , node_end = node:end_()
- local next_node
- local next_node_start
-
- for _, m in pairs(matches) do
- local _, _, other_end = m.node:start()
- if other_end > node_end then
- if not same_parent or node:parent() == m.node:parent() then
- if not next_node then
- next_node = m
- _, _, next_node_start = next_node.node:start()
- end
- if other_end < next_node_start then
- next_node = m
- _, _, next_node_start = next_node.node:start()
- end
- end
- end
- end
+ local next_node = queries.find_best_match(bufnr,
+ query_string,
+ 'textobjects',
+ function(match)
+ if not same_parent or node:parent() == match.node:parent() then
+ local _, _, start = match.node:start()
+ return start > node_end
+ end
+ end,
+ function(match)
+ local _, _, node_start = match.node:start()
+ return -node_start
+ end)
return next_node and next_node.node
end
function M.previous_textobject(node, query_string, same_parent, bufnr)
local node = node or ts_utils.get_node_at_cursor()
+ if not node then return end
+ local _, _, node_start = node:start()
local bufnr = bufnr or api.nvim_get_current_buf()
- local matches = queries.get_capture_matches(bufnr, query_string, 'textobjects')
- local _, _ , node_start = node:start()
- local previous_node
- local previous_node_end
-
- for _, m in pairs(matches) do
- local _, _, other_end = m.node:end_()
- if other_end < node_start then
- if not same_parent or node:parent() == m.node:parent() then
- if not previous_node then
- previous_node = m
- _, _, previous_node_end = previous_node.node:end_()
- end
- if other_end > previous_node_end then
- previous_node = m
- _, _, previous_node_end = previous_node.node:end_()
- end
- end
- end
- end
+ local previous_node = queries.find_best_match(bufnr,
+ query_string,
+ 'textobjects',
+ function(match)
+ if not same_parent or node:parent() == match.node:parent() then
+ local _, _, end_ = match.node:end_()
+ return end_ < node_start
+ end
+ end,
+ function(match)
+ local _, _, node_end = match.node:end_()
+ return node_end
+ end)
return previous_node and previous_node.node
end