diff --git a/src/drd/cli/commands.py b/src/drd/cli/commands.py index 0510a22..5b6cfc1 100644 --- a/src/drd/cli/commands.py +++ b/src/drd/cli/commands.py @@ -22,6 +22,11 @@ filemode='a') # Also log to console +# console = logging.StreamHandler() +# console.setLevel(logging.INFO) +# formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') +# console.setFormatter(formatter) +# logging.getLogger('').addHandler(console) def parse_multiline_input(input_string): diff --git a/src/drd/cli/monitor/error_resolver.py b/src/drd/cli/monitor/error_resolver.py index b6450b7..950f7d7 100644 --- a/src/drd/cli/monitor/error_resolver.py +++ b/src/drd/cli/monitor/error_resolver.py @@ -6,15 +6,11 @@ from ...prompts.monitor_error_resolution import get_error_resolution_prompt from ..query.file_operations import get_files_to_modify from ...utils.file_utils import get_file_content -from ...utils.input import get_user_confirmation import logging def monitoring_handle_error_with_dravid(error, error_trace, monitor): - if not get_user_confirmation("Do you want to proceed with the fix from Dravid?"): - return True - print_error(f"Error detected: {error}") logger = logging.getLogger(__name__) logger.info(f"Starting error handling for: {error}") @@ -22,14 +18,19 @@ def monitoring_handle_error_with_dravid(error, error_trace, monitor): # error_type = type(error).__name__ # error_trace = ''.join(traceback.format_exception( # type(error), error, error.__traceback__)) - + print("the type is *******") + print(error) + print("++++000-----") project_context = monitor.metadata_manager.get_project_context() print_info("Identifying relevant files for error context...") - error_details = """ - There is an error in the project. Strictly suggest only the files needed to fix the error. - error_trace: {error_trace} + error_details = f""" + There is an error in the project. Identify ony the files related to it + error_trace: {error} """ + input("testing >") + input("testing1 >") + return True files_to_check = run_with_loader( lambda: get_files_to_modify(error_details, project_context), @@ -85,10 +86,10 @@ def monitoring_handle_error_with_dravid(error, error_trace, monitor): logger.info(f"User response to restart: ") if requires_restart: print_info("The applied fix requires a server restart.") - restart_input = confirm_with_user( + restart_input = input( "Do you want to restart the server now? [y/N]: " ) - if restart_input: + if restart_input.lower() == 'y': print_info("Requesting server restart...") monitor.perform_restart() else: diff --git a/src/drd/cli/monitor/input_handler.py b/src/drd/cli/monitor/input_handler.py index b2b0c50..aa0498d 100644 --- a/src/drd/cli/monitor/input_handler.py +++ b/src/drd/cli/monitor/input_handler.py @@ -5,8 +5,10 @@ from ...utils import print_info, print_error from ...prompts.instructions import get_instruction_prompt from .input_parser import InputParser +from ...utils.input import get_input_with_timeout from ..query.main import execute_dravid_command import logging +from .state import ServerState class InputHandler: @@ -15,10 +17,14 @@ def __init__(self, monitor): self.logger = logging.getLogger(__name__) def handle_input(self): + if self.monitor.get_state() != ServerState.NORMAL: + self.logger.info( + "Input handling skipped: Server not in NORMAL state") + return self.logger.info("InputHandler triggered to handle input") print_info("\nNo more tasks to auto-process. What can I do next?") self._show_options() - user_input = input("> ") + user_input = input("") self.logger.info(f"Received user input: {user_input}") self._process_input(user_input) @@ -30,8 +36,13 @@ def _show_options(self): print_info("\nType your choice or command:") def _process_input(self, user_input): + state = self.monitor.get_state() + if state == ServerState.ERROR_HANDLING: + print("state from error handling interfering", state) + self.monitor.resume_error_handling(user_input, skip=True) + print("\n") + return True self.monitor.processing_input.set() - print("processing your input") try: if user_input.lower() == 'exit': confirm_exit = input( diff --git a/src/drd/cli/monitor/main.py b/src/drd/cli/monitor/main.py index 701029b..7766d30 100644 --- a/src/drd/cli/monitor/main.py +++ b/src/drd/cli/monitor/main.py @@ -26,31 +26,19 @@ def run_dev_server_with_monitoring(command: str): def handle_module_not_found(error_msg, monitor): - monitor.error_handling_in_progress.set() - try: - match = re.search( - r"(?:Cannot find module|Module not found|ImportError|No module named).*['\"](.*?)['\"]", error_msg, re.IGNORECASE) - if match: - module_name = match.group(1) - error = ImportError(f"Module '{module_name}' not found") - monitoring_handle_error_with_dravid(error, error_msg, monitor) - finally: - monitor.error_handling_in_progress.clear() + match = re.search( + r"(?:Cannot find module|Module not found|ImportError|No module named).*['\"](.*?)['\"]", error_msg, re.IGNORECASE) + if match: + module_name = match.group(1) + error = ImportError(f"Module '{module_name}' not found") + monitoring_handle_error_with_dravid(error, error_msg, monitor) def handle_syntax_error(error_msg, monitor): - monitor.error_handling_in_progress.set() - try: - error = SyntaxError(f"Syntax error detected: {error_msg}") - monitoring_handle_error_with_dravid(error, error_msg, monitor) - finally: - monitor.error_handling_in_progress.clear() + error = SyntaxError(f"Syntax error detected: {error_msg}") + monitoring_handle_error_with_dravid(error, error_msg, monitor) def handle_general_error(error_msg, monitor): - monitor.error_handling_in_progress.set() - try: - error = Exception(f"General error detected: {error_msg}") - monitoring_handle_error_with_dravid(error, error_msg, monitor) - finally: - monitor.error_handling_in_progress.clear() + error = Exception(f"General error detected: {error_msg}") + monitoring_handle_error_with_dravid(error, error_msg, monitor) diff --git a/src/drd/cli/monitor/output_monitor.py b/src/drd/cli/monitor/output_monitor.py index 2807f39..3a0bfcf 100644 --- a/src/drd/cli/monitor/output_monitor.py +++ b/src/drd/cli/monitor/output_monitor.py @@ -4,6 +4,7 @@ import time import threading from collections import deque +from .state import ServerState import re logger = logging.getLogger(__name__) @@ -39,6 +40,9 @@ def _monitor_output(self): logger.info("Starting output monitoring") while not self.stop_event.is_set(): + if self.monitor.get_state() != ServerState.NORMAL: + time.sleep(0.1) # Small sleep when not in normal state + continue if self.monitor.process is None or self.monitor.process.poll() is not None: logger.info( "Server process ended or not started. Waiting for restart...") @@ -91,8 +95,8 @@ def _check_for_errors(self, line): return False def _handle_error(self, error_context): + self.monitor.set_state(ServerState.ERROR_DETECTED) full_error = '\n'.join(error_context) - logger.error(f"Full error context:\n{full_error}") self.monitor.handle_error(full_error) def _check_idle_state(self): diff --git a/src/drd/cli/monitor/server_monitor.py b/src/drd/cli/monitor/server_monitor.py index 97cdffb..d845f67 100644 --- a/src/drd/cli/monitor/server_monitor.py +++ b/src/drd/cli/monitor/server_monitor.py @@ -2,12 +2,15 @@ import time import re import threading +from threading import Lock import subprocess from queue import Queue from .input_handler import InputHandler from .output_monitor import OutputMonitor from ...utils import print_info, print_success, print_error, print_header, print_prompt, print_warning from ...metadata.project_metadata import ProjectMetadataManager +from ...utils.input import get_user_confirmation +from .state import ServerState import logging @@ -21,10 +24,13 @@ def __init__(self, project_dir: str, error_handlers: dict, command: str): self.MAX_RETRIES = 3 self.error_handlers = error_handlers self.command = command + self.error_context = "" + self.error_handler = None self.process = None self.should_stop = threading.Event() self.restart_requested = threading.Event() self.processing_input = threading.Event() + self.skip_input = False self.input_handler = InputHandler(self) self.output_monitor = OutputMonitor(self) self.retry_count = 0 @@ -35,6 +41,8 @@ def __init__(self, project_dir: str, error_handlers: dict, command: str): } self.error_handlers['default'] = self.default_error_handler + self.state = ServerState.NORMAL + self.state_lock = Lock() logger.info( f"Initialized error handlers: {list(self.error_handlers.keys())}") @@ -112,23 +120,45 @@ def start_process(self): def _main_loop(self): try: while not self.should_stop.is_set(): - if self.error_handling_in_progress.is_set(): + current_state = self.get_state() + if current_state == ServerState.NORMAL: + if self.output_monitor.idle_detected.is_set(): + self.input_handler.handle_input() + self.output_monitor.idle_detected.clear() + elif current_state == ServerState.ERROR_DETECTED: + pass + elif current_state == ServerState.ERROR_HANDLING: # Wait for error handling to complete - self.error_handling_in_progress.wait() - elif self.output_monitor.idle_detected.is_set(): - self.input_handler.handle_input() - self.output_monitor.idle_detected.clear() - else: - # Small sleep to prevent busy waiting - self.should_stop.wait(timeout=0.1) + pass + elif current_state == ServerState.FIX_APPLYING: + # Wait for fix to be applied + pass + + # Small sleep to prevent busy waiting + self.should_stop.wait(timeout=0.1) except KeyboardInterrupt: - print_info("Stopping server...") + logger.info("Stopping server...") finally: self.stop() + def resume_error_handling(self, user_input, skip=False): + print("user_input", user_input, self.get_state()) + if user_input.lower() == 'y' and self.get_state() == ServerState.ERROR_HANDLING: + self.set_state(ServerState.FIX_APPLYING) + self.error_handler(self.error_context, self) + logger.info("CLEANING UP....") + self.clean_handlers() + self.skip_input = skip + + def clean_handlers(self): + if not self.skip_input: + self.error_context = "" + self.error_handler = None + self.set_state(ServerState.NORMAL) + def handle_error(self, error_context): logger.info("Entering handle_error method") - self.error_handling_in_progress.set() + self.set_state(ServerState.ERROR_HANDLING) self.output_monitor.idle_detected.clear() # print_warning("An error has been detected. Here's the context:") @@ -141,7 +171,14 @@ def handle_error(self, error_context): logger.info(f"Checking error pattern: {pattern}") if re.search(pattern, error_context, re.IGNORECASE): logger.info(f"Matched error pattern: {pattern}") - handler(error_context, self) + self.error_context = error_context + self.error_handler = handler + if not get_user_confirmation("Do you want to proceed with the fix from Dravid?"): + print("inside the confirmation...", self.get_state()) + self.clean_handlers() + return True + + self.resume_error_handling('y') break else: logger.warning( @@ -153,7 +190,7 @@ def handle_error(self, error_context): logger.error(f"Error during error handling: {str(e)}") print_error(f"Failed to handle the error: {str(e)}") - self.error_handling_in_progress.clear() + self.clean_handlers() logger.info("Exiting handle_error method") def default_error_handler(self, error_context, monitor): @@ -166,3 +203,12 @@ def default_error_handler(self, error_context, monitor): def request_restart(self): self.restart_requested.set() + + def set_state(self, new_state: ServerState): + with self.state_lock: + self.state = new_state + logger.info(f"Server state changed to: {self.state.name}") + + def get_state(self) -> ServerState: + with self.state_lock: + return self.state diff --git a/src/drd/cli/monitor/state.py b/src/drd/cli/monitor/state.py new file mode 100644 index 0000000..ce48e50 --- /dev/null +++ b/src/drd/cli/monitor/state.py @@ -0,0 +1,8 @@ +import enum + + +class ServerState(enum.Enum): + NORMAL = 1 + ERROR_DETECTED = 2 + ERROR_HANDLING = 3 + FIX_APPLYING = 4 diff --git a/src/drd/cli/query/file_operations.py b/src/drd/cli/query/file_operations.py index 563007e..7c500eb 100644 --- a/src/drd/cli/query/file_operations.py +++ b/src/drd/cli/query/file_operations.py @@ -7,6 +7,9 @@ def get_files_to_modify(query, project_context): + print("The query") + print(query) + print("======") file_query = get_files_to_modify_prompt(query, project_context) response = call_dravid_api_with_pagination( file_query, include_context=True) diff --git a/src/drd/utils/input.py b/src/drd/utils/input.py index d1e8ee1..b785afa 100644 --- a/src/drd/utils/input.py +++ b/src/drd/utils/input.py @@ -7,13 +7,19 @@ def confirm_with_user(msg): return click.confirm(f"{Fore.YELLOW} {msg} {Style.RESET_ALL}", default=False) -def get_user_confirmation(prompt): - sys.stdout.write(f"{prompt} [y/N] ") - sys.stdout.flush() - user_input = input().strip().lower() - if user_input in ['y', 'yes']: - return True - elif user_input in ['', 'n', 'no']: - return False - else: - print("Invalid input. Please enter 'y' for yes or 'n' for no.") +def get_user_confirmation(msg): + return click.confirm(f"{Fore.YELLOW} {msg} {Style.RESET_ALL}", default=False) + + +def get_input_with_timeout(msg, timeout): + try: + return click.prompt( + msg, + type=str, + default='', + show_default=False, + prompt_suffix='' + ) + except click.exceptions.Abort: + print("\nInput timeout reached.") + return None