Skip to content

Commit

Permalink
feat(treesitter): allow custom build function
Browse files Browse the repository at this point in the history
See #68
  • Loading branch information
rcarriga committed Sep 10, 2022
1 parent a52052c commit 3f8246e
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 115 deletions.
68 changes: 68 additions & 0 deletions lua/neotest/lib/positions/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,72 @@ M.merge = function(orig, new)
return orig
end

local function build_structure(positions, namespaces, opts)
---@type neotest.Position
local parent = table.remove(positions, 1)
if not parent then
return nil
end
parent.id = parent.type == "file" and parent.path or opts.position_id(parent, namespaces)
local current_level = { parent }
local child_namespaces = vim.list_extend({}, namespaces)
if parent.type == "namespace" or (opts.nested_tests and parent.type == "test") then
child_namespaces[#child_namespaces + 1] = parent
end
while true do
local next_pos = positions[1]
if not next_pos or not M.contains(parent, next_pos) then
-- Don't preserve empty namespaces
if #current_level == 1 and parent.type == "namespace" then
return nil
end
if opts.require_namespaces and parent.type == "test" and #namespaces == 0 then
return nil
end
return current_level
end

local sub_tree = build_structure(positions, child_namespaces, opts)
if opts.nested_tests or parent.type ~= "test" then
current_level[#current_level + 1] = sub_tree
end
end
end

---@class neotest.positions.ParseOptions
---@field nested_tests boolean Allow nested tests
---@field require_namespaces boolean Require tests to be within namespaces
---@field position_id fun(position: neotest.Position, parents: neotest.Position[]): string Position ID constructor

--- Convert a flat list of sorted positions to a tree. Positions ID fields can be nil as they will be assigned.
--- NOTE: This mutates the positions given by assigning the `id` field.
---@param positions neotest.Position[]
---@param opts neotest.positions.ParseOptions
---@return neotest.Tree
function M.parse_tree(positions, opts)
opts = vim.tbl_extend("force", {
nested_tests = false, -- Allow nested tests
require_namespaces = false, -- Only allow tests within namespaces
---@param position neotest.Position The position to return an ID for
---@param parents neotest.Position[] Parent positions for the position
position_id = function(position, parents)
return table.concat(
vim.tbl_flatten({
position.path,
vim.tbl_map(function(pos)
return pos.name
end, parents),
position.name,
}),
"::"
)
end,
}, opts or {})
local structure = assert(build_structure(positions, {}, opts))

return Tree.from_list(structure, function(pos)
return pos.id
end)
end

return M
176 changes: 61 additions & 115 deletions lua/neotest/lib/treesitter/init.lua
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
local async = require("neotest.async")
local types = require("neotest.types")
local Tree = types.Tree

local M = {}

local function get_query_type(query, match)
for id, _ in pairs(match) do
local name = query.captures[id]
if name == "test.name" or name == "test.definition" then
return "test"
end
if name == "namespace.name" or name == "namespace.definition" then
return "namespace"
end
local function get_match_type(captured_nodes)
if captured_nodes["test.name"] then
return "test"
end
if captured_nodes["namespace.name"] then
return "namespace"
end
end

local function build_position(file_path, source, captured_nodes)
local match_type = get_match_type(captured_nodes)
if match_type then
---@type string
local name = vim.treesitter.get_node_text(captured_nodes[match_type .. ".name"], source)
local definition = captured_nodes[match_type .. ".definition"]

return {
type = match_type,
path = file_path,
name = name,
range = { definition:range() },
}
end
return nil
end

---@param file_path string
---@param query table
---@param source string
---@param root table
---@param opts neotest.treesitter.ParseOptions
---@return table[]
local function collect(file_path, query, source, root)
local function collect(file_path, query, source, root, opts)
local sep = require("neotest.lib").files.sep
local path_elems = vim.split(file_path, sep, { plain = true })
local nodes = {
Expand All @@ -33,85 +44,31 @@ local function collect(file_path, query, source, root)
range = { root:range() },
},
}
pcall(vim.tbl_add_reverse_lookup, query.captures)
for _, match in query:iter_matches(root, source) do
local type = get_query_type(query, match)
if type then
---@type string
local name = vim.treesitter.get_node_text(match[query.captures[type .. ".name"]], source)
local definition = match[query.captures[type .. ".definition"]]

nodes[#nodes + 1] = {
type = type,
path = file_path,
name = name,
range = { definition:range() },
}
local captured_nodes = {}
for i, capture in ipairs(query.captures) do
captured_nodes[capture] = match[i]
end
end
return nodes
end

---@param pos_a neotest.Position
---@param pos_b neotest.Position
local function contains(pos_a, pos_b)
local a_s_r, a_s_c, a_e_r, a_e_c = unpack(pos_a.range)
local b_s_r, b_s_c, b_e_r, b_e_c = unpack(pos_b.range)
if a_s_r > b_s_r or a_e_r < b_e_r then
return false
end
if a_s_r == b_s_r and a_s_c > b_s_c then
return false
end
if a_e_r == b_e_r and a_e_c < b_e_c then
return false
end
return true
end

---@param positions table[]
---@return table[]? Nested lists to be parsed as a tree object
local function parse_tree(positions, namespaces, opts)
---@type neotest.Position
local parent = table.remove(positions, 1)
if not parent then
return nil
end
parent.id = parent.type == "file" and parent.path or opts.position_id(parent, namespaces)
local current_level = { parent }
local child_namespaces = vim.list_extend({}, namespaces)
if parent.type == "namespace" or (opts.nested_tests and parent.type == "test") then
child_namespaces[#child_namespaces + 1] = parent
end
while true do
local next_pos = positions[1]
if not next_pos or not contains(parent, next_pos) then
-- Don't preserve empty namespaces
if #current_level == 1 and parent.type == "namespace" then
return nil
local res = opts.build_position(file_path, source, captured_nodes)
if res then
if res[1] then
for _, pos in ipairs(res) do
nodes[#nodes + 1] = pos
end
else
nodes[#nodes + 1] = res
end
if opts.require_namespaces and parent.type == "test" and #namespaces == 0 then
return nil
end
return current_level
end

local sub_tree = parse_tree(positions, child_namespaces, opts)
if opts.nested_tests or parent.type ~= "test" then
current_level[#current_level + 1] = sub_tree
end
end
end

---@param pos neotest.Position
---@return string
local function position_key(pos)
return pos.id
return nodes
end

--- Injections take a long time to run and are not needed.
--- This does only the required parsing
local function fast_parse(lang_tree)
--- Replaces `LanguageTree:parse`
--- https://github.com/neovim/neovim/blob/master/runtime/lua/vim/treesitter/languagetree.lua
function M.fast_parse(lang_tree)
if lang_tree._valid then
return lang_tree._trees
end
Expand All @@ -121,13 +78,27 @@ local function fast_parse(lang_tree)
return parser:parse(old_trees[1], lang_tree._source)
end

---@class neotest.treesitter.ParseOptions : neotest.positions.ParseOptions
---@field fast boolean Use faster parsing (Should be unchanged unless injections are needed)
local ParseOptions = {}
---Builds one or more positions from the captured nodes from a query match.
---@param file_path string Path to file being parsed
---@param source string Contents of file being parsed
---@param captured_nodes table<string, userdata> Captured nodes, indexed by capture name (e.g. `test.name`)
---@return neotest.Position | neotest.Position[] | nil
function ParseOptions.build_position(file_path, source, captured_nodes) end

---Same as `parse_positions` but uses the provided content instead of reading file.
---@parma file_path string
---@param query table | string
---@param content string
---@param query table | string
---@param opts neotest.treesitter.ParseOptions
---@return neotest.Tree
local function parse_positions(file_path, query, content, opts)
function M.parse_positions_from_string(file_path, content, query, opts)
opts = vim.tbl_extend("force", { build_position = build_position }, opts or {})
local lib = require("neotest.lib")
local fast = opts.fast ~= false
local ft = require("neotest.lib").files.detect_filetype(file_path)
local ft = lib.files.detect_filetype(file_path)
local lang = require("nvim-treesitter.parsers").ft_to_lang(ft)
async.util.scheduler()
local lang_tree = vim.treesitter.get_string_parser(
Expand All @@ -143,50 +114,25 @@ local function parse_positions(file_path, query, content, opts)

local root
if fast then
root = fast_parse(lang_tree):root()
root = M.fast_parse(lang_tree):root()
else
root = lang_tree:parse()[1]:root()
end
local positions = collect(file_path, query, content, root)
local structure = parse_tree(positions, {}, opts)
local tree = Tree.from_list(structure, position_key)
return tree
local positions = collect(file_path, query, content, root, opts)
return lib.positions.parse_tree(positions, opts)
end

---Read a file's contents from disk and parse test positions using the given query.
---See lib.positions.parse_tree for more options options
---@async
---@param file_path string
---@param query string | vim.treesitter.Query
---@param opts table
---@param opts neotest.treesitter.ParseOptions
---@return neotest.Tree
function M.parse_positions(file_path, query, opts)
async.util.sleep(10) -- Prevent completely hogging main thread
local content = require("neotest.lib").files.read(file_path)
return M.parse_positions_from_string(file_path, content, query, opts)
end

---Same as `parse_positions` but uses the provided content instead of reading file.
function M.parse_positions_from_string(file_path, content, query, opts)
opts = vim.tbl_extend("force", {
nested_tests = false, -- Allow nested tests
require_namespaces = false, -- Only allow tests within namespaces
---@param position neotest.Position The position to return an ID for
---@param parents neotest.Position[] Parent positions for the position
position_id = function(position, namespaces)
return table.concat(
vim.tbl_flatten({
position.path,
vim.tbl_map(function(pos)
return pos.name
end, namespaces),
position.name,
}),
"::"
)
end,
}, opts or {})
local results = parse_positions(file_path, query, content, opts)
return results
end

return M
38 changes: 38 additions & 0 deletions tests/unit/lib/treesitter/init_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,42 @@ describe("treesitter parsing", function()
})
assert.Not.Nil(tree:get_key('test_spec.lua__"test 3"'))
end)

a.it("uses custom build function", function()
local tree = ts.parse_positions_from_string("test_spec.lua", test_file, plenary_queries, {
build_position = function(file_path)
return {
type = "test",
path = file_path,
name = "same_name",
range = { 0, 0, 0, 0 },
}
end,
})
for _, position in tree:iter() do
if position.type ~= "file" then
assert.are.same("same_name", position.name)
end
end
end)

a.it("allows custom build function to return list", function()
local tree = ts.parse_positions_from_string("test_spec.lua", test_file, plenary_queries, {
build_position = function(file_path)
return {
{
type = "test",
path = file_path,
name = "same_name",
range = { 0, 0, 0, 0 },
},
}
end,
})
for _, position in tree:iter() do
if position.type ~= "file" then
assert.are.same("same_name", position.name)
end
end
end)
end)

0 comments on commit 3f8246e

Please sign in to comment.