From c1fa76e47eddefe37e6fdb342ca43ac6be0848d9 Mon Sep 17 00:00:00 2001 From: Daniel Nakov Date: Mon, 11 Nov 2024 20:43:43 +0000 Subject: [PATCH] Use lazy imports to speed up initial loading time --- r2ai/interpreter.py | 12 +++++++----- r2ai/models.py | 10 ++++------ r2ai/pipe.py | 3 +-- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/r2ai/interpreter.py b/r2ai/interpreter.py index 0645eba..7d84267 100644 --- a/r2ai/interpreter.py +++ b/r2ai/interpreter.py @@ -26,11 +26,9 @@ from .models import get_hf_llm, new_get_hf_llm, get_default_model from .voice import tts from .const import R2AI_HOMEDIR -from . import auto, LOGGER, logging +from . import LOGGER, logging from .web import stop_http_server, server_running from .progress import progress_bar -import litellm -from .completion import messages_to_prompt file_dir = os.path.dirname(__file__) sys.path.append(file_dir) @@ -90,6 +88,7 @@ def ddg(m): return f"Considering:\n```{res}\n```\n" def is_litellm_model(model): + from litellm import models_by_provider provider = None model_name = None if model.startswith ("/"): @@ -98,7 +97,7 @@ def is_litellm_model(model): provider, model_name = model.split(":") elif "/" in model: provider, model_name = model.split("/") - if provider in litellm.models_by_provider and model_name in litellm.models_by_provider[provider]: + if provider in models_by_provider and model_name in models_by_provider[provider]: return True return False @@ -378,6 +377,7 @@ def respond(self): # builtins.print(prompt) response = None if self.auto_run: + from . import auto if(is_litellm_model(self.model)): response = auto.chat(self) else: @@ -416,7 +416,8 @@ def respond(self): # {"role": "system", "content": "You are a poetic assistant, be creative."}, # {"role": "user", "content": "Compose a poem that explains the concept of recursion in programming."} # ] - completion = litellm.completion( + from litellm import completion as litellm_completion + completion = litellm_completion( model=self.model.replace(":", "/"), messages=self.messages, max_completion_tokens=maxtokens, @@ -452,6 +453,7 @@ def respond(self): "max_tokens": maxtokens } if self.env["chat.rawdog"] == "true": + from .completion import messages_to_prompt prompt = messages_to_prompt(self, messages) response = self.llama_instance(prompt, **chat_args) else: diff --git a/r2ai/models.py b/r2ai/models.py index ff48c30..c35912d 100644 --- a/r2ai/models.py +++ b/r2ai/models.py @@ -1,20 +1,14 @@ from .utils import slurp, dump -from huggingface_hub import hf_hub_download, login -from huggingface_hub import HfApi, list_repo_tree, get_paths_info from typing import Dict, List, Union import appdirs import builtins import inquirer import json -import llama_cpp import os import shutil import subprocess import sys import traceback -from transformers import AutoTokenizer -from llama_cpp.llama_tokenizer import LlamaHFTokenizer - # DEFAULT_MODEL = "TheBloke/CodeLlama-34B-Instruct-GGUF" # DEFAULT_MODEL = "TheBloke/llama2-7b-chat-codeCherryPop-qLoRA-GGUF" # DEFAULT_MODEL = "-m TheBloke/dolphin-2_6-phi-2-GGUF" @@ -263,6 +257,7 @@ def get_hf_llm(ai, repo_id, debug_mode, context_window): # Check if model was originally split split_files = [model["filename"] for model in raw_models if selected_model in model["filename"]] + from huggingface_hub import hf_hub_download if len(split_files) > 1: # Download splits for split_file in split_files: @@ -466,6 +461,7 @@ def list_gguf_files(repo_id: str) -> List[Dict[str, Union[str, float]]]: """ try: + from huggingface_hub import HfApi api = HfApi() tree = list(api.list_repo_tree(repo_id)) files_info = [file for file in tree if file.path.endswith('.gguf')] @@ -634,7 +630,9 @@ def supports_metal(): def get_llama_inst(repo_id, **kwargs): + import llama_cpp if 'functionary' in repo_id: + from llama_cpp.llama_tokenizer import LlamaHFTokenizer kwargs['tokenizer'] = LlamaHFTokenizer.from_pretrained(repo_id) filename = os.path.basename(kwargs.pop('model_path')) kwargs['echo'] = True diff --git a/r2ai/pipe.py b/r2ai/pipe.py index 731954b..17beb4c 100644 --- a/r2ai/pipe.py +++ b/r2ai/pipe.py @@ -1,7 +1,7 @@ import os import traceback import r2pipe -from .progress import progress_bar + have_rlang = False r2lang = None @@ -63,7 +63,6 @@ def r2singleton(): def get_r2_inst(): return r2singleton() -@progress_bar("Loading", color="yellow") def open_r2(file, flags=[]): global r2, filename, r2lang r2 = r2pipe.open(file, flags=flags)