From df957cd97859806758c5ebf66b465b48312902a8 Mon Sep 17 00:00:00 2001 From: daniel nakov Date: Mon, 21 Oct 2024 15:15:30 -0400 Subject: [PATCH] UI fixes and cleanup --- r2ai/repl.py | 2 +- r2ai/ui/app.py | 75 +++++++++-------- r2ai/ui/app.tcss | 42 +++++++--- r2ai/ui/chat.py | 174 ++++++---------------------------------- r2ai/ui/db.py | 21 ----- r2ai/ui/model_select.py | 47 +++++++---- 6 files changed, 124 insertions(+), 237 deletions(-) delete mode 100644 r2ai/ui/db.py diff --git a/r2ai/repl.py b/r2ai/repl.py index 2372164..a1a5c7e 100644 --- a/r2ai/repl.py +++ b/r2ai/repl.py @@ -197,7 +197,7 @@ def runline(ai, usertext): traceback.print_exc() if usertext.startswith("-VV"): from ui.app import R2AIApp # pylint: disable=import-error - R2AIApp().run() + R2AIApp(ai=ai).run() return if usertext.startswith("?V") or usertext.startswith("-v"): r2ai_version() diff --git a/r2ai/ui/app.py b/r2ai/ui/app.py index 9201a61..7a99b5d 100644 --- a/r2ai/ui/app.py +++ b/r2ai/ui/app.py @@ -2,7 +2,7 @@ from textual.containers import ScrollableContainer, Container, Horizontal, VerticalScroll, Grid, Vertical # Add Vertical to imports from textual.widgets import Header, Footer, Input, Button, Static, DirectoryTree, Label, Tree, Markdown from textual.command import CommandPalette, Command, Provider, Hits, Hit -from textual.screen import Screen, ModalScreen +from textual.screen import Screen, ModalScreen, SystemModalScreen from textual.message import Message from textual.reactive import reactive from .model_select import ModelSelect @@ -18,23 +18,11 @@ from markdown_it import MarkdownIt # from ..repl import set_model, r2ai_singleton # ai = r2ai_singleton() -from .chat import chat, messages +from .chat import chat import asyncio -from .db import get_env import json -class ModelSelectProvider(Provider): - async def search(self, query: str) -> Hits: - yield Hit("Select Model", "Select Model", self.action_select_model) - -class ModelSelectDialog(ModalScreen): - def compose(self) -> ComposeResult: - yield Grid(ModelSelect(), id="model-select-dialog") - - def on_model_select_model_selected(self, event: ModelSelect.ModelSelected) -> None: - self.dismiss(event.model) - -class ModelConfigDialog(ModalScreen): +class ModelConfigDialog(SystemModalScreen): def __init__(self, keys: list[str]) -> None: super().__init__() self.keys = keys @@ -109,12 +97,19 @@ class R2AIApp(App): SUB_TITLE = None def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + super().__init__() + if 'ai' in kwargs: + self.ai = kwargs['ai'] + else: + class FakeAI: + model = 'gpt-4o' + messages = [] + self.ai = FakeAI() self.update_sub_title(get_filename()) def update_sub_title(self, binary: str = None) -> str: sub_title = None - model = get_env('model') + model = self.ai.model if binary and model: binary = Path(binary).name sub_title = f"{model} | {binary}" @@ -145,6 +140,7 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: self.install_screen(CommandPalette(), name="command_palette") + self.query_one("#chat-input", Input).focus() # self.install_screen(BinarySelectDialog(), name="binary_select_dialog") def action_show_command_palette(self) -> None: @@ -152,10 +148,11 @@ def action_show_command_palette(self) -> None: async def select_model(self) -> None: - model = await self.push_screen_wait(ModelSelectDialog()) + model = await self.push_screen_wait(ModelSelect()) if model: await self.validate_model() - self.notify(f"Selected model: {get_env('model')}") + self.ai.model = model + self.notify(f"Selected model: {self.ai.model}") self.update_sub_title() @work @@ -209,15 +206,15 @@ async def send_message(self) -> None: input_widget.value = "" try: await self.validate_model() - await chat(message, self.on_message) + await chat(self.ai, message, self.on_message) except Exception as e: self.notify(str(e), severity="error") async def validate_model(self) -> None: - model = get_env("model") + model = self.ai.model if not model: await self.select_model() - model = get_env("model") + model = self.ai.model keys = validate_environment(model) if keys['keys_in_environment'] is False: await self.push_screen_wait(ModelConfigDialog(keys['missing_keys'])) @@ -229,10 +226,12 @@ def add_message(self, id: str, sender: str, content: str) -> None: try: msg = ChatMessage(id, sender, content) chat_container.mount(msg) + return msg except Exception as e: pass - self.scroll_to_bottom() - return msg + finally: + self.scroll_to_bottom() + def scroll_to_bottom(self) -> None: chat_scroll = self.query_one("#chat-container", VerticalScroll) @@ -252,8 +251,14 @@ def compose(self) -> ComposeResult: for message in self.messages: yield Message(message) +class FilteredDirectoryTree(DirectoryTree): + input_path: reactive[Path] = reactive(Path.home()) + + def filter_paths(self, paths: Iterable[Path]) -> Iterable[Path]: + ps = [path for path in paths if str(path).startswith(str(self.input_path))] + return ps -class BinarySelectDialog(ModalScreen): +class BinarySelectDialog(SystemModalScreen): BINDINGS = [ ("up", "cursor_up", "Move cursor up"), ("down", "cursor_down", "Move cursor down"), @@ -263,27 +268,26 @@ class BinarySelectDialog(ModalScreen): ] def compose(self) -> ComposeResult: - yield Grid( - Vertical( - Input(placeholder="Enter path here...", id="path-input"), - DirectoryTree(Path.home(), id="file-browser"), - ), - id="binary-select-dialog" - ) + with Vertical(): + yield Input(placeholder="Enter path here...", id="path-input") + yield FilteredDirectoryTree(Path.home(), id="file-browser") def on_mount(self) -> None: self.path_input = self.query_one("#path-input", Input) self.file_browser = self.query_one("#file-browser", DirectoryTree) - self.set_focus(self.file_browser) + self.path_input.value = str(get_filename() or Path.home()) + self.path_input.focus() self.watch(self.path_input, "value", self.update_tree) @work(thread=True) def update_tree(self) -> None: path = Path(self.path_input.value) + if path.exists(): self.file_browser.path = str(path) elif path.parent.exists(): self.file_browser.path = str(path.parent) + self.file_browser.input_path = str(path) def on_button_pressed(self, event: Button.Pressed) -> None: if event.button.id == "up-button": @@ -316,6 +320,7 @@ def action_cursor_up(self) -> None: self.file_browser.action_cursor_up() def action_cursor_down(self) -> None: + self.file_browser.focus() self.file_browser.action_cursor_down() def action_select(self) -> None: @@ -324,5 +329,5 @@ def action_select(self) -> None: self.open_and_analyze_binary(str(node.data.path)) self.dismiss(str(node.data.path)) -app = R2AIApp() -app.run() +# app = R2AIApp() +# app.run() diff --git a/r2ai/ui/app.tcss b/r2ai/ui/app.tcss index a14f867..a850568 100644 --- a/r2ai/ui/app.tcss +++ b/r2ai/ui/app.tcss @@ -7,9 +7,9 @@ Placeholder { width: 15; height: 100; } - Screen { background: $surface; + layers: base overlay; } #content { @@ -27,9 +27,11 @@ Screen { padding: 1; overflow-y: scroll; layout: vertical; + layer: base; } #input-area { + layer: base; height: auto; margin-top: 1; } @@ -54,13 +56,6 @@ Screen { /* ... existing styles ... */ -#binary-select-dialog { - align: center middle; - width: 100%; - height: 100%; - background: $surface; - border: solid $primary; -} #binary-select-dialog Label { padding: 1 2; @@ -81,11 +76,6 @@ Screen { margin-left: 1; } -#file-browser { - height: 100%; - width: 100%; - align: center middle; -} .chat-message-container { height: auto; @@ -110,4 +100,30 @@ Static.label_sender { overflow: auto; padding-bottom: 1; border: solid white; +} + +ModelSelect { + display: none; + align: center middle; + Container { + margin: 0; + + width: 60%; + height: 50%; + overflow: hidden; + display: block; + OptionList { + margin-bottom: 3; + } + } +} + +BinarySelectDialog { + display: none; + align: center middle; + Vertical { + background: transparent; + height: 50%; + width: 50%; + } } \ No newline at end of file diff --git a/r2ai/ui/chat.py b/r2ai/ui/chat.py index 4baa385..647a357 100644 --- a/r2ai/ui/chat.py +++ b/r2ai/ui/chat.py @@ -1,157 +1,31 @@ from litellm import acompletion, ChatCompletionAssistantToolCall, ChatCompletionToolCallFunctionChunk import asyncio -from .db import get_env from r2ai.pipe import get_r2_inst -from .r2cmd import r2cmd -import json -from r2ai.repl import r2ai_singleton +from r2ai.auto import ChatAuto, SYSTEM_PROMPT_AUTO +import signal +from r2ai.tools import run_python, r2cmd -SYSTEM_PROMPT_AUTO = """ -You are a reverse engineer and you are using radare2 to analyze a binary. -The binary has already been loaded. -The user will ask questions about the binary and you will respond with the answer to the best of your ability. -Assume the user is always asking you about the binary, unless they're specifically asking you for radare2 help. -`this` or `here` might refer to the current address in the binary or the binary itself. -If you need more information, try to use the r2cmd tool to run commands before answering. -You can use the r2cmd tool multiple times if you need or you can pass a command with pipes if you need to chain commands. -If you're asked to decompile a function, make sure to return the code in the language you think it was originally written and rewrite it to be as easy as possible to be understood. Make sure you use descriptive variable and function names and add comments. -Don't just regurgitate the same code, figure out what it's doing and rewrite it to be more understandable. -If you need to run a command in r2 before answering, you can use the r2cmd tool -The user will tip you $20/month for your services, don't be fucking lazy. -Do not repeat commands if you already know the answer. -""" +def signal_handler(signum, frame): + raise KeyboardInterrupt -def run_python(command: str): - """runs a python script and returns the results""" - with open('r2ai_tmp.py', 'w') as f: - f.write(command) - # builtins.print('\x1b[1;32mRunning \x1b[4m' + "python code" + '\x1b[0m') - # builtins.print(command) - r2 = get_r2_inst() - r2.cmd('#!python r2ai_tmp.py > $tmp') - res = r2.cmd('cat $tmp') - r2.cmd('rm r2ai_tmp.py') - # builtins.print('\x1b[1;32mResult\x1b[0m\n' + res) - return res +async def chat(ai, message, cb): + model = ai.model.replace(":", "/") + tools = [r2cmd, run_python] + messages = ai.messages + [{"role": "user", "content": message}] + tool_choice = 'auto' -tools = [{ - "type": "function", - "function": { - "name": "r2cmd", - "description": "runs commands in radare2. You can run it multiple times or chain commands with pipes/semicolons. You can also use r2 interpreters to run scripts using the `#`, '#!', etc. commands. The output could be long, so try to use filters if possible or limit. This is your preferred tool.", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "command to run in radare2" - } - }, - "required": ["command"] - }, - } -}, { - "type": "function", - "function": { - "name": "run_python", - "description": "runs a python script and returns the results", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "python script to run" - } - }, - "required": ["command"] - } - } -}] -messages = [{"role": "system", "content": SYSTEM_PROMPT_AUTO}] -tool_end_message = '\nNOTE: The user saw this output, do not repeat it.' -async def process_tool_calls(tool_calls, cb): - if tool_calls: - for tool_call in tool_calls: - tool_name = tool_call["function"]["name"] - tool_args = json.loads(tool_call["function"]["arguments"]) - if cb: - cb('tool_call', { "id": tool_call["id"], "function": { "name": tool_name, "arguments": tool_args } }) - if tool_name == "r2cmd": - res = r2cmd(tool_args["command"]) - messages.append({"role": "tool", "name": tool_name, "content": res['output'] + tool_end_message, "tool_call_id": tool_call["id"]}) - if cb: - cb('tool_response', { "id": tool_call["id"] + '_response', "content": res['output'] }) - elif tool_name == "run_python": - res = run_python(tool_args["command"]) - messages.append({"role": "tool", "name": tool_name, "content": res + tool_end_message, "tool_call_id": tool_call["id"]}) - if cb: - cb('tool_response', { "id": tool_call["id"] + '_response', "content": res }) - - return await get_completion(cb) + chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb) + + original_handler = signal.getsignal(signal.SIGINT) -async def process_streaming_response(resp, cb): - tool_calls = [] - msgs = [] - async for chunk in resp: - delta = None - choice = chunk.choices[0] - delta = choice.delta - if delta.tool_calls: - delta_tool_calls = delta.tool_calls[0] - index = delta_tool_calls.index - fn_delta = delta_tool_calls.function - tool_call_id = delta_tool_calls.id - if len(tool_calls) < index + 1: - tool_calls.append({ - "id": tool_call_id, - "type": "function", - "function": { - "name":fn_delta.name, - "arguments": fn_delta.arguments - } - } - ) - - # handle some bug in llama-cpp-python streaming, tool_call.arguments is sometimes blank, but function_call has it. - # if fn_delta.arguments == '': - tool_calls[index]["function"]["arguments"] += fn_delta.arguments - # else: - # tool_calls[index]["function"]["arguments"] += fn_delta.arguments - else: - m = None - if delta.content is not None: - m = delta.content - if m is not None: - msgs.append(m) - if cb: - cb('message', { "content": m, "id": 'message_' + chunk.id, 'done': False }) - if 'finish_reason' in choice and choice['finish_reason'] is 'stop': - if cb: - cb('message', { "content": "", "id": 'message_' + chunk.id, 'done': True }) - if (len(tool_calls) > 0): - messages.append({"role": "assistant", "tool_calls": tool_calls}) - await process_tool_calls(tool_calls, cb) - if len(msgs) > 0: - response_message = ''.join(msgs) - messages.append({"role": "assistant", "content": response_message}) - return response_message - -async def get_completion(cb): - response = await acompletion( - model=get_env("model"), - messages=messages, - max_tokens=1024, - temperature=0.5, - tools=tools, - tool_choice="auto", - stream=True - ) - return await process_streaming_response(response, cb) - - -async def chat(message: str, cb) -> str: - messages.append({"role": "user", "content": message}) - if not get_env("model"): - raise Exception("No model selected") - response = await get_completion(cb) - return response + try: + signal.signal(signal.SIGINT, signal_handler) + return await chat_auto.chat() + except KeyboardInterrupt: + tasks = asyncio.all_tasks() + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + return None + finally: + signal.signal(signal.SIGINT, original_handler) \ No newline at end of file diff --git a/r2ai/ui/db.py b/r2ai/ui/db.py deleted file mode 100644 index 978039c..0000000 --- a/r2ai/ui/db.py +++ /dev/null @@ -1,21 +0,0 @@ -import dbm - -db = dbm.open(".r2ai.db", "c") - -def set_env(k, v): - db[k] = v.encode('utf-8') # Encode the value to bytes - -def get_env(k): - value = db.get(k) - if value is None: - return None - return value.decode('utf-8') # Decode bytes to string - -def close_db(): - if db is not None: - db.close() - -# Ensure the database is closed when the module is unloaded -import atexit -atexit.register(close_db) - diff --git a/r2ai/ui/model_select.py b/r2ai/ui/model_select.py index 0b6b60c..924b0ed 100644 --- a/r2ai/ui/model_select.py +++ b/r2ai/ui/model_select.py @@ -2,24 +2,34 @@ from textual.widgets import Input, OptionList from textual.widget import Widget from textual.widgets.option_list import Option +from textual.containers import Container from textual.message import Message from textual.binding import Binding +from textual.screen import ModalScreen, SystemModalScreen from textual import log # from ..models import models # from ..repl import set_model, r2ai_singleton # ai = r2ai_singleton() # MODELS = models().split("\n") -from litellm import model_list -from .db import get_env, set_env -MODELS = model_list +from litellm import models_by_provider +MODELS = [] +for provider in models_by_provider: + for model in models_by_provider[provider]: + MODELS.append(f"{provider}/{model}") +class ModalInput(Input): + BINDINGS = [ + Binding("down", "cursor_down", "Move down"), + ] -class ModelSelect(Widget): - # BINDINGS = [ - # Binding("up", "cursor_up", "Move up"), - # Binding("down", "cursor_down", "Move down"), - # Binding("enter", "select", "Select model"), - # ] + +class ModelSelect(SystemModalScreen): + BINDINGS = [ + Binding("up", "cursor_up", "Move up"), + Binding("down", "cursor_down", "Move down"), + Binding("enter", "select", "Select model"), + Binding("escape", "app.pop_screen", "Close"), + ] class ModelSelected(Message): """Event emitted when a model is selected.""" @@ -28,10 +38,12 @@ def __init__(self, model: str) -> None: super().__init__() def compose(self) -> ComposeResult: - self.input = Input(placeholder="Type to filter...") + self.input = ModalInput(placeholder="Type to filter...") self.option_list = OptionList() - yield self.input - yield self.option_list + with Container(): + yield self.input + yield self.option_list + def on_mount(self) -> None: self.options = [] @@ -42,8 +54,7 @@ def on_mount(self) -> None: self.options.append(Option(t, id=t)) self.option_list.add_options(self.options) self.filtered_options = self.options.copy() - - self.option_list.focus() + self.input.focus() def update_options(self, options): self.option_list.clear_options() @@ -59,12 +70,14 @@ def action_cursor_up(self) -> None: self.option_list.action_cursor_up() def action_cursor_down(self) -> None: - self.option_list.action_cursor_down() + if self.option_list.has_focus: + self.option_list.action_cursor_down() + else: + self.option_list.focus() def on_option_list_option_selected(self, index) -> None: selected_index = index.option_index if 0 <= selected_index < len(self.filtered_options): selected_option = self.filtered_options[selected_index] if not selected_option.disabled: - set_env("model", selected_option.id) - self.post_message(self.ModelSelected(selected_option.id)) + self.dismiss(selected_option.id)