7
7
from time import time , sleep
8
8
from rich .console import Console # import Rich's Console to enable recording
9
9
import re
10
- import io
11
10
12
11
# Use the package logger consistently
13
12
logger = get_logger (__name__ )
@@ -23,38 +22,35 @@ def setup_config():
23
22
os .environ .pop ("AUDIOKIT_LOG_LEVEL" , None )
24
23
os .environ .pop ("AUDIOKIT_LOG_FILE" , None )
25
24
26
- runner = CliRunner ()
27
-
28
- class AccumulatingStream (io .StringIO ):
29
- def __init__ (self , * args , ** kwargs ):
30
- super ().__init__ (* args , ** kwargs )
31
- self .accumulated = []
32
-
33
- def write (self , s ):
34
- self .accumulated .append (s )
35
- return super ().write (s )
25
+ @pytest .fixture (autouse = True )
26
+ def patch_console (monkeypatch ):
27
+ """
28
+ Monkeypatch the global console in audiokit.cli to use a recording Console.
29
+ """
30
+ import audiokit .cli as cli # import the module that created the console
31
+ recording_console = Console (force_terminal = True , record = True )
32
+ monkeypatch .setattr (cli , "console" , recording_console )
33
+ return recording_console
36
34
37
- def get_accumulated (self ):
38
- return '' .join (self .accumulated )
35
+ runner = CliRunner ()
39
36
40
37
def strip_ansi (text : str ) -> str :
41
38
# Regular expression for ANSI escape sequences
42
39
ansi_escape = re .compile (r'\x1B\[[0-?]*[ -/]*[@-~]' )
43
40
return ansi_escape .sub ('' , text )
44
41
45
- def test_cli_analyze (sample_audio_path ):
42
+ def test_cli_analyze (sample_audio_path , patch_console ):
46
43
"""Test CLI analyze command"""
47
44
logger .info ("Testing CLI analyze command" )
48
45
49
- # Create an accumulating stream and override the runner's output stream method.
50
- accum_stream = AccumulatingStream ()
51
- original_method = runner ._get_output_stream
52
- runner ._get_output_stream = lambda : accum_stream
53
46
result = runner .invoke (
54
47
app , ["analyze" , str (sample_audio_path )], catch_exceptions = False
55
48
)
56
- runner ._get_output_stream = original_method
57
- final_output = accum_stream .get_accumulated ()
49
+ # Allow a brief pause to let all output be recorded.
50
+ sleep (0.2 )
51
+ final_output = patch_console .export_text (clear = False )
52
+ if not final_output :
53
+ final_output = result .stdout
58
54
59
55
logger .debug ("CLI analyze command rendered output: {}" , final_output )
60
56
logger .debug ("CLI analyze command exit code: {}" , result .exit_code )
@@ -75,13 +71,9 @@ def test_cli_analyze(sample_audio_path):
75
71
assert "0.85" in final_output_clean # Guitar
76
72
assert "0.90" in final_output_clean # Drums
77
73
78
- def test_cli_process (sample_audio_path , tmp_path ):
74
+ def test_cli_process (sample_audio_path , tmp_path , patch_console ):
79
75
"""Test CLI process command"""
80
76
logger .info ("Testing CLI process command" )
81
- # Create an accumulating stream and override the runner's output stream method.
82
- accum_stream = AccumulatingStream ()
83
- original_method = runner ._get_output_stream
84
- runner ._get_output_stream = lambda : accum_stream
85
77
result = runner .invoke (
86
78
app , [
87
79
"process" ,
@@ -91,8 +83,10 @@ def test_cli_process(sample_audio_path, tmp_path):
91
83
],
92
84
catch_exceptions = False
93
85
)
94
- runner ._get_output_stream = original_method
95
- final_output = accum_stream .get_accumulated ()
86
+ sleep (0.2 )
87
+ final_output = patch_console .export_text (clear = False )
88
+ if not final_output :
89
+ final_output = result .stdout
96
90
97
91
logger .debug ("CLI process command rendered output: {}" , final_output )
98
92
logger .debug ("CLI process command exit code: {}" , result .exit_code )
0 commit comments