From 69cabc69be49bc37c8a9bbb7def1415504b489b2 Mon Sep 17 00:00:00 2001 From: Stephan Seitz Date: Sun, 21 Jun 2020 20:38:00 +0200 Subject: Add textobjects module --- lua/nvim-treesitter/configs.lua | 37 +++++++++-- lua/nvim-treesitter/incremental_selection.lua | 38 +++++------ lua/nvim-treesitter/query.lua | 1 + lua/nvim-treesitter/textobjects.lua | 96 +++++++++++++++++++++++++++ lua/nvim-treesitter/ts_utils.lua | 55 ++++++++++++++- lua/nvim-treesitter/utils.lua | 2 +- 6 files changed, 203 insertions(+), 26 deletions(-) create mode 100644 lua/nvim-treesitter/textobjects.lua (limited to 'lua') diff --git a/lua/nvim-treesitter/configs.lua b/lua/nvim-treesitter/configs.lua index abed55a7d..f09c8ac99 100644 --- a/lua/nvim-treesitter/configs.lua +++ b/lua/nvim-treesitter/configs.lua @@ -4,6 +4,19 @@ local queries = require'nvim-treesitter.query' local parsers = require'nvim-treesitter.parsers' local utils = require'nvim-treesitter.utils' +local M = {} + +local function has_some_textobject_mapping(lang) + for _, v in pairs(M.get_module('textobjects').keymaps) do + if type(v) == 'table' then + if v[lang] then + return true + end + end + end + return false +end + local config = { modules = {}, ensure_installed = nil @@ -65,13 +78,22 @@ local builtin_modules = { list_definitions = "gnD" } } + }, + textobjects = { + module_path = 'nvim-treesitter.textobjects', + enable = false, + disable = {}, + is_supported = function(lang) + return has_some_textobject_mapping(lang) or queries.has_textobjects(lang) + end, + keymaps = { + inverse_mappings = true + } } } local special_config_keys = {'enable', 'disable', 'keymaps'} -local M = {} - -- Resolves a module by requiring the `module_path` or using the module definition. local function resolve_module(mod_name) local config_mod = M.get_module(mod_name) @@ -279,12 +301,17 @@ function M.setup_module(mod, data, mod_name) mod.disable = data.disable end if mod.keymaps and type(data.keymaps) == 'table' then - for f, map in pairs(data.keymaps) do - if mod.keymaps[f] then - mod.keymaps[f] = map + if mod.keymaps.inverse_mappings then + mod.keymaps = data.keymaps + else + for f, map in pairs(data.keymaps) do + if mod.keymaps[f] then + mod.keymaps[f] = map + end end end end + for k, v in pairs(data) do -- Just copy all non-special configuration keys if not vim.tbl_contains(special_config_keys, k) then diff --git a/lua/nvim-treesitter/incremental_selection.lua b/lua/nvim-treesitter/incremental_selection.lua index 3b76b25f2..ecb02330d 100644 --- a/lua/nvim-treesitter/incremental_selection.lua +++ b/lua/nvim-treesitter/incremental_selection.lua @@ -8,32 +8,32 @@ local M = {} local selections = {} -local function update_selection(buf, node) - local start_row, start_col, end_row, end_col = node:range() - - if end_row == vim.fn.line('$') then - end_col = #vim.fn.getline('$') - end - - vim.fn.setpos(".", { buf, start_row+1, start_col+1, 0 }) - vim.fn.nvim_exec("normal v", false) - vim.fn.setpos(".", { buf, end_row+1, end_col+1, 0 }) -end - function M.init_selection() local buf = api.nvim_get_current_buf() local node = ts_utils.get_node_at_cursor() selections[buf] = { [1] = node } - update_selection(buf, node) + ts_utils.update_selection(buf, node) +end + +-- moves 0-based node position by one character +local function inclusive_pos_to_exclusive(row, col) + local line = vim.fn.getline(row + 1) + + -- move by one character changes row? + if #line == col + 1 then + return row + 1, 0 + else + return row, col + 1 + end end local function visual_selection_range() local _, csrow, cscol, _ = unpack(vim.fn.getpos("'<")) local _, cerow, cecol, _ = unpack(vim.fn.getpos("'>")) - if csrow < cerow then - return csrow-1, cscol-1, cerow-1, cecol-1 + if csrow < cerow or (csrow == cerow and cscol <= cecol) then + return csrow-1, cscol-1, inclusive_pos_to_exclusive(cerow-1, cecol-1) else - return cerow-1, cecol-1, csrow-1, cscol-1 + return cerow-1, cecol-1, inclusive_pos_to_exclusive(csrow-1, cscol-1) end end @@ -53,7 +53,7 @@ local function select_incremental(get_parent) local csrow, cscol, cerow, cecol = visual_selection_range() local root = parsers.get_parser().tree:root() local node = root:named_descendant_for_range(csrow, cscol, cerow, cecol) - update_selection(buf, node) + ts_utils.update_selection(buf, node) selections[buf] = { [1] = node } return end @@ -65,7 +65,7 @@ local function select_incremental(get_parent) table.insert(nodes, node) end - update_selection(buf, node) + ts_utils.update_selection(buf, node) end end @@ -84,7 +84,7 @@ function M.node_decremental() table.remove(selections[buf]) local node = nodes[#nodes] - update_selection(buf, node) + ts_utils.update_selection(buf, node) end function M.attach(bufnr) diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index 72cead9dd..7316c79cb 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -35,6 +35,7 @@ M.query_extensions = { } M.has_locals = get_query_guard('locals') +M.has_textobjects = get_query_guard('textobjects') M.has_highlights = get_query_guard('highlights') function M.get_query(lang, query_name) diff --git a/lua/nvim-treesitter/textobjects.lua b/lua/nvim-treesitter/textobjects.lua new file mode 100644 index 000000000..cf81b5963 --- /dev/null +++ b/lua/nvim-treesitter/textobjects.lua @@ -0,0 +1,96 @@ +local api = vim.api +local ts = vim.treesitter + +local configs = require "nvim-treesitter.configs" +local parsers = require "nvim-treesitter.parsers" +local queries = require'nvim-treesitter.query' +local locals = require'nvim-treesitter.locals' +local ts_utils = require'nvim-treesitter.ts_utils' + +local M = {} + +function M.select_textobject(query_string) + local bufnr = vim.api.nvim_get_current_buf() + local ft = api.nvim_buf_get_option(bufnr, "ft") + if not ft then return end + local lang = parsers.ft_to_lang(ft) + + local row, col = unpack(vim.api.nvim_win_get_cursor(0)) + row = row - 1 + + local matches = {} + + if string.match(query_string, '^@.*') then + matches = locals.get_capture_matches(bufnr, query_string, 'textobjects') + else + local parser = parsers.get_parser(bufnr, lang) + local root = parser:parse():root() + local start_row, _, end_row, _ = root:range() + + local nested = {} + local query = ts.parse_query(lang, query_string) + for m in queries.iter_prepared_matches(query, root, bufnr, start_row, end_row) do + for _, n in pairs(m) do + if n.node then + table.insert(matches, n.node) + end + end + end + end + + local match_length + local smallest_range + + for _, m in pairs(matches) do + if ts_utils.is_in_node_range(m, row, col) then + local length = ts_utils.node_length(m) + if not match_length or length < match_length then + smallest_range = m + match_length = length + end + end + end + + if smallest_range then + ts_utils.update_selection(bufnr, smallest_range) + end +end + +function M.attach(bufnr, lang) + local buf = bufnr or api.nvim_get_current_buf() + local config = configs.get_module("textobjects") + local lang = lang or parsers.ft_to_lang(api.nvim_buf_get_option(bufnr, "ft")) + + for mapping, query in pairs(config.keymaps) do + if type(query) == 'table' then + query = query[lang] + elseif not queries.get_query(lang, 'textobjects') then + query = nil + end + if query then + local cmd = ":lua require'nvim-treesitter.textobjects'.select_textobject('"..query.."')" + api.nvim_buf_set_keymap(buf, "o", mapping, cmd, {silent = true}) + api.nvim_buf_set_keymap(buf, "v", mapping, cmd, {silent = true}) + end + end +end + +function M.detach(bufnr) + local buf = bufnr or api.nvim_get_current_buf() + local config = configs.get_module("textobjects") + local lang = parsers.ft_to_lang(api.nvim_buf_get_option(bufnr, "ft")) + + for mapping, query in pairs(config.keymaps) do + if type(query) == 'table' then + query = query[lang] + elseif not queries.get_query(lang, 'textobjects') then + query = nil + end + if query then + api.nvim_buf_del_keymap(buf, "o", mapping) + api.nvim_buf_del_keymap(buf, "v", mapping) + end + end +end + +return M diff --git a/lua/nvim-treesitter/ts_utils.lua b/lua/nvim-treesitter/ts_utils.lua index 39c75365d..f9c3de104 100644 --- a/lua/nvim-treesitter/ts_utils.lua +++ b/lua/nvim-treesitter/ts_utils.lua @@ -28,7 +28,7 @@ function M.get_node_text(node, bufnr) end end ---- Determines wether a node is the parent of another +--- Determines whether a node is the parent of another -- @param dest the possible parent -- @param source the possible child node function M.is_parent(dest, source) @@ -323,4 +323,57 @@ function M.highlight_range(range, buf, hl_namespace, hl_group) vim.highlight.range(buf, hl_namespace, hl_group, {start_row, start_col}, {end_row, end_col}) end +-- Set visual selection to node +function M.update_selection(buf, node) + local start_row, start_col, end_row, end_col = node:range() + + if end_row == vim.fn.line('$') then + end_col = #vim.fn.getline('$') + end + + -- Convert to 1-based indices + start_row = start_row + 1 + start_col = start_col + 1 + end_row = end_row + 1 + end_col = end_col + 1 + + vim.fn.setpos(".", { buf, start_row, start_col, 0 }) + vim.fn.nvim_exec("normal v", false) + + -- Convert exclusive end position to inclusive + if end_col == 1 then + vim.fn.setpos(".", { buf, end_row - 1, -1, 0 }) + else + vim.fn.setpos(".", { buf, end_row, end_col - 1, 0 }) + end +end + +-- Byte length of node range +function M.node_length(node) + local _, _, start_byte = node:start() + local _, _, end_byte = node:end_() + return end_byte - start_byte +end + +--- Determines whether (line, col) position is in node range +-- @param node Node defining the range +-- @param line A line (0-based) +-- @param col A column (0-based) +function M.is_in_node_range(node, line, col) + local start_line, start_col, end_line, end_col = node:range() + if line >= start_line and line <= end_line then + if line == start_line and line == end_line then + return col >= start_col and col < end_col + elseif line == start_line then + return col >= start_col + elseif line == end_line then + return col < end_col + else + return true + end + else + return false + end +end + return M diff --git a/lua/nvim-treesitter/utils.lua b/lua/nvim-treesitter/utils.lua index c87cee2d8..416326e48 100644 --- a/lua/nvim-treesitter/utils.lua +++ b/lua/nvim-treesitter/utils.lua @@ -46,7 +46,7 @@ end -- Gets a property at path -- @param tbl the table to access --- @param path the '.' seperated path +-- @param path the '.' separated path -- @returns the value at path or nil function M.get_at_path(tbl, path) local segments = vim.split(path, '.', true) -- cgit v1.2.3-70-g09d2