diff --git a/r2ai/interpreter.py b/r2ai/interpreter.py index dbd4137..f1f4177 100644 --- a/r2ai/interpreter.py +++ b/r2ai/interpreter.py @@ -23,6 +23,7 @@ from .const import R2AI_HOMEDIR from . import auto, LOGGER, logging from .web import stop_http_server, server_running +from .progress import progress_bar try: from openai import OpenAI, OpenAIError @@ -729,10 +730,10 @@ def keywords_ai(self, text): mm = None return [word.strip() for word in text0.split(',')] + @progress_bar("Thinking", color="yellow") def chat(self, message=None): global print global Ginterrupted - if self.print is not None: print = self.print @@ -920,7 +921,6 @@ def respond(self): model=openai_model, max_tokens=maxtokens, temperature=float(self.env["llm.temperature"]), - repeat_penalty=float(self.env["llm.repeat_penalty"]), messages=self.messages, extra_headers={ "HTTP-Referer": "https://rada.re", # openrouter specific: Optional, for including your app on openrouter.ai rankings. diff --git a/r2ai/pipe.py b/r2ai/pipe.py index adcecf8..e0e7a57 100644 --- a/r2ai/pipe.py +++ b/r2ai/pipe.py @@ -1,6 +1,7 @@ import os import traceback import r2pipe +from .progress import progress_bar have_rlang = False r2lang = None @@ -62,6 +63,7 @@ def get_r2_inst(): global r2 return r2 +@progress_bar("Loading", color="yellow") def open_r2(file, flags=[]): global r2 r2 = r2pipe.open(file, flags=flags) diff --git a/r2ai/progress.py b/r2ai/progress.py new file mode 100644 index 0000000..944dda0 --- /dev/null +++ b/r2ai/progress.py @@ -0,0 +1,100 @@ +from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, Task +from rich.console import Console +from .web import server_running, server_in_background +from inspect import signature +from functools import wraps + + +def _support_total(sig, *args, **kwargs): + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + return "__progress_cls" in bound.arguments and "__progress_task" in bound.arguments + + +def progress_bar(text, color=None, total=None, infinite=False): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + sig = signature(func) + has_total = total is not None and _support_total( + sig, *args, **kwargs) + is_infinite = infinite or not has_total + + if server_running() and not server_in_background(): + return func(*args, **kwargs) + + with Progress(SpinnerColumn(), *Progress.get_default_columns(), console=Console(no_color=not bool(color)), transient=True) as p: + task_text = f"[{color}]{text}" if color else text + task = p.add_task( + task_text, total=None if is_infinite else total) + + if has_total: + result = func( + *args, + **kwargs, + __progress_cls=p, + __progress_task=task) + else: + result = func(*args, **kwargs) + + return result + return wrapper + return decorator + +# For consistency with the above +class ProgressBar: + def __init__(self, text, color=None, total=None, infinite=False) -> None: + self.text = text + self.color = color + self.total = total + self.infinite = infinite + self.progress: Progress = None + self.task: Task = None + + def __enter__(self): + self.progress = Progress( + SpinnerColumn(), + *Progress.get_default_columns(), + console=Console( + no_color=not bool( + self.color)), + transient=True) + if self.color: + self.task = self.progress.add_task( + f"[{self.color}]{self.text}", total=None if self.infinite else self.total) + else: + self.task = self.progress.add_task( + f"{self.text}", total=None if self.infinite else self.total) + self.progress.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.progress: + self.progress.stop() + +# from .progress import ProgressBar, progress_bar +# progress_bar is used as an decorator +# @progress_bar(Title, color="yellow") +# def foo(): +# bar +# +# +# unlike in the class, the decorated functin can only use progressive progress only if +# __progress_cls and __progress_task are used as positional arguments. else it defaults to infinite +# @progress_bar("Title", color="yellow", total=100) +# def foo(a,b, __progress_cls=None, __progress_task=None): +# i = 1 +# while True: + # progress_cls.update(p.task, advance=i) + # i+=1 + # time.sleep(1) +# +# +# ProgressBar is made for consistency with the decorator +# import time +# with ProgressBar("Title", color="Yellow", total=50) as p: +# i = 0 +# while True: + # p.progress.update(p.task, advance=i) + # i+=1 + # time.sleep(1)