summaryrefslogtreecommitdiffstats
path: root/scripts/check-queries.lua
blob: a84df85c4c47d164ca5d71d2d8d1e1ef27778b19 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env -S nvim -l

-- Equivalent to print(), but this will ensure consistent output regardless of
-- operating system.
local function io_print(text)
  if not text then
    text = ""
  end
  io.write(text, "\n")
end

local function extract_captures()
  local lines = vim.fn.readfile "CONTRIBUTING.md"
  local captures = {}
  local current_query

  for _, line in ipairs(lines) do
    if vim.startswith(line, "### ") then
      current_query = vim.fn.tolower(line:sub(5))
    elseif vim.startswith(line, "@") and current_query then
      if not captures[current_query] then
        captures[current_query] = {}
      end

      table.insert(captures[current_query], vim.split(line:sub(2), " ", true)[1])
    end
  end

  -- Complete captures for injections.
  local parsers = vim.tbl_keys(require("nvim-treesitter.parsers").list)
  for _, lang in pairs(parsers) do
    table.insert(captures["injections"], lang)
  end

  return captures
end

local function list_any(list, predicate)
  for _, v in pairs(list) do
    if predicate(v) then
      return true
    end
  end
  return false
end

local function do_check()
  local timings = {}
  local parsers = require("nvim-treesitter.info").installed_parsers()
  local queries = require "nvim-treesitter.query"
  local query_types = queries.built_in_query_groups

  local captures = extract_captures()
  local last_error

  io_print "::group::Check parsers"

  for _, lang in pairs(parsers) do
    timings[lang] = {}
    for _, query_type in pairs(query_types) do
      local before = vim.loop.hrtime()
      local ok, query = pcall(queries.get_query, lang, query_type)
      local after = vim.loop.hrtime()
      local duration = after - before
      table.insert(timings, { duration = duration, lang = lang, query_type = query_type })
      io_print("Checking " .. lang .. " " .. query_type .. string.format(" (%.02fms)", duration * 1e-6))
      if not ok then
        local err_msg = lang .. " (" .. query_type .. "): " .. query
        io_print(err_msg)
        last_error = err_msg
      else
        if query then
          for _, capture in ipairs(query.captures) do
            local is_valid = (
              vim.startswith(capture, "_") -- Helpers.
              or list_any(captures[query_type], function(documented_capture)
                return vim.startswith(capture, documented_capture)
              end)
            )
            if not is_valid then
              local error = string.format("(x) Invalid capture @%s in %s for %s.", capture, query_type, lang)
              io_print(error)
              last_error = error
            end
          end
        end
      end
    end
  end

  io_print "::endgroup::"

  if last_error then
    io_print()
    io_print "Last error: "
    error(last_error)
  end
  return timings
end

local ok, result = pcall(do_check)
local allowed_to_fail = vim.split(vim.env.ALLOWED_INSTALLATION_FAILURES or "", ",", true)

for k, v in pairs(require("nvim-treesitter.parsers").get_parser_configs()) do
  if not require("nvim-treesitter.parsers").has_parser(k) then
    -- On CI all parsers that can be installed from C files should be installed
    if
      vim.env.CI
      and not v.install_info.requires_generate_from_grammar
      and not vim.tbl_contains(allowed_to_fail, k)
    then
      io_print("Error: parser for " .. k .. " is not installed")
      vim.cmd "cq"
    else
      io_print("Warning: parser for " .. k .. " is not installed")
    end
  end
end

if ok then
  io_print "::group::Timings"
  table.sort(result, function(a, b)
    return a.duration < b.duration
  end)
  for i, val in ipairs(result) do
    io_print(string.format("%i. %.02fms %s %s", #result - i + 1, val.duration * 1e-6, val.lang, val.query_type))
  end
  io_print "::endgroup::"
  io_print "Check successful!"
  vim.cmd "q"
else
  io_print "Check failed:"
  io_print(result)
  io_print "\n"
  vim.cmd "cq"
end