diff --git a/lua/neotest/client/init.lua b/lua/neotest/client/init.lua index 60626627..c30aebd8 100644 --- a/lua/neotest/client/init.lua +++ b/lua/neotest/client/init.lua @@ -36,7 +36,7 @@ end ---Run the given tree ---@async ----@param tree? neotest.Tree +---@param tree neotest.Tree ---@param args table ---@field adapter string: Adapter ID ---@field strategy string: Strategy to run commands with @@ -48,28 +48,41 @@ function NeotestClient:run_tree(tree, args) table.insert(pos_ids, pos.id) end - local pos = tree:data() - local adapter_id, adapter = self:_get_adapter(pos.id, args.adapter) - if not adapter_id then - logger.error("Adapter not found for position", pos.id) + local root = tree:data() + local adapter_id, adapter = self:_get_adapter(root.id, args.adapter) + if not adapter_id or not adapter then + logger.error("Adapter not found for position", root.id) return end - self._state:update_running(adapter_id, pos.id, pos_ids) - local success, results = pcall(self._runner._run_tree, self._runner, tree, args, adapter) + self._state:update_running(adapter_id, root.id, pos_ids) + local all_results = {} + local success, error = pcall( + self._runner.run_tree, + self._runner, + tree, + args, + adapter, + function(results) + for pos_id, result in pairs(results) do + all_results[pos_id] = result + end + self._state:update_results(adapter_id, results) + end + ) if not success then - lib.notify(("%s: %s"):format(adapter.name, results), "warn") - results = {} + lib.notify(("%s: %s"):format(adapter.name, error), "warn") + all_results = {} for _, pos in tree:iter() do - results[pos.id] = { status = "skipped" } + all_results[pos.id] = { status = "skipped" } end end - if pos.type ~= "test" then - self._runner:collect_results(tree, results) + if root.type ~= "test" then + self._runner:fill_results(tree, all_results) end - if pos.type == "test" or pos.type == "namespace" then - results[pos.path] = nil + if root.type == "test" or root.type == "namespace" then + all_results[root.path] = nil end - self._state:update_results(adapter_id, results) + self._state:update_results(adapter_id, all_results) end ---@async diff --git a/lua/neotest/client/runner.lua b/lua/neotest/client/runner.lua index 61d6c758..fe880285 100644 --- a/lua/neotest/client/runner.lua +++ b/lua/neotest/client/runner.lua @@ -19,83 +19,146 @@ end ---@param tree neotest.Tree ---@param args table ---@param adapter neotest.Adapter ----@return table -function TestRunner:_run_tree(tree, args, adapter) - args = args or {} - args.strategy = args.strategy or "integrated" - local position = tree:data() +function TestRunner:run_tree(tree, args, adapter, on_results) + local results = {} + local results_callback = function(results_) + on_results(results_) + for pos_id, result in ipairs(results_) do + results[pos_id] = result + end + end - local spec = adapter.build_spec(vim.tbl_extend("force", args, { - tree = tree, - })) + args = vim.tbl_extend("keep", args or {}, { strategy = "integrated" }) - local results = {} + self:_run_tree(tree, args, adapter, results_callback) + on_results(results) +end + +function TestRunner:_run_tree(tree, args, adapter, results_callback) + local spec = adapter.build_spec(vim.tbl_extend("force", args, { tree = tree })) if not spec then - local function run_pos_types(pos_type) - local async_runners = {} - for _, node in tree:iter_nodes() do - if node:data().type == pos_type then - table.insert(async_runners, function() - return self:_run_tree(node, args, adapter) - end) - end + self:_run_broken_down_tree(tree, args, adapter, results_callback) + return + end + self:_run_spec(spec, tree, args, adapter, results_callback) +end + +function TestRunner:_stream_queue() + local sender, receiver = async.control.channel.mpsc() + + local producer = function(output_stream) + local orig = "" + local pending_data = nil + for data in output_stream do + orig = orig .. data + local ends_with_newline = vim.endswith(data, "\n") + local next_lines = vim.split(data, "\n", { plain = true, trimempty = true }) + if pending_data then + next_lines[1] = pending_data .. next_lines[1] + pending_data = nil end - local all_results = {} - if #async_runners == 0 then - return {} + if not ends_with_newline then + pending_data = table.remove(next_lines, #next_lines) end - for i, res in ipairs(async.util.join(async_runners)) do - all_results[i] = res[1] + for _, line in ipairs(next_lines) do + sender.send(line) end - return vim.tbl_extend("error", {}, unpack(all_results)) end + end - if position.type == "dir" then - logger.warn(("%s doesn't support running directories, attempting files"):format(adapter.name)) - results = run_pos_types("file") - elseif position.type ~= "test" then - logger.warn(("%s doesn't support running %ss"):format(adapter.name, position.type)) - results = run_pos_types("test") - else - error(("%s returned no data to run tests"):format(adapter.name)) - end - else - spec.strategy = - vim.tbl_extend("force", spec.strategy or {}, config.strategies[args.strategy] or {}) + local consumer = function() + return receiver.recv() + end + return producer, consumer +end +---@param spec neotest.RunSpec +---@param adapter neotest.Adapter +function TestRunner:_run_spec(spec, tree, args, adapter, results_callback) + local position = tree:data() + spec.strategy = + vim.tbl_extend("force", spec.strategy or {}, config.strategies[args.strategy] or {}) spec.env = vim.tbl_extend("force", spec.env or {}, args.env or {}) spec.cwd = args.cwd or spec.cwd if vim.tbl_isempty(spec.env or {}) then spec.env = nil end - local process_result = - self._processes:run(self:_create_process_key(adapter.name, position.id), spec, args) - results = adapter.results(spec, process_result, tree) - if vim.tbl_isempty(results) then - if #tree:children() ~= 0 then - logger.warn("Results returned were empty, setting all positions to failed") - for _, pos in tree:iter() do - results[pos.id] = { - status = "failed", - errors = {}, - output = process_result.output, - } - end - else - results[tree:data().id] = { status = "skipped", output = process_result.output } - end - else - for _, result in pairs(results) do - if not result.output then - result.output = process_result.output - end + + + local proc_key = self:_create_process_key(adapter.name, position.id) + local producer, consumer = self:_stream_queue() + + local process_result = self._processes:run(proc_key, spec, args, spec.stream and producer) + if spec.stream then + async.run(function() + for stream_results in spec.stream(consumer) do + results_callback(stream_results) end + end) + end + + local results = adapter.results(spec, process_result, tree) + + if vim.tbl_isempty(results) then + results_callback(self:_fill_empty_results(tree, process_result.output)) + return + end + + self:fill_results(tree, results) + + for _, result in pairs(results) do + if not result.output then + result.output = process_result.output end end + + results_callback(results) +end + +function TestRunner:_fill_empty_results(tree, output_path) + if #tree:children() == 0 then + return { [tree:data().id] = { status = "skipped", output = output_path } } + end + local results = {} + logger.warn("Results returned were empty, setting all positions to failed") + for _, pos in tree:iter() do + results[pos.id] = { + status = "failed", + errors = {}, + output = output_path, + } + end return results end +function TestRunner:_run_broken_down_tree(tree, args, adapter, results_callback) + local position = tree:data() + local function run_pos_types(pos_type) + local async_runners = {} + for _, node in tree:iter_nodes() do + if node:data().type == pos_type then + table.insert(async_runners, function() + self:_run_tree(node, args, adapter, results_callback) + end) + end + end + if #async_runners == 0 then + return {} + end + async.util.join(async_runners) + end + + if position.type == "dir" then + logger.warn(("%s doesn't support running directories, attempting files"):format(adapter.name)) + return run_pos_types("file") + elseif position.type ~= "test" then + logger.warn(("%s doesn't support running %ss"):format(adapter.name, position.type)) + return run_pos_types("test") + end + error(("%s returned no data to run tests"):format(adapter.name)) +end + function TestRunner:_create_process_key(adapter_id, pos_id) return adapter_id .. "-" .. pos_id end @@ -146,8 +209,11 @@ function TestRunner:attach(position, adapter_id) end ---@async -function TestRunner:collect_results(tree, results) +---@param tree neotest.Tree +---@param results table +function TestRunner:fill_results(tree, results) local root = tree:data() + local missing_tests = {} for _, node in tree:iter_nodes() do local pos = node:data() @@ -176,6 +242,10 @@ function TestRunner:collect_results(tree, results) results[parent_pos.id] = parent_result end + else + if pos.type == "test" then + missing_tests[#missing_tests + 1] = pos.id + end end end @@ -186,7 +256,7 @@ function TestRunner:collect_results(tree, results) if pos.type == "file" then -- Files not being present means that they were skipped (probably) if not results[pos.id] and root_result then - results[pos.id] = { status = "skipped", output = root.output } + results[pos.id] = { status = "skipped", output = root_result.output } end else -- Tests and namespaces not being present means that they failed to even start, count as root result @@ -196,6 +266,12 @@ function TestRunner:collect_results(tree, results) end end end + + for _, test_id in ipairs(missing_tests) do + for parent in tree:get_key(test_id):iter_parents() do + results[parent:data().id] = nil + end + end end return function(processes) diff --git a/lua/neotest/client/strategies/init.lua b/lua/neotest/client/strategies/init.lua index a0f9fd31..323d678a 100644 --- a/lua/neotest/client/strategies/init.lua +++ b/lua/neotest/client/strategies/init.lua @@ -26,7 +26,8 @@ end ---@param spec neotest.RunSpec ---@param args? table ---@return neotest.StrategyResult -function NeotestProcessTracker:run(pos_id, spec, args) +function NeotestProcessTracker:run(pos_id, spec, args, process_stream) + --TODO Break this up so we can use instance.output_stream before awaiting finish local strategy = self:_get_strategy(args) logger.info("Starting process", pos_id, "with strategy", args.strategy) logger.debug("Strategy spec", spec) @@ -38,6 +39,13 @@ function NeotestProcessTracker:run(pos_id, spec, args) return { code = 1, output = output_path } end self._instances[pos_id] = instance + if process_stream then + async.run(function() + for data in instance.output_stream() do + process_stream(data) + end + end) + end local code = instance.result() logger.info("Process for position", pos_id, "exited with code", code) local output = instance.output() diff --git a/lua/neotest/client/strategies/integrated/init.lua b/lua/neotest/client/strategies/integrated/init.lua index 2620d962..6913d73c 100644 --- a/lua/neotest/client/strategies/integrated/init.lua +++ b/lua/neotest/client/strategies/integrated/init.lua @@ -1,5 +1,24 @@ local async = require("neotest.async") local lib = require("neotest.lib") +local FanoutAccum = require("neotest.types").FanoutAccum + +local function first(...) + local functions = { ... } + local send_ran, await_ran = async.control.channel.oneshot() + local result, ran + for _, func in ipairs(functions) do + async.run(function() + local func_result = func() + if not ran then + result = func_result + ran = true + send_ran() + end + end) + end + await_ran() + return result +end ---@class integratedStrategyConfig ---@field height integer @@ -14,12 +33,23 @@ return function(spec) local finish_cond = async.control.Condvar.new() local result_code = nil local command = spec.command + local data_accum = FanoutAccum(function(prev, new) + if not prev then + return new + end + return prev .. new + end, nil) - local unread_data = "" local attach_win, attach_buf, attach_chan local output_path = async.fn.tempname() local open_err, output_fd = async.uv.fs_open(output_path, "w", 438) assert(not open_err, open_err) + + data_accum:subscribe(function(data) + local write_err, _ = async.uv.fs_write(output_fd, data) + assert(not write_err, write_err) + end) + local success, job = pcall(async.fn.jobstart, command, { cwd = cwd, env = env, @@ -28,14 +58,7 @@ return function(spec) width = spec.strategy.width, on_stdout = function(_, data) async.run(function() - data = table.concat(data, "\n") - unread_data = unread_data .. data - local write_err, _ = async.uv.fs_write(output_fd, data) - assert(not write_err, write_err) - if attach_chan then - async.api.nvim_chan_send(attach_chan, unread_data) - unread_data = "" - end + data_accum:push(table.concat(data, "\n")) end) end, on_exit = function(_, code) @@ -59,6 +82,13 @@ return function(spec) stop = function() async.fn.jobstop(job) end, + output_stream = function() + local sender, receiver = async.control.channel.mpsc() + data_accum:subscribe(sender.send) + return function() + return first(finish_cond:wait(), receiver.recv) + end + end, attach = function() attach_buf = attach_buf or vim.api.nvim_create_buf(false, true) attach_chan = attach_chan @@ -81,10 +111,9 @@ return function(spec) }) attach_win:jump_to() - if unread_data ~= "" then - async.api.nvim_chan_send(attach_chan, unread_data) - unread_data = "" - end + data_accum:subscribe(function(data) + async.api.nvim_chan_send(attach_chan, data) + end) end, result = function() if result_code == nil then diff --git a/lua/neotest/types/fanout_accum.lua b/lua/neotest/types/fanout_accum.lua new file mode 100644 index 00000000..6cb67a06 --- /dev/null +++ b/lua/neotest/types/fanout_accum.lua @@ -0,0 +1,38 @@ +---Accumulates provided data and stores it, while sending to consumers. +---Allows consuming all data ever pushed while subscribing at any point in time. +---@class FanoutAccum +---@field consumers fun(data: T)[] +---@field data T | nil +---@field accum fun(prev: T, new: any): T A function to combine previous data and new data +local FanoutAccum = {} + +---@generic T +---@param accum fun(prev: T, new: any): T +---@param init T +---@return FanoutAccum +function FanoutAccum:new(accum, init) + self.__index = self + return setmetatable({ + data = init, + accum = accum, + consumers = {}, + }, self) +end + +function FanoutAccum:subscribe(cb) + self.consumers[#self.consumers + 1] = cb + if self.data then + cb(self.data) + end +end + +function FanoutAccum:push(data) + self.data = self.accum(self.data, data) + for _, cb in ipairs(self.consumers) do + cb(data) + end +end + +return function(accum, init) + return FanoutAccum:new(accum, init) +end diff --git a/lua/neotest/types/init.lua b/lua/neotest/types/init.lua index 73b8340d..fe6b7a97 100644 --- a/lua/neotest/types/init.lua +++ b/lua/neotest/types/init.lua @@ -16,11 +16,12 @@ ---@field line? integer ---@class neotest.Process ----@field output async fun():string Output data ----@field is_complete fun(): boolean Is process complete ----@field result async fun(): integer Get result code of process (async) +---@field output async fun()string Output data +---@field is_complete fun() boolean Is process complete +---@field result async fun() integer Get result code of process (async) ---@field attach async fun() Attach to the running process for user input ---@field stop async fun() Stop the running process +---@field output_stream fun(): string | nil Async iterator of process output ---@alias neotest.Strategy async fun(spec: neotest.RunSpec): neotest.Process @@ -39,6 +40,7 @@ ---@field cwd? string ---@field context? table Arbitrary data to preserve state between running and result collection ---@field strategy? table Arguments for strategy +---@field stream fun(output_stream: fun(): string[]): fun(): table ---@class neotest.ConsumerListeners ---@field discover_positions fun(adapter_id: string, path: string, tree: neotest.Tree) @@ -54,5 +56,6 @@ local M = {} M.Tree = require("neotest.types.tree") M.FIFOQueue = require("neotest.types.queue") +M.FanoutAccum = require("neotest.types.fanout_accum") return M