Skip to content

Commit

Permalink
Setup progress handlers with rich Progress
Browse files Browse the repository at this point in the history
  • Loading branch information
nitanmarcel committed Sep 10, 2024
1 parent 18e6cda commit 8b71ce7
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 2 deletions.
4 changes: 2 additions & 2 deletions r2ai/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .const import R2AI_HOMEDIR
from . import auto, LOGGER, logging
from .web import stop_http_server, server_running
from .progress import progress_bar

try:
from openai import OpenAI, OpenAIError
Expand Down Expand Up @@ -729,10 +730,10 @@ def keywords_ai(self, text):
mm = None
return [word.strip() for word in text0.split(',')]

@progress_bar("Thinking", color="yellow")
def chat(self, message=None):
global print
global Ginterrupted

if self.print is not None:
print = self.print

Expand Down Expand Up @@ -920,7 +921,6 @@ def respond(self):
model=openai_model,
max_tokens=maxtokens,
temperature=float(self.env["llm.temperature"]),
repeat_penalty=float(self.env["llm.repeat_penalty"]),
messages=self.messages,
extra_headers={
"HTTP-Referer": "https://rada.re", # openrouter specific: Optional, for including your app on openrouter.ai rankings.
Expand Down
2 changes: 2 additions & 0 deletions r2ai/pipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import traceback
import r2pipe
from .progress import progress_bar

have_rlang = False
r2lang = None
Expand Down Expand Up @@ -62,6 +63,7 @@ def get_r2_inst():
global r2
return r2

@progress_bar("Loading", color="yellow")
def open_r2(file, flags=[]):
global r2
r2 = r2pipe.open(file, flags=flags)
100 changes: 100 additions & 0 deletions r2ai/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, Task
from rich.console import Console
from .web import server_running, server_in_background
from inspect import signature
from functools import wraps


def _support_total(sig, *args, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
return "__progress_cls" in bound.arguments and "__progress_task" in bound.arguments


def progress_bar(text, color=None, total=None, infinite=False):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
sig = signature(func)
has_total = total is not None and _support_total(
sig, *args, **kwargs)
is_infinite = infinite or not has_total

if server_running() and not server_in_background():
return func(*args, **kwargs)

with Progress(SpinnerColumn(), *Progress.get_default_columns(), console=Console(no_color=not bool(color)), transient=True) as p:
task_text = f"[{color}]{text}" if color else text
task = p.add_task(
task_text, total=None if is_infinite else total)

if has_total:
result = func(
*args,
**kwargs,
__progress_cls=p,
__progress_task=task)
else:
result = func(*args, **kwargs)

return result
return wrapper
return decorator

# For consistency with the above
class ProgressBar:
def __init__(self, text, color=None, total=None, infinite=False) -> None:
self.text = text
self.color = color
self.total = total
self.infinite = infinite
self.progress: Progress = None
self.task: Task = None

def __enter__(self):
self.progress = Progress(
SpinnerColumn(),
*Progress.get_default_columns(),
console=Console(
no_color=not bool(
self.color)),
transient=True)
if self.color:
self.task = self.progress.add_task(
f"[{self.color}]{self.text}", total=None if self.infinite else self.total)
else:
self.task = self.progress.add_task(
f"{self.text}", total=None if self.infinite else self.total)
self.progress.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.progress:
self.progress.stop()

# from .progress import ProgressBar, progress_bar
# progress_bar is used as an decorator
# @progress_bar(Title, color="yellow")
# def foo():
# bar
#
#
# unlike in the class, the decorated functin can only use progressive progress only if
# __progress_cls and __progress_task are used as positional arguments. else it defaults to infinite
# @progress_bar("Title", color="yellow", total=100)
# def foo(a,b, __progress_cls=None, __progress_task=None):
# i = 1
# while True:
# progress_cls.update(p.task, advance=i)
# i+=1
# time.sleep(1)
#
#
# ProgressBar is made for consistency with the decorator
# import time
# with ProgressBar("Title", color="Yellow", total=50) as p:
# i = 0
# while True:
# p.progress.update(p.task, advance=i)
# i+=1
# time.sleep(1)

0 comments on commit 8b71ce7

Please sign in to comment.