From 43309c0bfaa308a6f4ff107bf662d9ecdf8f4c4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=B3n=C3=A1n=20Carrigan?= Date: Sat, 4 Mar 2023 13:40:02 +0000 Subject: [PATCH] feat: state consumer --- doc/neotest.txt | 53 ++++++- lua/neotest/client/events/init.lua | 4 +- lua/neotest/client/init.lua | 2 +- lua/neotest/client/state/init.lua | 1 + lua/neotest/config/init.lua | 7 + lua/neotest/consumers/init.lua | 3 +- lua/neotest/consumers/state/init.lua | 115 ++++++++++++++ lua/neotest/consumers/state/tracker.lua | 191 ++++++++++++++++++++++++ lua/neotest/init.lua | 1 + scripts/gendocs.lua | 1 + 10 files changed, 372 insertions(+), 6 deletions(-) create mode 100644 lua/neotest/consumers/state/init.lua create mode 100644 lua/neotest/consumers/state/tracker.lua diff --git a/doc/neotest.txt b/doc/neotest.txt index c02dba7d..6360da26 100644 --- a/doc/neotest.txt +++ b/doc/neotest.txt @@ -14,6 +14,7 @@ neotest *neotest* Summary Consumer...........................................|neotest.summary| Jump Consumer.................................................|neotest.jump| Quickfix Consumer.........................................|neotest.quickfix| + State Consumer...............................................|neotest.state| Neotest Client..............................................|neotest.Client| Library........................................................|neotest.lib| Library: Files...........................................|neotest.lib.files| @@ -140,6 +141,9 @@ Default values: running = { concurrent = true }, + state = { + enabled = true + }, status = { enabled = true, signs = true, @@ -214,6 +218,7 @@ Fields~ {output_panel} `(neotest.Config.output_panel)` {quickfix} `(neotest.Config.quickfix)` {status} `(neotest.Config.status)` +{state} `(neotest.Config.state)` {diagnostic} `(neotest.Config.diagnostic)` {projects} `(table)` Project specific settings, keys are project root directories (e.g "~/Dev/my_project") @@ -284,6 +289,10 @@ Fields~ {enabled} `(boolean)` {open_on_run} `(string|boolean)` Open nearest test result after running + *neotest.Config.state* +Fields~ +{enabled} `(boolean)` + *neotest.Config.output_panel* Fields~ {enabled} `(boolean)` @@ -326,7 +335,7 @@ The client interface provides methods for interacting with tests, fetching results as well as event listeners. To listen to an event, just assign the event listener to a function: >lua - client.listeners.discover_positions = function (adapter_id, path, tree) + client.listeners.discover_positions = function (adapter_id, tree) ... end < @@ -628,6 +637,46 @@ neotest.quickfix *neotest.quickfix* A consumer that sends results to the quickfix list. +============================================================================== +neotest.state *neotest.state* + + +A consumer that tracks various pieces of state in Neotest. +Most of the internals of Neotest are asynchronous so this consumer allows +tracking the state of the test suite and test results without needing to +write asynchronous code. + + *neotest.state.adapter_ids()* +`adapter_ids`() + +Get the list of all adapter IDs currently active +Return~ +`(string[])` + + *neotest.state.status_counts()* +`status_counts`({adapter_id}, {args}) + +Get the counts of the various states of tests for the entire suite or for a +buffer. +Parameters~ +{adapter_id} `(string)` +{args?} `(neotest.state.StatusCountsArgs)` +Return~ +`(neotest.state.StatusCounts)` | nil + + *neotest.state.StatusCountsArgs* +Fields~ +{buffer?} `(integer)` Returns statuses for this buffer + + *neotest.state.StatusCounts* +Fields~ +{total} `(integer)` +{passed} `(integer)` +{failed} `(integer)` +{skipped} `(integer)` +{running} `(integer)` + + ============================================================================== *neotest.Client* `Client` @@ -644,7 +693,7 @@ start because it can slow down startup. *neotest.ConsumerListeners* Fields~ -{discover_positions} `(fun(adapter_id: string, path: string, tree: neotest.Tree))` +{discover_positions} `(fun(adapter_id: string, tree: neotest.Tree))` {run} `(fun(adapter_id: string, root_id: string, position_ids: string[]))` {results} `(fun(adapter_id: string, results: table, partial: boolean))` {test_file_focused} `(fun(adapter_id: string, file_path: string)>)` diff --git a/lua/neotest/client/events/init.lua b/lua/neotest/client/events/init.lua index 21d37042..3e38e639 100644 --- a/lua/neotest/client/events/init.lua +++ b/lua/neotest/client/events/init.lua @@ -18,9 +18,9 @@ local NeotestEvents = { M.events = NeotestEvents ---@class neotest.InternalClientListeners ----@field discover_positions table +---@field discover_positions table ---@field run table ----@field results table)> +---@field results table, partial: boolean)> ---@field test_file_focused table> ---@field test_focused table> diff --git a/lua/neotest/client/init.lua b/lua/neotest/client/init.lua index a5ba5de4..e6fab85a 100644 --- a/lua/neotest/client/init.lua +++ b/lua/neotest/client/init.lua @@ -29,7 +29,7 @@ local neotest = {} neotest.Client = {} ---@class neotest.ConsumerListeners ----@field discover_positions fun(adapter_id: string, path: string, tree: neotest.Tree) +---@field discover_positions fun(adapter_id: string, tree: neotest.Tree) ---@field run fun(adapter_id: string, root_id: string, position_ids: string[]) ---@field results fun(adapter_id: string, results: table, partial: boolean) ---@field test_file_focused fun(adapter_id: string, file_path: string)> diff --git a/lua/neotest/client/state/init.lua b/lua/neotest/client/state/init.lua index 10977bfe..976c9e27 100644 --- a/lua/neotest/client/state/init.lua +++ b/lua/neotest/client/state/init.lua @@ -69,6 +69,7 @@ end function NeotestClientState:update_results(adapter_id, results, partial) logger.debug("New results for adapter", adapter_id) logger.trace(results) + local positions = self:positions(adapter_id) self._results[adapter_id] = vim.tbl_extend("force", self._results[adapter_id] or {}, results) if not self._running[adapter_id] then self._running[adapter_id] = {} diff --git a/lua/neotest/config/init.lua b/lua/neotest/config/init.lua index 9dba5b08..6dd716e7 100644 --- a/lua/neotest/config/init.lua +++ b/lua/neotest/config/init.lua @@ -44,6 +44,7 @@ define_highlights() ---@field output_panel neotest.Config.output_panel ---@field quickfix neotest.Config.quickfix ---@field status neotest.Config.status +---@field state neotest.Config.state ---@field diagnostic neotest.Config.diagnostic ---@field projects table Project specific settings, keys --- are project root directories (e.g "~/Dev/my_project") @@ -104,6 +105,9 @@ define_highlights() ---@field enabled boolean ---@field open_on_run string|boolean Open nearest test result after running +---@class neotest.Config.state +---@field enabled boolean + ---@class neotest.Config.output_panel ---@field enabled boolean ---@field open string|fun():integer A command or function to open a window for the output panel @@ -252,6 +256,9 @@ local default_config = { enabled = true, open = true, }, + state = { + enabled = true, + }, projects = {}, } diff --git a/lua/neotest/consumers/init.lua b/lua/neotest/consumers/init.lua index a192802c..fc5c1999 100644 --- a/lua/neotest/consumers/init.lua +++ b/lua/neotest/consumers/init.lua @@ -17,7 +17,7 @@ local neotest = {} --- results as well as event listeners. To listen to an event, just assign the event --- listener to a function: --- >lua ---- client.listeners.discover_positions = function (adapter_id, path, tree) +--- client.listeners.discover_positions = function (adapter_id, tree) --- ... --- end --- < @@ -37,6 +37,7 @@ neotest.consumers = { jump = require("neotest.consumers.jump"), benchmark = require("neotest.consumers.benchmark"), quickfix = require("neotest.consumers.quickfix"), + state = require("neotest.consumers.state"), } return neotest.consumers diff --git a/lua/neotest/consumers/state/init.lua b/lua/neotest/consumers/state/init.lua new file mode 100644 index 00000000..18fcb9d3 --- /dev/null +++ b/lua/neotest/consumers/state/init.lua @@ -0,0 +1,115 @@ +local async = require("neotest.async") +local logger = require("neotest.logging") +local neotest = {} +local StateTracker = require("neotest.consumers.state.tracker") + +---@type neotest.state.StateTracker +---@nodoc +local tracker + +---@param client neotest.Client +---@nodoc +local function init(client) + local updated_cond = async.control.Condvar.new() + local pending_update = false + tracker = StateTracker:new(client) + local function update_positions() + while true do + if not pending_update then + updated_cond:wait() + end + tracker:update_positions() + pending_update = false + async.util.sleep(50) + end + end + + vim.api.nvim_create_autocmd("BufAdd", { + callback = function(args) + tracker:register_buffer(args.buf) + pending_update = true + updated_cond:notify_all() + end, + }) + for _, buf in ipairs(async.api.nvim_list_bufs()) do + tracker:register_buffer(buf) + end + async.run(function() + xpcall(update_positions, function(msg) + logger.error("Error in state consumer", debug.traceback(msg, 2)) + end) + end) + client.listeners.discover_positions = function(adapter_id) + if not tracker:adapter_state(adapter_id) then + tracker:register_adapter(adapter_id) + end + pending_update = true + updated_cond:notify_all() + end + + client.listeners.run = function(adapter_id, _, position_ids) + tracker:update_running(adapter_id, position_ids) + end + + client.listeners.results = function(adapter_id, results, partial) + if partial then + return + end + tracker:update_results(adapter_id, results) + end +end + +---@param args? table +---@return neotest.state.State | nil +---@nodoc +local function state_from_args(adapter_id, args) + if args and args.buffer then + return tracker:buffer_state(adapter_id, args.buffer) + end + return tracker:adapter_state(adapter_id) +end + +---@toc_entry State Consumer +---@text +--- A consumer that tracks various pieces of state in Neotest. +--- Most of the internals of Neotest are asynchronous so this consumer allows +--- tracking the state of the test suite and test results without needing to +--- write asynchronous code. +---@class neotest.consumers.state +neotest.state = {} + +--- Get the list of all adapter IDs currently active +---@return string[] +function neotest.state.adapter_ids() + return tracker.adapter_ids +end + +--- Get the counts of the various states of tests for the entire suite or for a +--- buffer. +---@param adapter_id string +---@param args? neotest.state.StatusCountsArgs +---@return neotest.state.StatusCounts | nil +function neotest.state.status_counts(adapter_id, args) + local state = state_from_args(adapter_id, args) + + return state and state.status +end + +---@class neotest.state.StatusCountsArgs +---@field buffer? integer Returns statuses for this buffer + +---@class neotest.state.StatusCounts +---@field total integer +---@field passed integer +---@field failed integer +---@field skipped integer +---@field running integer + +neotest.summary = setmetatable(neotest.state, { + __call = function(_, client) + init(client) + return neotest.state + end, +}) + +return neotest.state diff --git a/lua/neotest/consumers/state/tracker.lua b/lua/neotest/consumers/state/tracker.lua new file mode 100644 index 00000000..d53eb0a2 --- /dev/null +++ b/lua/neotest/consumers/state/tracker.lua @@ -0,0 +1,191 @@ +---@class neotest.state.State +---@field positions neotest.Tree +---@field running table +---@field status neotest.state.StatusCounts + +---@class neotest.state.AdapterState : neotest.state.State +---@field buffers table + +---@class neotest.state.StateTracker +---@field adapter_states table +---@field adapter_ids string[] +---@field path_buffers table +---@field client neotest.Client +local StateTracker = {} + +function StateTracker:new(client) + local tracker = { + adapter_states = {}, + client = client, + adapter_ids = {}, + path_buffers = {}, + } + self.__index = self + return setmetatable(tracker, self) +end + +---@param buffer integer +---@return neotest.state.State | nil +function StateTracker:buffer_state(adapter_id, buffer) + local path = vim.fn.fnamemodify(vim.fn.bufname(buffer), ":p") + local state = self.adapter_states[adapter_id] + return state.buffers[path] +end + +function StateTracker:count_tests(tree) + local count = 0 + for _, pos in tree:iter() do + if pos.type == "test" then + count = count + 1 + end + end + + return count +end + +function StateTracker:is_test(pos_id, tree) + local node = tree:get_key(pos_id) + if node and node:data().type == "test" then + return true + end + return false +end + +function StateTracker:update_counts(adapter_id) + local state = self.adapter_states[adapter_id] + local status = state.status + local running = state.running + local tree = state.positions + status.running = 0 + for _ in pairs(state.running) do + status.running = status.running + 1 + end + for _, buf_state in pairs(state.buffers) do + buf_state.status.running = 0 + for _ in pairs(buf_state.running) do + buf_state.status.running = buf_state.status.running + 1 + end + end + + local adapter_results = self.client:get_results(adapter_id) + status.failed = 0 + status.passed = 0 + status.skipped = 0 + for pos_id, result in pairs(adapter_results) do + if not running[pos_id] and self:is_test(pos_id, tree) then + state.status[result.status] = state.status[result.status] + 1 + end + end + for _, buf_state in pairs(state.buffers) do + buf_state.status.failed = 0 + buf_state.status.passed = 0 + buf_state.status.skipped = 0 + for _, pos in buf_state.positions:iter() do + local result = adapter_results[pos.id] + if not running[pos.id] and result and self:is_test(pos.id, tree) then + buf_state.status[result.status] = buf_state.status[result.status] + 1 + end + end + end +end + +function StateTracker:update_positions() + for adapter_id, state in pairs(self.adapter_states) do + state.positions = assert(self.client:get_position(nil, { adapter = adapter_id })) + state.status.total = self:count_tests(state.positions) + for _, node in state.positions:iter_nodes() do + local pos = node:data() + if pos.type == "file" and self.path_buffers[pos.path] then + if not state.buffers[pos.path] then + state.buffers[pos.path] = { + positions = node, + running = {}, + status = { + failed = 0, + passed = 0, + skipped = 0, + total = 0, + running = 0, + }, + } + end + end + end + + for path, buf_state in pairs(state.buffers) do + local new_tree = state.positions:get_key(path) + if not new_tree then + state.buffers[path] = nil + else + buf_state.positions = new_tree + buf_state.status.total = self:count_tests(new_tree) + end + end + end +end + +function StateTracker:adapter_state(adapter_id) + return self.adapter_states[adapter_id] +end + +function StateTracker:register_adapter(adapter_id) + self.adapter_ids[#self.adapter_ids + 1] = adapter_id + self.adapter_states[adapter_id] = { + running = {}, + buffers = {}, + status = { + failed = 0, + passed = 0, + skipped = 0, + total = 0, + running = 0, + }, + } +end + +function StateTracker:register_buffer(buffer) + local path = vim.fn.fnamemodify(vim.fn.bufname(buffer), ":p") + self.path_buffers[path] = buffer +end + +function StateTracker:update_running(adapter_id, position_ids) + local state = self:adapter_state(adapter_id) + local running = state.running + local tree = state.positions + for _, pos_id in ipairs(position_ids) do + if self:is_test(pos_id, tree) then + running[pos_id] = (running[pos_id] or 0) + 1 + for _, buf_state in pairs(self:adapter_state(adapter_id).buffers) do + if buf_state.positions:get_key(pos_id) then + buf_state.running[pos_id] = (buf_state.running[pos_id] or 0) + 1 + end + end + end + end + self:update_counts(adapter_id) +end + +function StateTracker:update_results(adapter_id, results) + local state = self:adapter_state(adapter_id) + local running = state.running + + for pos_id, _ in pairs(results) do + if running[pos_id] then + running[pos_id] = running[pos_id] - 1 + if running[pos_id] == 0 then + running[pos_id] = nil + end + for _, buf_state in pairs(self:adapter_state(adapter_id).buffers) do + if buf_state.running[pos_id] then + buf_state.running[pos_id] = buf_state.running[pos_id] - 1 + if buf_state.running[pos_id] == 0 then + buf_state.running[pos_id] = nil + end + end + end + end + end + self:update_counts(adapter_id) +end + +return StateTracker diff --git a/lua/neotest/init.lua b/lua/neotest/init.lua index bc7071b5..d25c8bf8 100644 --- a/lua/neotest/init.lua +++ b/lua/neotest/init.lua @@ -36,6 +36,7 @@ ---@field status neotest.consumers.status ---@field diagnostic neotest.consumers.diagnostic ---@field jump neotest.consumers.jump +---@field state neotest.consumers.state ---@nodoc local neotest = {} diff --git a/scripts/gendocs.lua b/scripts/gendocs.lua index 918ecea7..8f41da0e 100644 --- a/scripts/gendocs.lua +++ b/scripts/gendocs.lua @@ -668,6 +668,7 @@ minidoc.generate( "./lua/neotest/consumers/summary/init.lua", "./lua/neotest/consumers/jump.lua", "./lua/neotest/consumers/quickfix.lua", + "./lua/neotest/consumers/state/init.lua", "./lua/neotest/client/init.lua", "./lua/neotest/lib/init.lua", "./lua/neotest/lib/file/init.lua",