Skip to content

Commit

Permalink
need to fix interferene
Browse files Browse the repository at this point in the history
  • Loading branch information
vysakh0 committed Aug 11, 2024
1 parent f833b14 commit 75695fe
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 57 deletions.
5 changes: 5 additions & 0 deletions src/drd/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
filemode='a')

# Also log to console
# console = logging.StreamHandler()
# console.setLevel(logging.INFO)
# formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
# console.setFormatter(formatter)
# logging.getLogger('').addHandler(console)


def parse_multiline_input(input_string):
Expand Down
21 changes: 11 additions & 10 deletions src/drd/cli/monitor/error_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,31 @@
from ...prompts.monitor_error_resolution import get_error_resolution_prompt
from ..query.file_operations import get_files_to_modify
from ...utils.file_utils import get_file_content
from ...utils.input import get_user_confirmation
import logging


def monitoring_handle_error_with_dravid(error, error_trace, monitor):
if not get_user_confirmation("Do you want to proceed with the fix from Dravid?"):
return True

print_error(f"Error detected: {error}")
logger = logging.getLogger(__name__)
logger.info(f"Starting error handling for: {error}")

# error_message = str(error)
# error_type = type(error).__name__
# error_trace = ''.join(traceback.format_exception(
# type(error), error, error.__traceback__))

print("the type is *******")
print(error)
print("++++000-----")
project_context = monitor.metadata_manager.get_project_context()

