diff options
| -rw-r--r-- | lua/nvim-treesitter/configs.lua | 131 | ||||
| -rw-r--r-- | lua/nvim-treesitter/info.lua | 2 | ||||
| -rw-r--r-- | lua/nvim-treesitter/locals.lua | 4 | ||||
| -rw-r--r-- | lua/nvim-treesitter/query.lua | 9 | ||||
| -rw-r--r-- | lua/nvim-treesitter/refactor/highlight_definitions.lua | 93 | ||||
| -rw-r--r-- | lua/nvim-treesitter/ts_utils.lua | 51 | ||||
| -rw-r--r-- | lua/nvim-treesitter/utils.lua | 17 | ||||
| -rw-r--r-- | queries/javascript/locals.scm | 3 |
8 files changed, 274 insertions, 36 deletions
diff --git a/lua/nvim-treesitter/configs.lua b/lua/nvim-treesitter/configs.lua index 13eb93675..30c2114ce 100644 --- a/lua/nvim-treesitter/configs.lua +++ b/lua/nvim-treesitter/configs.lua @@ -2,6 +2,7 @@ local api = vim.api local queries = require'nvim-treesitter.query' local parsers = require'nvim-treesitter.parsers' +local utils = require'nvim-treesitter.utils' -- @enable can be true or false -- @disable is a list of languages, only relevant if enable is true @@ -12,9 +13,7 @@ local config = { highlight = { enable = false, disable = {}, - is_supported = function(lang) - return queries.get_query(lang, 'highlights') ~= nil - end + is_supported = queries.has_highlights }, incremental_selection = { enable = false, @@ -25,9 +24,14 @@ local config = { scope_incremental="grc", node_decremental="grm" }, - is_supported = function(lang) - return queries.get_query(lang, 'locals') - end + is_supported = queries.has_locals + }, + refactor = { + highlight_definitions = { + enable = false, + disable = {}, + is_supported = queries.has_locals + } } }, ensure_installed = nil @@ -38,7 +42,8 @@ local M = {} local function enable_module(mod, bufnr, lang) local bufnr = bufnr or api.nvim_get_current_buf() local lang = lang or parsers.ft_to_lang(api.nvim_buf_get_option(bufnr, 'ft')) - if not parsers.list[lang] or not config.modules[mod] then + + if not parsers.list[lang] or not M.get_module(mod) then return end @@ -47,22 +52,26 @@ local function enable_module(mod, bufnr, lang) end local function enable_mod_conf_autocmd(mod, lang) - if not config.modules[mod] or M.is_enabled(mod, lang) then return end + local config_mod = M.get_module(mod) + + if not config_mod or M.is_enabled(mod, lang) then return end local cmd = string.format("lua require'nvim-treesitter.%s'.attach()", mod) for _, ft in pairs(parsers.lang_to_ft(lang)) do api.nvim_command(string.format("autocmd NvimTreesitter FileType %s %s", ft, cmd)) end - for i, parser in pairs(config.modules[mod].disable) do + for i, parser in pairs(config_mod.disable) do if parser == lang then - table.remove(config.modules[mod].disable, i) + table.remove(config_mod.disable, i) break end end end local function enable_all(mod, lang) - if not config.modules[mod] then return end + local config_mod = M.get_module(mod) + + if not config_mod then return end for _, bufnr in pairs(api.nvim_list_bufs()) do local ft = api.nvim_buf_get_option(bufnr, 'ft') @@ -81,7 +90,7 @@ local function enable_all(mod, lang) end end end - config.modules[mod].enable = true + config_mod.enable = true end local function disable_module(mod, bufnr, lang) @@ -91,7 +100,7 @@ local function disable_module(mod, bufnr, lang) return end - if not parsers.list[lang] or not config.modules[mod] then + if not parsers.list[lang] or not M.get_module(mod) then return end @@ -100,14 +109,16 @@ local function disable_module(mod, bufnr, lang) end local function disable_mod_conf_autocmd(mod, lang) - if not config.modules[mod] or not M.is_enabled(mod, lang) then return end + local config_mod = M.get_module(mod) + + if not config_mod or not M.is_enabled(mod, lang) then return end local cmd = string.format("lua require'nvim-treesitter.%s'.attach()", mod) -- TODO(kyazdani): detach the correct autocmd... doesn't work when using %s, cmd for _, ft in pairs(parsers.lang_to_ft(lang)) do api.nvim_command(string.format("autocmd! NvimTreesitter FileType %s", ft)) end - table.insert(config.modules[mod].disable, lang) + table.insert(config_mod.disable, lang) end local function disable_all(mod, lang) @@ -123,7 +134,30 @@ local function disable_all(mod, lang) for _, lang in pairs(parsers.available_parsers()) do disable_mod_conf_autocmd(mod, lang) end - config.modules[mod].enable = false + + local config_mod = M.get_module(mod) + + if config_mod then + config_mod.enable = false + end + end +end + +-- Recurses trough all modules including submodules +-- @param accumulator function called for each module +-- @param root root configuration table to start at +-- @param path prefix path +local function recurse_modules(accumulator, root, path) + local root = root or config.modules + + for name, module in pairs(root) do + local new_path = path and (path..'.'..name) or name + + if M.is_module(module) then + accumulator(name, module, new_path) + elseif type(module) == 'table' then + recurse_modules(accumulator, module, new_path) + end end end @@ -169,7 +203,7 @@ function M.is_enabled(mod, lang) return false end - local module_config = config.modules[mod] + local module_config = M.get_module(mod) if not module_config then return false end if not module_config.enable or not module_config.is_supported(lang) then @@ -188,19 +222,7 @@ function M.setup(user_data) for mod, data in pairs(user_data) do if config.modules[mod] then - if type(data.enable) == 'boolean' then - config.modules[mod].enable = data.enable - end - if type(data.disable) == 'table' then - config.modules[mod].disable = data.disable - end - if config.modules[mod].keymaps and type(data.keymaps) == 'table' then - for f, map in pairs(data.keymaps) do - if config.modules[mod].keymaps[f] then - config.modules[mod].keymaps[f] = map - end - end - end + M.setup_module(config.modules[mod], data) elseif mod == 'ensure_installed' then config.ensure_installed = data require'nvim-treesitter.install'.ensure_installed(data) @@ -208,12 +230,55 @@ function M.setup(user_data) end end +--- Sets up a single module or all submodules of a group +-- @param mod the module or group of modules +-- @param data user defined configuration for the module +function M.setup_module(mod, data) + if M.is_module(mod) then + if type(data.enable) == 'boolean' then + mod.enable = data.enable + end + if type(data.disable) == 'table' then + mod.disable = data.disable + end + if mod.keymaps and type(data.keymaps) == 'table' then + for f, map in pairs(data.keymaps) do + if mod.keymaps[f] then + mod.keymaps[f] = map + end + end + end + elseif type(data) == 'table' and type(mod) == 'table' then + for key, value in pairs(data) do + M.setup_module(mod[key], value) + end + end +end + function M.available_modules() - return vim.tbl_keys(config.modules) + local modules = {} + + recurse_modules(function(_, _, path) + table.insert(modules, path) + end) + + return modules +end + +-- Gets a module config by path +-- @param mod_path path to the module +-- @returns the module or nil +function M.get_module(mod_path) + local mod = utils.get_at_path(config.modules, mod_path) + + return M.is_module(mod) and mod or nil end -function M.get_module(mod) - return config.modules[mod] +-- Determines whether the provided table is a module. +-- A module should contain an 'is_supported' function. +-- @param mod the module table +function M.is_module(mod) + return type(mod) == 'table' and type(mod.is_supported) == 'function' end return M diff --git a/lua/nvim-treesitter/info.lua b/lua/nvim-treesitter/info.lua index c768595c4..0c99f3ff2 100644 --- a/lua/nvim-treesitter/info.lua +++ b/lua/nvim-treesitter/info.lua @@ -65,7 +65,7 @@ local function print_info_modules(sorted_languages) end local function module_info(mod) - if mod and not configs.get_config()[mod] then return end + if mod and not configs.get_module(mod) then return end local parserlist = parsers.available_parsers() table.sort(parserlist, function(a, b) return #a > #b end) diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua index 8c7d400eb..fae8195ea 100644 --- a/lua/nvim-treesitter/locals.lua +++ b/lua/nvim-treesitter/locals.lua @@ -11,7 +11,7 @@ local M = { locals = {} } -function M.collect_locals(bufnr) +function M.collect_locals(bufnr, root) local lang = parsers.ft_to_lang(api.nvim_buf_get_option(bufnr, "ft")) if not lang then return end @@ -21,7 +21,7 @@ function M.collect_locals(bufnr) local parser = parsers.get_parser(bufnr, lang) if not parser then return end - local root = parser:parse():root() + local root = root or parser:parse():root() local start_row, _, end_row, _ = root:range() local locals = {} diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index 42e835824..d04e368a6 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -13,6 +13,12 @@ local function read_query_files(filenames) return table.concat(contents, '\n') end +local function get_query_gaurd(query) + return function(lang) + return M.get_query(lang, query) ~= nil + end +end + -- Some treesitter grammars extend others. -- We can use that to import the queries of the base language M.base_language_map = { @@ -21,6 +27,9 @@ M.base_language_map = { tsx = {'typescript', 'javascript'}, } +M.has_locals = get_query_gaurd('locals') +M.has_highlights = get_query_gaurd('highlights') + function M.get_query(lang, query_name) local query_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', lang, query_name), true) local query_string = '' diff --git a/lua/nvim-treesitter/refactor/highlight_definitions.lua b/lua/nvim-treesitter/refactor/highlight_definitions.lua new file mode 100644 index 000000000..ef415bb77 --- /dev/null +++ b/lua/nvim-treesitter/refactor/highlight_definitions.lua @@ -0,0 +1,93 @@ +-- This module highlights reference usages and the corresponding +-- definition on cursor hold. + +local parsers = require'nvim-treesitter.parsers' +local ts_utils = require'nvim-treesitter.ts_utils' +local locals = require'nvim-treesitter.locals' +local api = vim.api +local cmd = api.nvim_command + +local M = {} + +local usage_namespace = api.nvim_create_namespace('nvim-treesitter-usages') + +local function find_usages(node, scope_node) + local usages = {} + local node_text = ts_utils.get_node_text(node)[1] + + if not node_text or #node_text < 1 then return end + + for _, def in ipairs(locals.collect_locals(bufnr, scope_node)) do + if def.reference + and def.reference.node + and ts_utils.get_node_text(def.reference.node)[1] == node_text then + + table.insert(usages, def.reference.node) + end + end + + return usages +end + +function M.highlight_usages(bufnr) + M.clear_usage_highlights(bufnr) + + local node_at_point = ts_utils.get_node_at_cursor() + + if not node_at_point then return end + + local def_node, scope = ts_utils.find_definition(node_at_point, bufnr) + local usages = find_usages(node_at_point, scope) + + for _, usage_node in ipairs(usages) do + local start_row, start_col, _, end_col = usage_node:range() + + if usage_node ~= node_at_point then + api.nvim_buf_add_highlight( + bufnr, + usage_namespace, + 'Visual', + start_row, + start_col, + end_col) + end + end + + if def_node then + local start_row, start_col, _, end_col = def_node:range() + + if def_node ~= node_at_point then + api.nvim_buf_add_highlight( + bufnr, + usage_namespace, + 'Search', + start_row, + start_col, + end_col) + end + end +end + +function M.clear_usage_highlights(bufnr) + api.nvim_buf_clear_namespace(bufnr, usage_namespace, 0, -1) +end + +function M.attach(bufnr) + local bufnr = bufnr or api.nvim_get_current_buf() + + cmd(string.format('augroup NvimTreesitterUsages_%d', bufnr)) + cmd 'au!' + cmd(string.format([[autocmd CursorHold <buffer=%d> lua require'nvim-treesitter.refactor.highlight_definitions'.highlight_usages(%d)]], bufnr, bufnr)) + cmd(string.format([[autocmd CursorMoved <buffer=%d> lua require'nvim-treesitter.refactor.highlight_definitions'.clear_usage_highlights(%d)]], bufnr, bufnr)) + cmd(string.format([[autocmd InsertEnter <buffer=%d> lua require'nvim-treesitter.refactor.highlight_definitions'.clear_usage_highlights(%d)]], bufnr, bufnr)) + cmd 'augroup END' +end + +function M.detach(bufnr) + M.clear_usage_highlights(bufnr) + cmd(string.format('autocmd! NvimTreesitterUsages_%d CursorHold', bufnr)) + cmd(string.format('autocmd! NvimTreesitterUsages_%d CursorMoved', bufnr)) + cmd(string.format('autocmd! NvimTreesitterUsages_%d InsertEnter', bufnr)) +end + +return M diff --git a/lua/nvim-treesitter/ts_utils.lua b/lua/nvim-treesitter/ts_utils.lua index 687a3c539..363d26374 100644 --- a/lua/nvim-treesitter/ts_utils.lua +++ b/lua/nvim-treesitter/ts_utils.lua @@ -212,4 +212,55 @@ function M.get_node_at_cursor(winnr) return root:named_descendant_for_range(cursor[1]-1,cursor[2],cursor[1]-1,cursor[2]) end +-- Finds the definition node and it's scope node of a node +-- @param node starting node +-- @param bufnr buffer +-- @returns the definition node and the definition nodes scope node +function M.find_definition(node, bufnr) + local bufnr = bufnr or api.nvim_get_current_buf() + local node_text = M.get_node_text(node)[1] + local current_scope = M.containing_scope(node) + local _, _, node_start = node:start() + + -- If a scope wasn't found then use the root node + if current_scope == node then + current_scope = parsers.get_parser(bufnr).tree:root() + end + + while current_scope ~= nil and current_scope ~= node do + for _, def in ipairs(locals.collect_locals(bufnr, current_scope)) do + if def.definition then + for _, def_node in ipairs(M.get_local_nodes(def.definition)) do + local _, _, def_start = def_node:start() + + if M.get_node_text(def_node)[1] == node_text and def_start < node_start then + return def_node, current_scope + end + end + end + end + + current_scope = M.containing_scope(current_scope:parent()) + end + + return nil, nil +end + +-- Gets all nodes from a local list result. +-- @param local_def the local list result +-- @returns a list of nodes +function M.get_local_nodes(local_def) + if local_def.node then + return { local_def.node } + else + local result = {} + + for _, def in pairs(local_def) do + vim.list_extend(result, M.get_local_nodes(def)) + end + + return result + end +end + return M diff --git a/lua/nvim-treesitter/utils.lua b/lua/nvim-treesitter/utils.lua index 6fa159817..41833f8cf 100644 --- a/lua/nvim-treesitter/utils.lua +++ b/lua/nvim-treesitter/utils.lua @@ -45,4 +45,21 @@ function M.get_cache_dir() return nil, 'Invalid cache rights, $XDG_CACHE_HOME or /tmp should be read/write' end +--- Gets a property at path +-- @param tbl the table to access +-- @param path the '.' seperated path +-- @returns the value at path or nil +function M.get_at_path(tbl, path) + local segments = vim.split(path, '.', true) + local result = tbl + + for _, segment in ipairs(segments) do + if type(result) == 'table' then + result = result[segment] + end + end + + return result +end + return M diff --git a/queries/javascript/locals.scm b/queries/javascript/locals.scm index 165adfed9..d56000d5a 100644 --- a/queries/javascript/locals.scm +++ b/queries/javascript/locals.scm @@ -28,6 +28,9 @@ (variable_declarator name: (identifier) @definition) +(import_specifier + (identifier) @definition) + ; References ;------------ |
