Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] UI #47

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ r2ai = "r2ai:main"

[tool.setuptools.packages.find]
where = [".", "r2ai"]
include = ["main", "r2ai"]
include = ["main", "r2ai", "r2ai/ui"]
namespaces = true

[tool.setuptools.dynamic]
Expand Down
7 changes: 5 additions & 2 deletions r2ai/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,16 +797,18 @@ def chat(self, message=None):

# If it was, we respond non-interactively
self.messages.append({"role": "user", "content": message})
response = None
try:
self.respond()
response = self.respond()
self.clear_hints()
except Exception:
if Ginterrupted:
Ginterrupted = False
else:
traceback.print_exc()
self.end_active_block()

return response

def end_active_block(self):
# if self.env["chat.code"] == "false":
# return
Expand Down Expand Up @@ -938,6 +940,7 @@ def respond(self):
if self.env["chat.reply"] == "true":
self.messages.append({"role": "assistant", "content": response})
print(response)
return response
else:
self.logger.warn("OpenAi python not found. Falling back to requests library")
response = openapi.chat(self.messages, self.api_base, openai_model, self.api_key)
Expand Down
1 change: 1 addition & 0 deletions r2ai/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ def get_r2_inst():
def open_r2(file, flags=[]):
global r2
r2 = r2pipe.open(file, flags=flags)
return r2
18 changes: 13 additions & 5 deletions r2ai/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def r2_cmd(x):
r2ai -v show r2ai version (same as ?V)
r2ai -w ([port]) start webserver (curl -D hello http://localhost:8000)
r2ai -W ([port]) start webserver in background
r2ai -VV visual mode
r2ai -W ([port]) start webserver in background
r2ai -V (num) set log level for this session
0: NOTSET, 1: DEBUG, 2: INFO,
3: WARNING, 4: ERROR, 5: CRITICAL
Expand Down Expand Up @@ -174,6 +176,12 @@ def slurp_until(endword):
text += line
return text

def set_model(model):
ai = ais[0].ai
ai.model = model
ai.env["llm.model"] = ai.model
set_default_model(ai.model)

def runline(ai, usertext):
global print
global autoai
Expand All @@ -197,6 +205,10 @@ def runline(ai, usertext):
return r2ai_vars(ai, usertext[2:].strip())
except Exception:
traceback.print_exc()
if usertext.startswith("-VV"):
from ui.app import R2AIApp
R2AIApp().run()
return
if usertext.startswith("?V") or usertext.startswith("-v"):
r2ai_version()
elif usertext.startswith("<<"):
Expand Down Expand Up @@ -224,11 +236,7 @@ def runline(ai, usertext):
elif usertext.startswith("-m"):
words = usertext.split(" ")
if len(words) > 1:
if ai.model is not words[1]:
ai.llama_instance = None
ai.model = words[1]
ai.env["llm.model"] = ai.model
set_default_model(ai.model)
set_model(words[1])
else:
print(ai.model)
elif usertext == "reset" or usertext.startswith("-R"):
Expand Down
Empty file added r2ai/ui/__init__.py
Empty file.
242 changes: 242 additions & 0 deletions r2ai/ui/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
from textual.app import App, ComposeResult, SystemCommand
from textual.containers import ScrollableContainer, Container, Horizontal, VerticalScroll, Grid, Vertical # Add Vertical to imports
from textual.widgets import Header, Footer, Input, Button, Static, DirectoryTree, Label, Tree, Markdown
from textual.command import CommandPalette, Command, Provider, Hits, Hit
from textual.screen import Screen
from textual.message import Message
from textual.timer import Timer # Add this import
from textual.reactive import reactive
from .model_select import ModelSelect
from r2ai.pipe import open_r2
from typing import Iterable
import os
from pathlib import Path
from textual import work
from textual.widget import Widget
from textual.css.query import NoMatches

# from ..repl import set_model, r2ai_singleton
# ai = r2ai_singleton()
from .chat import chat, messages
import asyncio
from .db import get_env
r2 = None
class ModelSelectProvider(Provider):
async def search(self, query: str) -> Hits:
yield Hit("Select Model", "Select Model", self.action_select_model)


class ModelSelectDialog(Screen):
def compose(self) -> ComposeResult:
yield Grid(ModelSelect(), id="model-select-dialog")

def on_model_select_model_selected(self, event: ModelSelect.ModelSelected) -> None:
self.dismiss(event.model)


class ChatMessage(Markdown):
markdown = ""
def __init__(self, id: str, sender: str, content: str, ) -> None:
self.markdown = f"*{sender}:* {content}"
super().__init__(id=id, markdown=self.markdown)
def add_text(self, markdown: str) -> None:
self.markdown += markdown
self.update(self.markdown)


class R2AIApp(App):
CSS_PATH = "app.tcss"
BINDINGS = [
("ctrl+p", "show_command_palette", "Command Palette"),
]
TITLE = "r2ai"

def compose(self) -> ComposeResult:
yield Header()
yield Container(
VerticalScroll(

id="chat-container",
),
ScrollableContainer(
Horizontal(
Input(placeholder="Type your message here...", id="chat-input"),
Button("Send", variant="primary", id="send-button"),
id="input-container",
),
id="input-area",
),
id="content",
)
yield Footer()

def on_mount(self) -> None:
self.install_screen(CommandPalette(), name="command_palette")
# self.install_screen(BinarySelectDialog(), name="binary_select_dialog")

def action_show_command_palette(self) -> None:
self.push_screen("command_palette")

def action_select_model(self) -> None:
model = self.push_screen(ModelSelectDialog())
if model:
self.notify(f"Selected model: {get_env('model')}")

def action_load_binary(self) -> None:
self.push_screen(BinarySelectDialog())

def get_system_commands(self, screen: Screen) -> Iterable[SystemCommand]:
yield from super().get_system_commands(screen)
yield SystemCommand("Models", "Select Model", self.action_select_model)
yield SystemCommand("Load Binary", "Load Binary", self.action_load_binary) # Add this command

def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "send-button":
self.send_message()
def on_model_select_model_selected(self, event: ModelSelect.ModelSelected) -> None:
self.notify(f"Selected model: {event.model}")

async def on_input_submitted(self, event: Input.Submitted) -> None:
await self.send_message()

def on_message(self, type: str, message: any) -> None:
if type == 'message':
existing = None
try:
existing = self.query_one(f"#{message['id']}")
except NoMatches:
existing = self.add_message(message["id"], "AI", "")
print(existing)
existing.add_text(message["content"])
elif type == 'tool_call':
self.add_message(message["id"], "AI", f"*Tool Call:* {message['function']['name']}")
elif type == 'tool_response':
self.add_message(message["id"], "AI", f"*Tool Response:* {message['content']}")

async def send_message(self) -> None:
input_widget = self.query_one("#chat-input", Input)
message = input_widget.value.strip()
if message:
self.add_message(None, "User", message)
input_widget.value = ""
# Process the message and get AI response
# await chat(message, self.on_message)
resp = chat(message)
self.add_message(None, "AI", resp)
# self.add_message("AI", response)

def add_message(self, id: str, sender: str, content: str) -> None:
chat_container = self.query_one("#chat-container", VerticalScroll)
msg = ChatMessage(id, sender, content)
chat_container.mount(msg)
self.scroll_to_bottom()
return msg

def scroll_to_bottom(self) -> None:
chat_scroll = self.query_one("#chat-container", VerticalScroll)
chat_scroll.scroll_end(animate=False)

class Message(Widget):
def __init__(self, message: str) -> None:
super().__init__()
self.content = f'[bold]{message.role}[/] {message.content}'

def render(self) -> str:
return Markdown(self.content)

class Messages(Container):
def __init__(self, messages) -> None:
self.messages = messages
def compose(self) -> ComposeResult:
for message in self.messages:
yield Message(message)


class BinarySelectDialog(Screen):
BINDINGS = [
("up", "cursor_up", "Move cursor up"),
("down", "cursor_down", "Move cursor down"),
("enter", "select", "Select item"),
("escape", "app.pop_screen", "Close"),
("backspace", "go_up", "Go up one level"), # Add this binding
]

def compose(self) -> ComposeResult:
yield Grid(
Vertical(
Horizontal(
Label("Enter path:"),
Button("⬆️", id="up-button", variant="primary"),
id="path-header"
),
Input(placeholder="Enter path here...", id="path-input"),
DirectoryTree(Path.home(), id="file-browser"),
),
id="binary-select-dialog"
)

def on_mount(self) -> None:
self.path_input = self.query_one("#path-input", Input)
self.file_browser = self.query_one("#file-browser", DirectoryTree)
self.up_button = self.query_one("#up-button", Button)
self.set_focus(self.file_browser)
self.watch(self.path_input, "value", self.update_tree)

@work(thread=True)
def update_tree(self) -> None:
path = Path(self.path_input.value)
if path.exists():
self.file_browser.path = str(path)
elif path.parent.exists():
self.file_browser.path = str(path.parent)

def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "up-button":
self.go_up()

def action_go_up(self) -> None:
current_path = Path(self.file_browser.path)
parent_path = current_path.parent
if parent_path != current_path:
self.file_browser.path = str(parent_path)
self.path_input.value = str(parent_path)

def on_directory_tree_file_selected(self, event: DirectoryTree.FileSelected) -> None:
self.path_input.value = str(event.path)
self.open_and_analyze_binary(str(event.path))
self.dismiss(str(event.path))

@work(thread=True)
def open_and_analyze_binary(self, path: str) -> None:
global r2
r2 = open_r2(path)
r2.cmd("aaa")

def on_directory_tree_directory_selected(self, event: DirectoryTree.DirectorySelected) -> None:
self.path_input.value = str(event.path)

def action_cursor_up(self) -> None:
self.file_browser.action_cursor_up()

def action_cursor_down(self) -> None:
self.file_browser.action_cursor_down()

def action_select(self) -> None:
node = self.file_browser.cursor_node
if node.data.is_file:
self.dismiss(str(node.data.path))
else:
self.file_browser.toggle_node(node)

# def on_tree_node_highlighted(self, event: Tree.NodeHighlighted) -> None:
# self.path_input.value = str(event.node.data.path)
# self.path_input.cursor_position = len(self.path_input.value)

# def filter_paths(self, paths: Iterable[Path]) -> Iterable[Path]:
# return [path for path in paths if not path.name.startswith(".")]



app = R2AIApp()
app.run()

Loading