print_info("Identifying relevant files for error context...")
error_details = """
There is an error in the project. Strictly suggest only the files needed to fix the error.
error_trace: {error_trace}
error_details = f"""
There is an error in the project. Identify ony the files related to it
error_trace: {error}
"""
input("testing >")
input("testing1 >")
return True

files_to_check = run_with_loader(
lambda: get_files_to_modify(error_details, project_context),
Expand Down Expand Up @@ -85,10 +86,10 @@ def monitoring_handle_error_with_dravid(error, error_trace, monitor):
logger.info(f"User response to restart: ")
if requires_restart:
print_info("The applied fix requires a server restart.")
restart_input = confirm_with_user(
restart_input = input(
"Do you want to restart the server now? [y/N]: "
)
if restart_input:
if restart_input.lower() == 'y':
print_info("Requesting server restart...")
monitor.perform_restart()
else:
Expand Down
15 changes: 13 additions & 2 deletions src/drd/cli/monitor/input_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from ...utils import print_info, print_error
from ...prompts.instructions import get_instruction_prompt
from .input_parser import InputParser
from ...utils.input import get_input_with_timeout
from ..query.main import execute_dravid_command
import logging
from .state import ServerState


class InputHandler:
Expand All @@ -15,10 +17,14 @@ def __init__(self, monitor):
self.logger = logging.getLogger(__name__)

def handle_input(self):
if self.monitor.get_state() != ServerState.NORMAL:
self.logger.info(
"Input handling skipped: Server not in NORMAL state")
return
self.logger.info("InputHandler triggered to handle input")
print_info("\nNo more tasks to auto-process. What can I do next?")
self._show_options()
user_input = input("> ")
user_input = input("")
self.logger.info(f"Received user input: {user_input}")
self._process_input(user_input)

Expand All @@ -30,8 +36,13 @@ def _show_options(self):
print_info("\nType your choice or command:")

def _process_input(self, user_input):
state = self.monitor.get_state()
if state == ServerState.ERROR_HANDLING:
print("state from error handling interfering", state)
self.monitor.resume_error_handling(user_input, skip=True)
print("\n")
return True
self.monitor.processing_input.set()
print("processing your input")
try:
if user_input.lower() == 'exit':
confirm_exit = input(
Expand Down
32 changes: 10 additions & 22 deletions src/drd/cli/monitor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,19 @@ def run_dev_server_with_monitoring(command: str):


def handle_module_not_found(error_msg, monitor):
monitor.error_handling_in_progress.set()
try:
match = re.search(
r"(?:Cannot find module|Module not found|ImportError|No module named).*['\"](.*?)['\"]", error_msg, re.IGNORECASE)
if match:
module_name = match.group(1)
error = ImportError(f"Module '{module_name}' not found")
monitoring_handle_error_with_dravid(error, error_msg, monitor)
finally:
monitor.error_handling_in_progress.clear()
match = re.search(
r"(?:Cannot find module|Module not found|ImportError|No module named).*['\"](.*?)['\"]", error_msg, re.IGNORECASE)
if match:
module_name = match.group(1)
error = ImportError(f"Module '{module_name}' not found")
monitoring_handle_error_with_dravid(error, error_msg, monitor)


def handle_syntax_error(error_msg, monitor):
monitor.error_handling_in_progress.set()
try:
error = SyntaxError(f"Syntax error detected: {error_msg}")
monitoring_handle_error_with_dravid(error, error_msg, monitor)
finally:
monitor.error_handling_in_progress.clear()
error = SyntaxError(f"Syntax error detected: {error_msg}")
monitoring_handle_error_with_dravid(error, error_msg, monitor)


def handle_general_error(error_msg, monitor):
monitor.error_handling_in_progress.set()
try:
error = Exception(f"General error detected: {error_msg}")
monitoring_handle_error_with_dravid(error, error_msg, monitor)
finally:
monitor.error_handling_in_progress.clear()
error = Exception(f"General error detected: {error_msg}")
monitoring_handle_error_with_dravid(error, error_msg, monitor)
6 changes: 5 additions & 1 deletion src/drd/cli/monitor/output_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import threading
from collections import deque
from .state import ServerState
import re

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -39,6 +40,9 @@ def _monitor_output(self):

logger.info("Starting output monitoring")
while not self.stop_event.is_set():
if self.monitor.get_state() != ServerState.NORMAL:
time.sleep(0.1) # Small sleep when not in normal state
continue
if self.monitor.process is None or self.monitor.process.poll() is not None:
logger.info(
"Server process ended or not started. Waiting for restart...")
Expand Down Expand Up @@ -91,8 +95,8 @@ def _check_for_errors(self, line):
return False

def _handle_error(self, error_context):
self.monitor.set_state(ServerState.ERROR_DETECTED)
full_error = '\n'.join(error_context)
logger.error(f"Full error context:\n{full_error}")
self.monitor.handle_error(full_error)

def _check_idle_state(self):
Expand Down
70 changes: 58 additions & 12 deletions src/drd/cli/monitor/server_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import time
import re
import threading
from threading import Lock
import subprocess
from queue import Queue
from .input_handler import InputHandler
from .output_monitor import OutputMonitor
from ...utils import print_info, print_success, print_error, print_header, print_prompt, print_warning
from ...metadata.project_metadata import ProjectMetadataManager
from ...utils.input import get_user_confirmation
from .state import ServerState

import logging

Expand All @@ -21,10 +24,13 @@ def __init__(self, project_dir: str, error_handlers: dict, command: str):
self.MAX_RETRIES = 3
self.error_handlers = error_handlers
self.command = command
self.error_context = ""
self.error_handler = None
self.process = None
self.should_stop = threading.Event()
self.restart_requested = threading.Event()
self.processing_input = threading.Event()
self.skip_input = False
self.input_handler = InputHandler(self)
self.output_monitor = OutputMonitor(self)
self.retry_count = 0
Expand All @@ -35,6 +41,8 @@ def __init__(self, project_dir: str, error_handlers: dict, command: str):

}
self.error_handlers['default'] = self.default_error_handler
self.state = ServerState.NORMAL
self.state_lock = Lock()
logger.info(
f"Initialized error handlers: {list(self.error_handlers.keys())}")

Expand Down Expand Up @@ -112,23 +120,45 @@ def start_process(self):
def _main_loop(self):
try:
while not self.should_stop.is_set():
if self.error_handling_in_progress.is_set():
current_state = self.get_state()
if current_state == ServerState.NORMAL:
if self.output_monitor.idle_detected.is_set():
self.input_handler.handle_input()
self.output_monitor.idle_detected.clear()
elif current_state == ServerState.ERROR_DETECTED:
pass
elif current_state == ServerState.ERROR_HANDLING:
# Wait for error handling to complete
self.error_handling_in_progress.wait()
elif self.output_monitor.idle_detected.is_set():
self.input_handler.handle_input()
self.output_monitor.idle_detected.clear()
else:
# Small sleep to prevent busy waiting
self.should_stop.wait(timeout=0.1)
pass
elif current_state == ServerState.FIX_APPLYING:
# Wait for fix to be applied
pass

# Small sleep to prevent busy waiting
self.should_stop.wait(timeout=0.1)
except KeyboardInterrupt:
print_info("Stopping server...")
logger.info("Stopping server...")
finally:
self.stop()

def resume_error_handling(self, user_input, skip=False):
print("user_input", user_input, self.get_state())
if user_input.lower() == 'y' and self.get_state() == ServerState.ERROR_HANDLING:
self.set_state(ServerState.FIX_APPLYING)
self.error_handler(self.error_context, self)
logger.info("CLEANING UP....")
self.clean_handlers()
self.skip_input = skip

def clean_handlers(self):
if not self.skip_input:
self.error_context = ""
self.error_handler = None
self.set_state(ServerState.NORMAL)

def handle_error(self, error_context):
logger.info("Entering handle_error method")
self.error_handling_in_progress.set()
self.set_state(ServerState.ERROR_HANDLING)
self.output_monitor.idle_detected.clear()

# print_warning("An error has been detected. Here's the context:")
Expand All @@ -141,7 +171,14 @@ def handle_error(self, error_context):
logger.info(f"Checking error pattern: {pattern}")
if re.search(pattern, error_context, re.IGNORECASE):
logger.info(f"Matched error pattern: {pattern}")
handler(error_context, self)
self.error_context = error_context
self.error_handler = handler
if not get_user_confirmation("Do you want to proceed with the fix from Dravid?"):
print("inside the confirmation...", self.get_state())
self.clean_handlers()
return True

self.resume_error_handling('y')
break
else:
logger.warning(
Expand All @@ -153,7 +190,7 @@ def handle_error(self, error_context):
logger.error(f"Error during error handling: {str(e)}")
print_error(f"Failed to handle the error: {str(e)}")

self.error_handling_in_progress.clear()
self.clean_handlers()
logger.info("Exiting handle_error method")

def default_error_handler(self, error_context, monitor):
Expand All @@ -166,3 +203,12 @@ def default_error_handler(self, error_context, monitor):

def request_restart(self):
self.restart_requested.set()

def set_state(self, new_state: ServerState):
with self.state_lock:
self.state = new_state
logger.info(f"Server state changed to: {self.state.name}")

def get_state(self) -> ServerState:
with self.state_lock:
return self.state
8 changes: 8 additions & 0 deletions src/drd/cli/monitor/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import enum


class ServerState(enum.Enum):
NORMAL = 1
ERROR_DETECTED = 2
ERROR_HANDLING = 3
FIX_APPLYING = 4
3 changes: 3 additions & 0 deletions src/drd/cli/query/file_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@


def get_files_to_modify(query, project_context):
print("The query")
print(query)
print("======")
file_query = get_files_to_modify_prompt(query, project_context)
response = call_dravid_api_with_pagination(
file_query, include_context=True)
Expand Down
26 changes: 16 additions & 10 deletions src/drd/utils/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ def confirm_with_user(msg):
return click.confirm(f"{Fore.YELLOW} {msg} {Style.RESET_ALL}", default=False)


def get_user_confirmation(prompt):
sys.stdout.write(f"{prompt} [y/N] ")
sys.stdout.flush()
user_input = input().strip().lower()
if user_input in ['y', 'yes']:
return True
elif user_input in ['', 'n', 'no']:
return False
else:
print("Invalid input. Please enter 'y' for yes or 'n' for no.")
def get_user_confirmation(msg):
return click.confirm(f"{Fore.YELLOW} {msg} {Style.RESET_ALL}", default=False)


def get_input_with_timeout(msg, timeout):
try:
return click.prompt(
msg,
type=str,
default='',
show_default=False,
prompt_suffix=''
)
except click.exceptions.Abort:
print("\nInput timeout reached.")
return None

0 comments on commit 75695fe

Please sign in to comment.