From db60fdd7bcad2de953a3f3183ed85a5cd54fdc01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 24 Oct 2024 22:39:05 +0200 Subject: [PATCH] black python formatter applyed --- llama_bringup/launch/base.launch.py | 95 ++++++++++++------- llama_bringup/launch/minicpm-2.6.launch.py | 15 ++- llama_bringup/launch/spaetzle.launch.py | 15 ++- llama_bringup/llama_bringup/utils.py | 49 ++++++---- llama_cli/llama_cli/api/__init__.py | 12 +-- llama_cli/llama_cli/command/llama.py | 5 +- llama_cli/llama_cli/verb/launch.py | 3 +- llama_cli/llama_cli/verb/prompt.py | 29 +++--- llama_cli/setup.py | 5 +- llama_cli/test/test_copyright.py | 8 +- llama_cli/test/test_flake8.py | 6 +- llama_cli/test/test_pep257.py | 4 +- .../llama_demos/chatllama_demo_node.py | 37 +++++--- llama_demos/llama_demos/llama_demo_node.py | 13 +-- .../llama_demos/llama_embeddings_demo_node.py | 6 +- .../llama_demos/llama_rag_demo_node.py | 23 +++-- .../llama_demos/llama_rerank_demo_node.py | 2 +- llama_demos/llama_demos/llava_demo_node.py | 29 +++--- .../llama_ros/langchain/chat_llama_ros.py | 42 ++++---- .../langchain/chat_prompt_formatter.py | 3 +- llama_ros/llama_ros/langchain/llama_ros.py | 5 +- .../llama_ros/langchain/llama_ros_common.py | 3 +- .../langchain/llama_ros_embeddings.py | 3 +- llama_ros/llama_ros/llama_client_node.py | 43 +++++---- 24 files changed, 272 insertions(+), 183 deletions(-) diff --git a/llama_bringup/launch/base.launch.py b/llama_bringup/launch/base.launch.py index eabc241..995c5e3 100644 --- a/llama_bringup/launch/base.launch.py +++ b/llama_bringup/launch/base.launch.py @@ -41,25 +41,25 @@ def run_llama(context: LaunchContext, embedding, reranking): "n_ctx": LaunchConfiguration("n_ctx", default=512), "n_batch": LaunchConfiguration("n_batch", default=2048), "n_ubatch": LaunchConfiguration("n_batch", default=512), - + # GPU params "n_gpu_layers": LaunchConfiguration("n_gpu_layers", default=0), "split_mode": LaunchConfiguration("split_mode", default="layer"), "main_gpu": LaunchConfiguration("main_gpu", default=0), "tensor_split": LaunchConfiguration("tensor_split", default="[0.0]"), - + # attn params "grp_attn_n": LaunchConfiguration("grp_attn_n", default=1), "grp_attn_w": LaunchConfiguration("grp_attn_w", default=512), - + # rope params "rope_freq_base": LaunchConfiguration("rope_freq_base", default=0.0), "rope_freq_scale": LaunchConfiguration("rope_freq_scale", default=0.0), "rope_scaling_type": LaunchConfiguration("rope_scaling_type", default=""), - + # yarn params "yarn_ext_factor": LaunchConfiguration("yarn_ext_factor", default=-1.0), "yarn_attn_factor": LaunchConfiguration("yarn_attn_factor", default=1.0), "yarn_beta_fast": LaunchConfiguration("yarn_beta_fast", default=32.0), "yarn_beta_slow": LaunchConfiguration("yarn_beta_slow", default=1.0), "yarn_orig_ctx": LaunchConfiguration("yarn_orig_ctx", default=0), - + # bool params "embedding": embedding, "reranking": reranking, "logits_all": LaunchConfiguration("logits_all", default=False), @@ -68,44 +68,67 @@ def run_llama(context: LaunchContext, embedding, reranking): "warmup": LaunchConfiguration("warmup", default=True), "check_tensors": LaunchConfiguration("check_tensors", default=False), "flash_attn": LaunchConfiguration("flash_attn", default=False), - + # cache params "no_kv_offload": LaunchConfiguration("no_kv_offload", default=False), "cache_type_k": LaunchConfiguration("cache_type_k", default="f16"), "cache_type_v": LaunchConfiguration("cache_type_v", default="f16"), - + # CPU params "n_threads": LaunchConfiguration("n_threads", default=1), "cpu_mask": LaunchConfiguration("cpu_mask", default=""), "cpu_range": LaunchConfiguration("cpu_range", default=""), "priority": LaunchConfiguration("priority", default="normal"), "strict_cpu": LaunchConfiguration("strict_cpu", default=False), "poll": LaunchConfiguration("poll", default=50), - + # batch CPU params "n_threads_batch": LaunchConfiguration("n_threads_batch", default=1), "cpu_mask_batch": LaunchConfiguration("cpu_mask_batch", default=""), "cpu_range_batch": LaunchConfiguration("cpu_range_batch", default=""), "priority_batch": LaunchConfiguration("priority_batch", default="normal"), "strict_cpu_batch": LaunchConfiguration("strict_cpu_batch", default=False), "poll_batch": LaunchConfiguration("poll_batch", default=50), - + # switch context params "n_predict": LaunchConfiguration("n_predict", default=128), "n_keep": LaunchConfiguration("n_keep", default=-1), - + # paths params "model": LaunchConfiguration("model", default=""), - "lora_adapters": ParameterValue(LaunchConfiguration("lora_adapters", default=[""]), value_type=List[str]), - "lora_adapters_scales": ParameterValue(LaunchConfiguration("lora_adapters_scales", default=[0.0]), value_type=List[float]), + "lora_adapters": ParameterValue( + LaunchConfiguration("lora_adapters", default=[""]), value_type=List[str] + ), + "lora_adapters_scales": ParameterValue( + LaunchConfiguration("lora_adapters_scales", default=[0.0]), + value_type=List[float], + ), "mmproj": LaunchConfiguration("mmproj", default=""), "numa": LaunchConfiguration("numa", default="none"), "pooling_type": LaunchConfiguration("pooling_type", default=""), - - "prefix": ParameterValue(LaunchConfiguration("prefix", default=""), value_type=str), - "suffix": ParameterValue(LaunchConfiguration("suffix", default=""), value_type=str), - "stopping_words": ParameterValue(LaunchConfiguration("stopping_words", default=[""]), value_type=List[str]), - "image_prefix": ParameterValue(LaunchConfiguration("image_prefix", default=""), value_type=str), - "image_suffix": ParameterValue(LaunchConfiguration("image_suffix", default=""), value_type=str), - "image_text": ParameterValue(LaunchConfiguration("image_text", default=""), value_type=str), - - "system_prompt": ParameterValue(LaunchConfiguration("system_prompt", default=""), value_type=str), - "system_prompt_file": ParameterValue(LaunchConfiguration("system_prompt_file", default=""), value_type=str), + # prefix/suffix + "prefix": ParameterValue( + LaunchConfiguration("prefix", default=""), value_type=str + ), + "suffix": ParameterValue( + LaunchConfiguration("suffix", default=""), value_type=str + ), + "stopping_words": ParameterValue( + LaunchConfiguration("stopping_words", default=[""]), + value_type=List[str], + ), + "image_prefix": ParameterValue( + LaunchConfiguration("image_prefix", default=""), value_type=str + ), + "image_suffix": ParameterValue( + LaunchConfiguration("image_suffix", default=""), value_type=str + ), + "image_text": ParameterValue( + LaunchConfiguration("image_text", default=""), value_type=str + ), + # prompt params + "system_prompt": ParameterValue( + LaunchConfiguration("system_prompt", default=""), value_type=str + ), + "system_prompt_file": ParameterValue( + LaunchConfiguration("system_prompt_file", default=""), value_type=str + ), + # debug "debug": LaunchConfiguration("debug", default=True), } @@ -123,32 +146,36 @@ def run_llama(context: LaunchContext, embedding, reranking): name=llama_node_name, namespace="llama", parameters=[params], - condition=UnlessCondition(PythonExpression( - [LaunchConfiguration("use_llava")])) + condition=UnlessCondition( + PythonExpression([LaunchConfiguration("use_llava")]) + ), ), Node( package="llama_ros", executable="llava_node", name="llava_node", namespace="llama", parameters=[params], - condition=IfCondition(PythonExpression( - [LaunchConfiguration("use_llava")])) + condition=IfCondition(PythonExpression([LaunchConfiguration("use_llava")])), ) embedding = LaunchConfiguration("embedding") embedding_cmd = DeclareLaunchArgument( "embedding", default_value="False", - description="Whether the model is an embedding model") + description="Whether the model is an embedding model", + ) reranking = LaunchConfiguration("reranking") reranking_cmd = DeclareLaunchArgument( "reranking", default_value="False", - description="Whether the model is an reranking model") - - return LaunchDescription([ - embedding_cmd, - reranking_cmd, - OpaqueFunction(function=run_llama, args=[embedding, reranking]) - ]) + description="Whether the model is an reranking model", + ) + + return LaunchDescription( + [ + embedding_cmd, + reranking_cmd, + OpaqueFunction(function=run_llama, args=[embedding, reranking]), + ] + ) diff --git a/llama_bringup/launch/minicpm-2.6.launch.py b/llama_bringup/launch/minicpm-2.6.launch.py index 63b6a47..16a924a 100644 --- a/llama_bringup/launch/minicpm-2.6.launch.py +++ b/llama_bringup/launch/minicpm-2.6.launch.py @@ -28,7 +28,14 @@ def generate_launch_description(): - return LaunchDescription([ - create_llama_launch_from_yaml(os.path.join( - get_package_share_directory("llama_bringup"), "models", "MiniCPM-2.6.yaml")) - ]) + return LaunchDescription( + [ + create_llama_launch_from_yaml( + os.path.join( + get_package_share_directory("llama_bringup"), + "models", + "MiniCPM-2.6.yaml", + ) + ) + ] + ) diff --git a/llama_bringup/launch/spaetzle.launch.py b/llama_bringup/launch/spaetzle.launch.py index 2966b53..1528cc9 100644 --- a/llama_bringup/launch/spaetzle.launch.py +++ b/llama_bringup/launch/spaetzle.launch.py @@ -28,7 +28,14 @@ def generate_launch_description(): - return LaunchDescription([ - create_llama_launch_from_yaml(os.path.join( - get_package_share_directory("llama_bringup"), "models", "Spaetzle.yaml")) - ]) + return LaunchDescription( + [ + create_llama_launch_from_yaml( + os.path.join( + get_package_share_directory("llama_bringup"), + "models", + "Spaetzle.yaml", + ) + ) + ] + ) diff --git a/llama_bringup/llama_bringup/utils.py b/llama_bringup/llama_bringup/utils.py index e5fc22b..35b6725 100644 --- a/llama_bringup/llama_bringup/utils.py +++ b/llama_bringup/llama_bringup/utils.py @@ -33,34 +33,41 @@ def download_model(repo: str, file: str) -> str: - match = re.search(r'-(\d+)-of-(\d+)\.gguf', file) + match = re.search(r"-(\d+)-of-(\d+)\.gguf", file) if match: total_shards = int(match.group(2)) - base_name = file[:match.start()] + base_name = file[: match.start()] # download shards for i in range(1, total_shards + 1): shard_file = f"{base_name}-{i:05d}-of-{total_shards:05d}.gguf" - hf_hub_download(repo_id=repo, filename=shard_file, - force_download=False) + hf_hub_download(repo_id=repo, filename=shard_file, force_download=False) # return first shard return hf_hub_download( repo_id=repo, filename=f"{base_name}-00001-of-{total_shards:05d}.gguf", - force_download=False + force_download=False, ) return hf_hub_download(repo_id=repo, filename=file, force_download=False) def load_prompt_type(prompt_file_name: str) -> Tuple: - file_path = os.path.join(get_package_share_directory( - "llama_bringup"), "prompts", f"{prompt_file_name}.yaml") + file_path = os.path.join( + get_package_share_directory("llama_bringup"), + "prompts", + f"{prompt_file_name}.yaml", + ) with open(file_path, "r") as file: yaml_data = yaml.safe_load(file) - return yaml_data["prefix"], yaml_data["suffix"], yaml_data["stopping_words"], yaml_data["system_prompt"] + return ( + yaml_data["prefix"], + yaml_data["suffix"], + yaml_data["stopping_words"], + yaml_data["system_prompt"], + ) def create_llama_launch_from_yaml(file_path: str) -> IncludeLaunchDescription: @@ -70,8 +77,11 @@ def create_llama_launch_from_yaml(file_path: str) -> IncludeLaunchDescription: def create_llama_launch(**kwargs) -> IncludeLaunchDescription: - prompt_data = load_prompt_type(kwargs["system_prompt_type"]) if kwargs.get( - "system_prompt_type") else ("", "", [], "") + prompt_data = ( + load_prompt_type(kwargs["system_prompt_type"]) + if kwargs.get("system_prompt_type") + else ("", "", [], "") + ) kwargs["prefix"] = kwargs.get("prefix", prompt_data[0]) kwargs["suffix"] = kwargs.get("suffix", prompt_data[1]) kwargs["system_prompt"] = kwargs.get("system_prompt", prompt_data[3]) @@ -83,9 +93,14 @@ def create_llama_launch(**kwargs) -> IncludeLaunchDescription: # load models for key in ["model", "mmproj"]: - if not kwargs.get(key) and kwargs.get(f"{key}_repo") and kwargs.get(f"{key}_filename"): + if ( + not kwargs.get(key) + and kwargs.get(f"{key}_repo") + and kwargs.get(f"{key}_filename") + ): kwargs[key] = download_model( - kwargs[f"{key}_repo"], kwargs[f"{key}_filename"]) + kwargs[f"{key}_repo"], kwargs[f"{key}_filename"] + ) # load lora adapters lora_adapters = [] @@ -122,8 +137,10 @@ def create_llama_launch(**kwargs) -> IncludeLaunchDescription: kwargs["use_llava"] = False return IncludeLaunchDescription( - PythonLaunchDescriptionSource(os.path.join( - get_package_share_directory("llama_bringup"), "launch", "base.launch.py")), - launch_arguments={key: str(value) - for key, value in kwargs.items()}.items() + PythonLaunchDescriptionSource( + os.path.join( + get_package_share_directory("llama_bringup"), "launch", "base.launch.py" + ) + ), + launch_arguments={key: str(value) for key, value in kwargs.items()}.items(), ) diff --git a/llama_cli/llama_cli/api/__init__.py b/llama_cli/llama_cli/api/__init__.py index fcd447e..5308408 100644 --- a/llama_cli/llama_cli/api/__init__.py +++ b/llama_cli/llama_cli/api/__init__.py @@ -52,16 +52,15 @@ def launch_llm(file_path: str) -> None: print(f"File '{file_path}' does not exists") return - ld = LaunchDescription([ - create_llama_launch_from_yaml(file_path) - ]) + ld = LaunchDescription([create_llama_launch_from_yaml(file_path)]) ls = LaunchService() ls.include_launch_description(ld) ls.run() -def prompt_llm(prompt: str, reset: bool = False, - temp: float = 0.8, image_url: str = "") -> None: +def prompt_llm( + prompt: str, reset: bool = False, temp: float = 0.8, image_url: str = "" +) -> None: rclpy.init() llama_client = LlamaClientNode() @@ -71,8 +70,7 @@ def prompt_llm(prompt: str, reset: bool = False, goal.sampling_config.temp = temp if image_url: - req = urllib.request.Request( - image_url, headers={"User-Agent": "Mozilla/5.0"}) + req = urllib.request.Request(image_url, headers={"User-Agent": "Mozilla/5.0"}) response = urllib.request.urlopen(req) arr = np.asarray(bytearray(response.read()), dtype=np.uint8) img = cv2.imdecode(arr, -1) diff --git a/llama_cli/llama_cli/command/llama.py b/llama_cli/llama_cli/command/llama.py index e988ad2..1f4638f 100644 --- a/llama_cli/llama_cli/command/llama.py +++ b/llama_cli/llama_cli/command/llama.py @@ -29,8 +29,9 @@ class LlamaCommand(CommandExtension): def add_arguments(self, parser, cli_name): self._subparser = parser - add_subparsers_on_demand(parser, cli_name, "_verb", - "llama_cli.verb", required=False) + add_subparsers_on_demand( + parser, cli_name, "_verb", "llama_cli.verb", required=False + ) def main(self, *, parser, args): if not hasattr(args, "_verb"): diff --git a/llama_cli/llama_cli/verb/launch.py b/llama_cli/llama_cli/verb/launch.py index 563cbce..384b3a3 100644 --- a/llama_cli/llama_cli/verb/launch.py +++ b/llama_cli/llama_cli/verb/launch.py @@ -28,8 +28,7 @@ class LaunchVerb(VerbExtension): def add_arguments(self, parser, cli_name): - arg = parser.add_argument( - "file_path", help="path to the YAML of the LLM") + arg = parser.add_argument("file_path", help="path to the YAML of the LLM") def main(self, *, args): launch_llm(args.file_path) diff --git a/llama_cli/llama_cli/verb/prompt.py b/llama_cli/llama_cli/verb/prompt.py index f7576d7..00937ac 100644 --- a/llama_cli/llama_cli/verb/prompt.py +++ b/llama_cli/llama_cli/verb/prompt.py @@ -28,21 +28,26 @@ class PromptVerb(VerbExtension): def add_arguments(self, parser, cli_name): - arg = parser.add_argument( - "prompt", help="prompt text for the LLM") + arg = parser.add_argument("prompt", help="prompt text for the LLM") parser.add_argument( - "-r", "--reset", + "-r", + "--reset", action="store_true", - help="Whether to reset the LLM and its context before prompting") + help="Whether to reset the LLM and its context before prompting", + ) parser.add_argument( - "-t", "--temp", - metavar="N", type=positive_float, default=0.8, - help="Temperature value (default: 0.8)") + "-t", + "--temp", + metavar="N", + type=positive_float, + default=0.8, + help="Temperature value (default: 0.8)", + ) parser.add_argument( - "--image-url", - type=str, default="", - help="Image URL to sent to the VLM") + "--image-url", type=str, default="", help="Image URL to sent to the VLM" + ) def main(self, *, args): - prompt_llm(args.prompt, reset=args.reset, - temp=args.temp, image_url=args.image_url) + prompt_llm( + args.prompt, reset=args.reset, temp=args.temp, image_url=args.image_url + ) diff --git a/llama_cli/setup.py b/llama_cli/setup.py index 0f4c3d9..1d424d1 100644 --- a/llama_cli/setup.py +++ b/llama_cli/setup.py @@ -1,10 +1,9 @@ - from setuptools import setup from setuptools import find_packages setup( name="llama_cli", - version="0.0.0", + version="4.0.5", packages=find_packages(exclude=["test"]), zip_safe=True, author="Miguel Ángel González Santamarta", @@ -25,5 +24,5 @@ "launch = llama_cli.verb.launch:LaunchVerb", "prompt = llama_cli.verb.prompt:PromptVerb", ], - } + }, ) diff --git a/llama_cli/test/test_copyright.py b/llama_cli/test/test_copyright.py index 97a3919..ceffe89 100644 --- a/llama_cli/test/test_copyright.py +++ b/llama_cli/test/test_copyright.py @@ -17,9 +17,11 @@ # Remove the `skip` decorator once the source file(s) have a copyright header -@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') +@pytest.mark.skip( + reason="No copyright header has been placed in the generated source file." +) @pytest.mark.copyright @pytest.mark.linter def test_copyright(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found errors' + rc = main(argv=[".", "test"]) + assert rc == 0, "Found errors" diff --git a/llama_cli/test/test_flake8.py b/llama_cli/test/test_flake8.py index 27ee107..ee79f31 100644 --- a/llama_cli/test/test_flake8.py +++ b/llama_cli/test/test_flake8.py @@ -20,6 +20,6 @@ @pytest.mark.linter def test_flake8(): rc, errors = main_with_errors(argv=[]) - assert rc == 0, \ - 'Found %d code style errors / warnings:\n' % len(errors) + \ - '\n'.join(errors) + assert rc == 0, "Found %d code style errors / warnings:\n" % len( + errors + ) + "\n".join(errors) diff --git a/llama_cli/test/test_pep257.py b/llama_cli/test/test_pep257.py index b234a38..a2c3deb 100644 --- a/llama_cli/test/test_pep257.py +++ b/llama_cli/test/test_pep257.py @@ -19,5 +19,5 @@ @pytest.mark.linter @pytest.mark.pep257 def test_pep257(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found code style errors / warnings' + rc = main(argv=[".", "test"]) + assert rc == 0, "Found code style errors / warnings" diff --git a/llama_demos/llama_demos/chatllama_demo_node.py b/llama_demos/llama_demos/chatllama_demo_node.py index 82ea274..977ba33 100644 --- a/llama_demos/llama_demos/chatllama_demo_node.py +++ b/llama_demos/llama_demos/chatllama_demo_node.py @@ -41,9 +41,9 @@ def __init__(self) -> None: super().__init__("chat_llama_demo_node") self.declare_parameter( - "prompt", "Who is the character in the middle of the image?") - self.prompt = self.get_parameter( - "prompt").get_parameter_value().string_value + "prompt", "Who is the character in the middle of the image?" + ) + self.prompt = self.get_parameter("prompt").get_parameter_value().string_value self.cv_bridge = CvBridge() @@ -58,20 +58,27 @@ def send_prompt(self) -> None: penalty_last_n=8, ) - self.prompt = ChatPromptTemplate.from_messages([ - SystemMessage( - "You are a IA that just answer with a single word."), - HumanMessagePromptTemplate.from_template(template=[ - {"type": "text", "text": f"{self.prompt}"}, - {"type": "image_url", "image_url": "{image_url}"} - ]) - ]) + self.prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage("You are a IA that just answer with a single word."), + HumanMessagePromptTemplate.from_template( + template=[ + {"type": "text", "text": f"{self.prompt}"}, + {"type": "image_url", "image_url": "{image_url}"}, + ] + ), + ] + ) self.chain = self.prompt | self.chat | StrOutputParser() self.initial_time = time.time() - for text in self.chain.stream({"image_url": "https://pics.filmaffinity.com/Dragon_Ball_Bola_de_Dragaon_Serie_de_TV-973171538-large.jpg"}): + for text in self.chain.stream( + { + "image_url": "https://pics.filmaffinity.com/Dragon_Ball_Bola_de_Dragaon_Serie_de_TV-973171538-large.jpg" + } + ): self.tokens += 1 print(text, end="", flush=True) if self.eval_time < 0: @@ -81,10 +88,10 @@ def send_prompt(self) -> None: self.get_logger().info("END") end_time = time.time() + self.get_logger().info(f"Time to eval: {self.eval_time - self.initial_time} s") self.get_logger().info( - f"Time to eval: {self.eval_time - self.initial_time} s") - self.get_logger().info( - f"Prediction speed: {self.tokens / (end_time - self.eval_time)} t/s") + f"Prediction speed: {self.tokens / (end_time - self.eval_time)} t/s" + ) def main(): diff --git a/llama_demos/llama_demos/llama_demo_node.py b/llama_demos/llama_demos/llama_demo_node.py index 134f03e..17f61e5 100644 --- a/llama_demos/llama_demos/llama_demo_node.py +++ b/llama_demos/llama_demos/llama_demo_node.py @@ -36,9 +36,10 @@ def __init__(self) -> None: super().__init__("llama_demo_node") self.declare_parameter( - "prompt", "Do you know the city of León from Spain? Can you tell me a bit about its history?") - self.prompt = self.get_parameter( - "prompt").get_parameter_value().string_value + "prompt", + "Do you know the city of León from Spain? Can you tell me a bit about its history?", + ) + self.prompt = self.get_parameter("prompt").get_parameter_value().string_value self.tokens = 0 self.initial_time = -1 @@ -66,10 +67,10 @@ def send_prompt(self) -> None: self.get_logger().info("END") end_time = time.time() + self.get_logger().info(f"Time to eval: {self.eval_time - self.initial_time} s") self.get_logger().info( - f"Time to eval: {self.eval_time - self.initial_time} s") - self.get_logger().info( - f"Prediction speed: {self.tokens / (end_time - self.eval_time)} t/s") + f"Prediction speed: {self.tokens / (end_time - self.eval_time)} t/s" + ) def main(): diff --git a/llama_demos/llama_demos/llama_embeddings_demo_node.py b/llama_demos/llama_demos/llama_embeddings_demo_node.py index 586b966..4e6dba9 100644 --- a/llama_demos/llama_demos/llama_embeddings_demo_node.py +++ b/llama_demos/llama_demos/llama_embeddings_demo_node.py @@ -36,9 +36,9 @@ def __init__(self) -> None: super().__init__("llama_embeddings_demo_node") self.declare_parameter( - "prompt", "This is the test to create embeddings using llama_ros") - self.prompt = self.get_parameter( - "prompt").get_parameter_value().string_value + "prompt", "This is the test to create embeddings using llama_ros" + ) + self.prompt = self.get_parameter("prompt").get_parameter_value().string_value self._llama_client = LlamaClientNode.get_instance() diff --git a/llama_demos/llama_demos/llama_rag_demo_node.py b/llama_demos/llama_demos/llama_rag_demo_node.py index 31a26c1..cb8552e 100644 --- a/llama_demos/llama_demos/llama_rag_demo_node.py +++ b/llama_demos/llama_demos/llama_rag_demo_node.py @@ -49,22 +49,22 @@ ) docs = loader.load() -text_splitter = RecursiveCharacterTextSplitter( - chunk_size=1000, chunk_overlap=200) +text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) splits = text_splitter.split_documents(docs) -vectorstore = Chroma.from_documents( - documents=splits, embedding=LlamaROSEmbeddings()) +vectorstore = Chroma.from_documents(documents=splits, embedding=LlamaROSEmbeddings()) # retrieve and generate using the relevant snippets of the blog retriever = vectorstore.as_retriever(search_kwargs={"k": 20}) # create prompt -prompt = ChatPromptTemplate.from_messages([ - SystemMessage("You are an AI assistant that answer questions."), - HumanMessagePromptTemplate.from_template( - "Taking into account the followin context:\n{context}\nAnswer this question: {question}" - ) -]) +prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage("You are an AI assistant that answer questions."), + HumanMessagePromptTemplate.from_template( + "Taking into account the followin context:\n{context}\nAnswer this question: {question}" + ), + ] +) # create rerank compression retriever compressor = LlamaROSReranker(top_n=3) @@ -79,8 +79,7 @@ def format_docs(docs): # create and use the chain rag_chain = ( - {"context": compression_retriever | format_docs, - "question": RunnablePassthrough()} + {"context": compression_retriever | format_docs, "question": RunnablePassthrough()} | prompt | ChatLlamaROS(temp=0.2) | StrOutputParser() diff --git a/llama_demos/llama_demos/llama_rerank_demo_node.py b/llama_demos/llama_demos/llama_rerank_demo_node.py index 7e96d59..5790b74 100644 --- a/llama_demos/llama_demos/llama_rerank_demo_node.py +++ b/llama_demos/llama_demos/llama_rerank_demo_node.py @@ -36,7 +36,7 @@ "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", - "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.", ] scores = LlamaClientNode.get_instance().rerank_documents(rerank_req).scores diff --git a/llama_demos/llama_demos/llava_demo_node.py b/llama_demos/llama_demos/llava_demo_node.py index 47631d6..2054cd6 100644 --- a/llama_demos/llama_demos/llava_demo_node.py +++ b/llama_demos/llama_demos/llava_demo_node.py @@ -44,18 +44,22 @@ def __init__(self) -> None: self.cv_bridge = CvBridge() self.declare_parameter( - "prompt", "Who is the character in the middle of the image?") - self.prompt = self.get_parameter( - "prompt").get_parameter_value().string_value + "prompt", "Who is the character in the middle of the image?" + ) + self.prompt = self.get_parameter("prompt").get_parameter_value().string_value self.declare_parameter("use_image", True) - self.use_image = self.get_parameter( - "use_image").get_parameter_value().bool_value + self.use_image = ( + self.get_parameter("use_image").get_parameter_value().bool_value + ) self.declare_parameter( - "image_url", "https://pics.filmaffinity.com/Dragon_Ball_Bola_de_Dragaon_Serie_de_TV-973171538-large.jpg") - self.image = self.load_image_from_url(self.get_parameter( - "image_url").get_parameter_value().string_value) + "image_url", + "https://pics.filmaffinity.com/Dragon_Ball_Bola_de_Dragaon_Serie_de_TV-973171538-large.jpg", + ) + self.image = self.load_image_from_url( + self.get_parameter("image_url").get_parameter_value().string_value + ) self.tokens = 0 self.initial_time = -1 @@ -65,8 +69,7 @@ def __init__(self) -> None: @staticmethod def load_image_from_url(url): - req = urllib.request.Request( - url, headers={"User-Agent": "Mozilla/5.0"}) + req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"}) response = urllib.request.urlopen(req) arr = np.asarray(bytearray(response.read()), dtype=np.uint8) img = cv2.imdecode(arr, -1) @@ -95,10 +98,10 @@ def send_prompt(self) -> None: self.get_logger().info("END") end_time = time.time() + self.get_logger().info(f"Time to eval: {self.eval_time - self.initial_time} s") self.get_logger().info( - f"Time to eval: {self.eval_time - self.initial_time} s") - self.get_logger().info( - f"Prediction speed: {self.tokens / (end_time - self.eval_time)} t/s") + f"Prediction speed: {self.tokens / (end_time - self.eval_time)} t/s" + ) def main(): diff --git a/llama_ros/llama_ros/langchain/chat_llama_ros.py b/llama_ros/llama_ros/langchain/chat_llama_ros.py index e0dc8d5..d24cfd0 100644 --- a/llama_ros/llama_ros/langchain/chat_llama_ros.py +++ b/llama_ros/llama_ros/langchain/chat_llama_ros.py @@ -48,8 +48,7 @@ def _llm_type(self) -> str: return "chatllamaros" def _messages_to_chat_messages( - self, - messages: List[BaseMessage] + self, messages: List[BaseMessage] ) -> tuple[FormatChatMessages.Request, Optional[str], Optional[np.ndarray]]: chat_messages = FormatChatMessages.Request() @@ -63,15 +62,18 @@ def _messages_to_chat_messages( if type(message.content) == str: chat_messages.messages.append( - Message(role=role, content=message.content)) + Message(role=role, content=message.content) + ) else: for single_content in message.content: if type(single_content) == str: chat_messages.messages.append( - Message(role=role, content=single_content)) + Message(role=role, content=single_content) + ) elif single_content["type"] == "text": chat_messages.messages.append( - Message(role=role, content=single_content["text"])) + Message(role=role, content=single_content["text"]) + ) elif single_content["type"] == "image_url": image_text = single_content["image_url"]["url"] if "data:image" in image_text: @@ -94,21 +96,21 @@ def _generate( llama_client = self.llama_client.get_instance() - chat_messages, image_url, image = self._messages_to_chat_messages( - messages) + chat_messages, image_url, image = self._messages_to_chat_messages(messages) formatted_prompt = llama_client.format_chat_prompt( - chat_messages).formatted_prompt + chat_messages + ).formatted_prompt goal_action = self._create_action_goal( - formatted_prompt, stop, image_url, image, **kwargs) + formatted_prompt, stop, image_url, image, **kwargs + ) result, status = LlamaClientNode.get_instance().generate_response(goal_action) if status != GoalStatus.STATUS_SUCCEEDED: return "" - generation = ChatGeneration( - message=AIMessage(content=result.response.text)) + generation = ChatGeneration(message=AIMessage(content=result.response.text)) return ChatResult(generations=[generation]) def _stream( @@ -121,17 +123,23 @@ def _stream( llama_client = self.llama_client.get_instance() - chat_messages, image_url, image = self._messages_to_chat_messages( - messages) + chat_messages, image_url, image = self._messages_to_chat_messages(messages) formatted_prompt = llama_client.format_chat_prompt( - chat_messages).formatted_prompt + chat_messages + ).formatted_prompt goal_action = self._create_action_goal( - formatted_prompt, stop, image_url, image, **kwargs) + formatted_prompt, stop, image_url, image, **kwargs + ) - for pt in LlamaClientNode.get_instance().generate_response(goal_action, stream=True): + for pt in LlamaClientNode.get_instance().generate_response( + goal_action, stream=True + ): if run_manager: - run_manager.on_llm_new_token(pt.text, verbose=self.verbose,) + run_manager.on_llm_new_token( + pt.text, + verbose=self.verbose, + ) yield ChatGenerationChunk(message=AIMessageChunk(content=pt.text)) diff --git a/llama_ros/llama_ros/langchain/chat_prompt_formatter.py b/llama_ros/llama_ros/langchain/chat_prompt_formatter.py index 67baa00..e207207 100644 --- a/llama_ros/llama_ros/langchain/chat_prompt_formatter.py +++ b/llama_ros/llama_ros/langchain/chat_prompt_formatter.py @@ -40,6 +40,7 @@ def ChatPromptFormatter(messages): output_msgs.append(new_msg) response = client.format_chat_prompt( - FormatChatMessages.Request(messages=output_msgs)) + FormatChatMessages.Request(messages=output_msgs) + ) return response.formatted_prompt diff --git a/llama_ros/llama_ros/langchain/llama_ros.py b/llama_ros/llama_ros/langchain/llama_ros.py index 83fdfc4..ea367ac 100644 --- a/llama_ros/llama_ros/langchain/llama_ros.py +++ b/llama_ros/llama_ros/langchain/llama_ros.py @@ -72,7 +72,10 @@ def _stream( for pt in LlamaClientNode.get_instance().generate_response(goal, stream=True): if run_manager: - run_manager.on_llm_new_token(pt.text, verbose=self.verbose,) + run_manager.on_llm_new_token( + pt.text, + verbose=self.verbose, + ) yield GenerationChunk(text=pt.text) diff --git a/llama_ros/llama_ros/langchain/llama_ros_common.py b/llama_ros/llama_ros/langchain/llama_ros_common.py index c0c6ded..ba23895 100644 --- a/llama_ros/llama_ros/langchain/llama_ros_common.py +++ b/llama_ros/llama_ros/langchain/llama_ros_common.py @@ -109,7 +109,8 @@ def _create_action_goal( if image_url and image is None: req = urllib.request.Request( - image_url, headers={"User-Agent": "Mozilla/5.0"}) + image_url, headers={"User-Agent": "Mozilla/5.0"} + ) response = urllib.request.urlopen(req) arr = np.asarray(bytearray(response.read()), dtype=np.uint8) image = cv2.imdecode(arr, -1) diff --git a/llama_ros/llama_ros/langchain/llama_ros_embeddings.py b/llama_ros/llama_ros/langchain/llama_ros_embeddings.py index ea0f1fa..d9db32a 100644 --- a/llama_ros/llama_ros/langchain/llama_ros_embeddings.py +++ b/llama_ros/llama_ros/langchain/llama_ros_embeddings.py @@ -50,8 +50,7 @@ def __call_generate_embedding_srv(self, text: str) -> List[int]: return self.llama_client.generate_embeddings(req).embeddings def embed_documents(self, texts: List[str]) -> List[List[float]]: - embeddings = [self.__call_generate_embedding_srv( - text) for text in texts] + embeddings = [self.__call_generate_embedding_srv(text) for text in texts] return [list(map(float, e)) for e in embeddings] def embed_query(self, text: str) -> List[float]: diff --git a/llama_ros/llama_ros/llama_client_node.py b/llama_ros/llama_ros/llama_client_node.py index e8c7603..db5a55a 100644 --- a/llama_ros/llama_ros/llama_client_node.py +++ b/llama_ros/llama_ros/llama_client_node.py @@ -79,43 +79,38 @@ def __init__(self, namespace: str = "llama") -> None: raise Exception("This class is a Singleton") super().__init__( - f"client_{str(uuid.uuid4()).replace('-', '_')}_node", namespace=namespace) + f"client_{str(uuid.uuid4()).replace('-', '_')}_node", namespace=namespace + ) self._action_client = ActionClient( self, GenerateResponse, "generate_response", - callback_group=self._callback_group + callback_group=self._callback_group, ) self._tokenize_srv_client = self.create_client( - Tokenize, - "tokenize", - callback_group=self._callback_group + Tokenize, "tokenize", callback_group=self._callback_group ) self._detokenize_srv_client = self.create_client( - Detokenize, - "detokenize", - callback_group=self._callback_group + Detokenize, "detokenize", callback_group=self._callback_group ) self._embeddings_srv_client = self.create_client( GenerateEmbeddings, "generate_embeddings", - callback_group=self._callback_group + callback_group=self._callback_group, ) self._rerank_srv_client = self.create_client( - RerankDocuments, - "rerank_documents", - callback_group=self._callback_group + RerankDocuments, "rerank_documents", callback_group=self._callback_group ) self._format_chat_srv_client = self.create_client( FormatChatMessages, "format_chat_prompt", - callback_group=self._callback_group + callback_group=self._callback_group, ) # executor @@ -132,15 +127,21 @@ def detokenize(self, req: Detokenize.Request) -> Detokenize.Response: self._detokenize_srv_client.wait_for_service() return self._detokenize_srv_client.call(req) - def generate_embeddings(self, req: GenerateEmbeddings.Request) -> GenerateEmbeddings.Response: + def generate_embeddings( + self, req: GenerateEmbeddings.Request + ) -> GenerateEmbeddings.Response: self._embeddings_srv_client.wait_for_service() return self._embeddings_srv_client.call(req) - def rerank_documents(self, req: RerankDocuments.Request) -> RerankDocuments.Response: + def rerank_documents( + self, req: RerankDocuments.Request + ) -> RerankDocuments.Response: self._rerank_srv_client.wait_for_service() return self._rerank_srv_client.call(req) - def format_chat_prompt(self, req: FormatChatMessages.Request) -> FormatChatMessages.Response: + def format_chat_prompt( + self, req: FormatChatMessages.Request + ) -> FormatChatMessages.Response: self._format_chat_srv_client.wait_for_service() return self._format_chat_srv_client.call(req) @@ -148,8 +149,11 @@ def generate_response( self, goal: GenerateResponse.Goal, feedback_cb: Callable = None, - stream: bool = False - ) -> Union[Tuple[GenerateResponse.Result, GoalStatus], Generator[PartialResponse, None, None]]: + stream: bool = False, + ) -> Union[ + Tuple[GenerateResponse.Result, GoalStatus], + Generator[PartialResponse, None, None], + ]: self._action_done = False self._action_result = None @@ -161,7 +165,8 @@ def generate_response( feedback_cb = self._feedback_callback send_goal_future = self._action_client.send_goal_async( - goal, feedback_callback=feedback_cb) + goal, feedback_callback=feedback_cb + ) send_goal_future.add_done_callback(self._goal_response_callback) # Wait for action to be done