aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/nvim-treesitter/incremental_selection.lua39
1 files changed, 35 insertions, 4 deletions
diff --git a/lua/nvim-treesitter/incremental_selection.lua b/lua/nvim-treesitter/incremental_selection.lua
index fa833f90a..ac2c98386 100644
--- a/lua/nvim-treesitter/incremental_selection.lua
+++ b/lua/nvim-treesitter/incremental_selection.lua
@@ -22,11 +22,32 @@ end
local function visual_selection_range()
local _, csrow, cscol, _ = unpack(vim.fn.getpos("'<"))
local _, cerow, cecol, _ = unpack(vim.fn.getpos("'>"))
+
+ local start_row, start_col, end_row, end_col
+
if csrow < cerow or (csrow == cerow and cscol <= cecol) then
- return csrow - 1, cscol - 1, cerow - 1, cecol
+ start_row = csrow - 1
+ start_col = cscol - 1
+ end_row = cerow - 1
+ end_col = cecol
else
- return cerow - 1, cecol - 1, csrow - 1, cscol
+ start_row = cerow - 1
+ start_col = cecol - 1
+ end_row = csrow - 1
+ end_col = cscol
+ end
+
+ -- The last char in ts is equivalent to the EOF in another line.
+ local last_row = vim.fn.line("$")
+ local last_col = vim.fn.col({last_row, "$"})
+ last_row = last_row - 1
+ last_col = last_col - 1
+ if end_row == last_row and end_col == last_col then
+ end_row = end_row + 1
+ end_col = 0
end
+
+ return start_row, start_col, end_row, end_col
end
local function range_matches(node)
@@ -57,8 +78,18 @@ local function select_incremental(get_parent)
-- Find a node that changes the current selection.
local node = nodes[#nodes]
while true do
- node = get_parent(node)
- if not node then return end
+ local parent = get_parent(node)
+ if not parent or parent == node then
+ -- Keep searching in the main tree
+ -- TODO: we should search on the parent tree of the current node.
+ local root = parsers.get_parser():parse()[1]:root()
+ parent = root:named_descendant_for_range(csrow, cscol, cerow, cecol)
+ if not parent or parent == node then
+ ts_utils.update_selection(buf, node)
+ return
+ end
+ end
+ node = parent
local srow, scol, erow, ecol = node:range()
local same_range = (
srow == csrow