From 32ff2ac21135a372a42b38ae131e531e64833bd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=B3n=C3=A1n=20Carrigan?= Date: Thu, 11 Jul 2024 15:12:28 +0100 Subject: [PATCH] feat(run): augment args Allow users to augment the arguments to all tests being run from a singular function. ```lua local nio = require("nio") neotest.setup({ run = { augment = function(tree, args) local name = nio.ui.input({ prompt = "What is your name?" }) args.env = { USER_NAME = name } return args end, }, }) ``` See #431 --- doc/neotest.txt | 9 +++++++++ lua/neotest/config/init.lua | 5 +++++ lua/neotest/consumers/run.lua | 18 ++++++++++++++++-- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/doc/neotest.txt b/doc/neotest.txt index 5d31f905..8a7cddb6 100644 --- a/doc/neotest.txt +++ b/doc/neotest.txt @@ -193,6 +193,7 @@ Default values: elixir = , go = " ;query\n ;Captures imported types\n (qualified_type name: (type_identifier) @symbol)\n ;Captures package-local and built-in types\n (type_identifier)@symbol\n ;Captures imported function calls and variables/constants\n (selector_expression field: (field_identifier) @symbol)\n ;Captures package-local functions calls\n (call_expression function: (identifier) @symbol)\n ", haskell = " ;query\n ;explicit import\n ((import_item [(variable)]) @symbol)\n ;symbols that may be imported implicitly\n ((type) @symbol)\n (qualified_variable (variable) @symbol)\n (exp_apply (exp_name (variable) @symbol))\n ((constructor) @symbol)\n ((operator) @symbol)\n ", + java = " ;query\n ;captures imported classes\n (import_declaration\n (scoped_identifier name: ((identifier) @symbol))\n )\n ", javascript = ' ;query\n ;Captures named imports\n (import_specifier name: (identifier) @symbol)\n ;Captures default import\n (import_clause (identifier) @symbol)\n ;Capture require statements\n (variable_declarator \n name: (identifier) @symbol\n value: (call_expression (identifier) @function (#eq? @function "require")))\n ;Capture namespace imports\n (namespace_import (identifier) @symbol)\n', lua = ' ;query\n ;Captures module names in require calls\n (function_call\n name: ((identifier) @function (#eq? @function "require"))\n arguments: (arguments (string) @symbol))\n ', python = " ;query\n ;Captures imports and modules they're imported from\n (import_from_statement (_ (identifier) @symbol))\n (import_statement (_ (identifier) @symbol))\n ", @@ -234,6 +235,7 @@ Fields~ {highlights} `(table)` {floating} `(neotest.Config.floating)` {strategies} `(neotest.Config.strategies)` +{run} `(neotest.Config.run)` {summary} `(neotest.Config.summary)` {output} `(neotest.Config.output)` {output_panel} `(neotest.Config.output_panel)` @@ -280,6 +282,13 @@ Fields~ Fields~ {integrated} `(neotest.Config.strategies.integrated)` + *neotest.Config.run* +Fields~ +{enabled} `(boolean)` +{augment?} `(fun(tree: neotest.Tree, arg: +neotest.run.RunArgs):neotest.run.RunArgs)` A function to augment the arguments +any tests being run + *neotest.Config.summary* Fields~ {enabled} `(boolean)` diff --git a/lua/neotest/config/init.lua b/lua/neotest/config/init.lua index 230ce890..01f9fa6b 100644 --- a/lua/neotest/config/init.lua +++ b/lua/neotest/config/init.lua @@ -54,6 +54,7 @@ local js_watch_query = [[ ---@field highlights table ---@field floating neotest.Config.floating ---@field strategies neotest.Config.strategies +---@field run neotest.Config.run ---@field summary neotest.Config.summary ---@field output neotest.Config.output ---@field output_panel neotest.Config.output_panel @@ -87,6 +88,10 @@ local js_watch_query = [[ ---@class neotest.Config.strategies ---@field integrated neotest.Config.strategies.integrated +---@class neotest.Config.run +---@field enabled boolean +---@field augment? fun(tree: neotest.Tree, arg: neotest.run.RunArgs):neotest.run.RunArgs A function to augment the arguments any tests being run + ---@class neotest.Config.summary ---@field enabled boolean ---@field animated boolean Enable/disable animation of icons diff --git a/lua/neotest/consumers/run.lua b/lua/neotest/consumers/run.lua index 262313e3..78807233 100644 --- a/lua/neotest/consumers/run.lua +++ b/lua/neotest/consumers/run.lua @@ -1,5 +1,6 @@ local nio = require("nio") local lib = require("neotest.lib") +local config = require("neotest.config") ---@private ---@type neotest.Client @@ -46,6 +47,17 @@ end ---@field [1] string? Position ID to run ---@field suite boolean Run the entire suite instead of a single position +local function augment_args(tree, args) + args = type(args) == "string" and { args } or args + args = args or {} + local aug = config.run.augment + if not aug then + return args + end + nio.scheduler() + return aug(tree, args) +end + --- Run the given position or the nearest position if not given. --- All arguments are optional --- @@ -70,7 +82,7 @@ function neotest.run.run(args) lib.notify("No tests found") return end - client:run_tree(tree, type(args) == "string" and { args } or args) + client:run_tree(tree, augment_args(tree, args)) end neotest.run.run = nio.create(neotest.run.run, 1) @@ -103,7 +115,7 @@ function neotest.run.run_last(args) lib.notify("Last test run no longer exists") return end - client:run_tree(tree, args) + client:run_tree(tree, augment_args(tree, args)) end) end @@ -147,6 +159,7 @@ function neotest.run.stop(args) end client:stop(pos, args) end + neotest.run.stop = nio.create(neotest.run.stop, 1) ---@class neotest.run.AttachArgs : neotest.client.AttachArgs @@ -173,6 +186,7 @@ function neotest.run.attach(args) end client:attach(pos, args) end + neotest.run.attach = nio.create(neotest.run.attach, 1) --- Get the list of all known adapter IDs.