diff options
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/nvim-treesitter/query.lua | 30 | ||||
| -rw-r--r-- | lua/nvim-treesitter/textobjects.lua | 70 |
2 files changed, 60 insertions, 40 deletions
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index c630b366c..034df223d 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -212,6 +212,36 @@ function M.get_capture_matches(bufnr, capture_string, query_group) return matches end +function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function) + if not string.sub(capture_string, 1,2) == '@' then + api.nvim_err_writeln('capture_string must start with "@"') + return + end + + --remove leading "@" + capture_string = string.sub(capture_string, 2) + + local best + local best_score + + for maybe_match in M.iter_group_results(bufnr, query_group) do + local match = utils.get_at_path(maybe_match, capture_string) + + if match and filter_predicate(match) then + local current_score = scoring_function(match) + if not best then + best = match + best_score = current_score + end + if current_score > best_score then + best = match + best_score = current_score + end + end + end + return best +end + -- Iterates matches from a query file. -- @param bufnr the buffer -- @param query_group the query file to use diff --git a/lua/nvim-treesitter/textobjects.lua b/lua/nvim-treesitter/textobjects.lua index af2c1fff8..0646a8ef0 100644 --- a/lua/nvim-treesitter/textobjects.lua +++ b/lua/nvim-treesitter/textobjects.lua @@ -145,56 +145,46 @@ end function M.next_textobject(node, query_string, same_parent, bufnr) local node = node or ts_utils.get_node_at_cursor() + if not node then return end + local _, _, node_end = node:end_() local bufnr = bufnr or api.nvim_get_current_buf() - local matches = queries.get_capture_matches(bufnr, query_string, 'textobjects') - local _, _ , node_end = node:end_() - local next_node - local next_node_start - - for _, m in pairs(matches) do - local _, _, other_end = m.node:start() - if other_end > node_end then - if not same_parent or node:parent() == m.node:parent() then - if not next_node then - next_node = m - _, _, next_node_start = next_node.node:start() - end - if other_end < next_node_start then - next_node = m - _, _, next_node_start = next_node.node:start() - end - end - end - end + local next_node = queries.find_best_match(bufnr, + query_string, + 'textobjects', + function(match) + if not same_parent or node:parent() == match.node:parent() then + local _, _, start = match.node:start() + return start > node_end + end + end, + function(match) + local _, _, node_start = match.node:start() + return -node_start + end) return next_node and next_node.node end function M.previous_textobject(node, query_string, same_parent, bufnr) local node = node or ts_utils.get_node_at_cursor() + if not node then return end + local _, _, node_start = node:start() local bufnr = bufnr or api.nvim_get_current_buf() - local matches = queries.get_capture_matches(bufnr, query_string, 'textobjects') - local _, _ , node_start = node:start() - local previous_node - local previous_node_end - - for _, m in pairs(matches) do - local _, _, other_end = m.node:end_() - if other_end < node_start then - if not same_parent or node:parent() == m.node:parent() then - if not previous_node then - previous_node = m - _, _, previous_node_end = previous_node.node:end_() - end - if other_end > previous_node_end then - previous_node = m - _, _, previous_node_end = previous_node.node:end_() - end - end - end - end + local previous_node = queries.find_best_match(bufnr, + query_string, + 'textobjects', + function(match) + if not same_parent or node:parent() == match.node:parent() then + local _, _, end_ = match.node:end_() + return end_ < node_start + end + end, + function(match) + local _, _, node_end = match.node:end_() + return node_end + end) return previous_node and previous_node.node end |
