Skip to content

Commit

Permalink
black python formatter applyed
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Oct 24, 2024
1 parent ff3e80c commit db60fdd
Show file tree
Hide file tree
Showing 24 changed files with 272 additions and 183 deletions.
95 changes: 61 additions & 34 deletions llama_bringup/launch/base.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,25 @@ def run_llama(context: LaunchContext, embedding, reranking):
"n_ctx": LaunchConfiguration("n_ctx", default=512),
"n_batch": LaunchConfiguration("n_batch", default=2048),
"n_ubatch": LaunchConfiguration("n_batch", default=512),

# GPU params
"n_gpu_layers": LaunchConfiguration("n_gpu_layers", default=0),
"split_mode": LaunchConfiguration("split_mode", default="layer"),
"main_gpu": LaunchConfiguration("main_gpu", default=0),
"tensor_split": LaunchConfiguration("tensor_split", default="[0.0]"),

# attn params
"grp_attn_n": LaunchConfiguration("grp_attn_n", default=1),
"grp_attn_w": LaunchConfiguration("grp_attn_w", default=512),

# rope params
"rope_freq_base": LaunchConfiguration("rope_freq_base", default=0.0),
"rope_freq_scale": LaunchConfiguration("rope_freq_scale", default=0.0),
"rope_scaling_type": LaunchConfiguration("rope_scaling_type", default=""),

# yarn params
"yarn_ext_factor": LaunchConfiguration("yarn_ext_factor", default=-1.0),
"yarn_attn_factor": LaunchConfiguration("yarn_attn_factor", default=1.0),
"yarn_beta_fast": LaunchConfiguration("yarn_beta_fast", default=32.0),
"yarn_beta_slow": LaunchConfiguration("yarn_beta_slow", default=1.0),
"yarn_orig_ctx": LaunchConfiguration("yarn_orig_ctx", default=0),

# bool params
"embedding": embedding,
"reranking": reranking,
"logits_all": LaunchConfiguration("logits_all", default=False),
Expand All @@ -68,44 +68,67 @@ def run_llama(context: LaunchContext, embedding, reranking):
"warmup": LaunchConfiguration("warmup", default=True),
"check_tensors": LaunchConfiguration("check_tensors", default=False),
"flash_attn": LaunchConfiguration("flash_attn", default=False),

# cache params
"no_kv_offload": LaunchConfiguration("no_kv_offload", default=False),
"cache_type_k": LaunchConfiguration("cache_type_k", default="f16"),
"cache_type_v": LaunchConfiguration("cache_type_v", default="f16"),

# CPU params
"n_threads": LaunchConfiguration("n_threads", default=1),
"cpu_mask": LaunchConfiguration("cpu_mask", default=""),
"cpu_range": LaunchConfiguration("cpu_range", default=""),
"priority": LaunchConfiguration("priority", default="normal"),
"strict_cpu": LaunchConfiguration("strict_cpu", default=False),
"poll": LaunchConfiguration("poll", default=50),

# batch CPU params
"n_threads_batch": LaunchConfiguration("n_threads_batch", default=1),
"cpu_mask_batch": LaunchConfiguration("cpu_mask_batch", default=""),
"cpu_range_batch": LaunchConfiguration("cpu_range_batch", default=""),
"priority_batch": LaunchConfiguration("priority_batch", default="normal"),
"strict_cpu_batch": LaunchConfiguration("strict_cpu_batch", default=False),
"poll_batch": LaunchConfiguration("poll_batch", default=50),

# switch context params
"n_predict": LaunchConfiguration("n_predict", default=128),
"n_keep": LaunchConfiguration("n_keep", default=-1),

# paths params
"model": LaunchConfiguration("model", default=""),
"lora_adapters": ParameterValue(LaunchConfiguration("lora_adapters", default=[""]), value_type=List[str]),
"lora_adapters_scales": ParameterValue(LaunchConfiguration("lora_adapters_scales", default=[0.0]), value_type=List[float]),
"lora_adapters": ParameterValue(
LaunchConfiguration("lora_adapters", default=[""]), value_type=List[str]
),
"lora_adapters_scales": ParameterValue(
LaunchConfiguration("lora_adapters_scales", default=[0.0]),
value_type=List[float],
),
"mmproj": LaunchConfiguration("mmproj", default=""),
"numa": LaunchConfiguration("numa", default="none"),
"pooling_type": LaunchConfiguration("pooling_type", default=""),

"prefix": ParameterValue(LaunchConfiguration("prefix", default=""), value_type=str),
"suffix": ParameterValue(LaunchConfiguration("suffix", default=""), value_type=str),
"stopping_words": ParameterValue(LaunchConfiguration("stopping_words", default=[""]), value_type=List[str]),
"image_prefix": ParameterValue(LaunchConfiguration("image_prefix", default=""), value_type=str),
"image_suffix": ParameterValue(LaunchConfiguration("image_suffix", default=""), value_type=str),
"image_text": ParameterValue(LaunchConfiguration("image_text", default="<image>"), value_type=str),

"system_prompt": ParameterValue(LaunchConfiguration("system_prompt", default=""), value_type=str),
"system_prompt_file": ParameterValue(LaunchConfiguration("system_prompt_file", default=""), value_type=str),
# prefix/suffix
"prefix": ParameterValue(
LaunchConfiguration("prefix", default=""), value_type=str
),
"suffix": ParameterValue(
LaunchConfiguration("suffix", default=""), value_type=str
),
"stopping_words": ParameterValue(
LaunchConfiguration("stopping_words", default=[""]),
value_type=List[str],
),
"image_prefix": ParameterValue(
LaunchConfiguration("image_prefix", default=""), value_type=str
),
"image_suffix": ParameterValue(
LaunchConfiguration("image_suffix", default=""), value_type=str
),
"image_text": ParameterValue(
LaunchConfiguration("image_text", default="<image>"), value_type=str
),
# prompt params
"system_prompt": ParameterValue(
LaunchConfiguration("system_prompt", default=""), value_type=str
),
"system_prompt_file": ParameterValue(
LaunchConfiguration("system_prompt_file", default=""), value_type=str
),
# debug
"debug": LaunchConfiguration("debug", default=True),
}

Expand All @@ -123,32 +146,36 @@ def run_llama(context: LaunchContext, embedding, reranking):
name=llama_node_name,
namespace="llama",
parameters=[params],
condition=UnlessCondition(PythonExpression(
[LaunchConfiguration("use_llava")]))
condition=UnlessCondition(
PythonExpression([LaunchConfiguration("use_llava")])
),
), Node(
package="llama_ros",
executable="llava_node",
name="llava_node",
namespace="llama",
parameters=[params],
condition=IfCondition(PythonExpression(
[LaunchConfiguration("use_llava")]))
condition=IfCondition(PythonExpression([LaunchConfiguration("use_llava")])),
)

embedding = LaunchConfiguration("embedding")
embedding_cmd = DeclareLaunchArgument(
"embedding",
default_value="False",
description="Whether the model is an embedding model")
description="Whether the model is an embedding model",
)

reranking = LaunchConfiguration("reranking")
reranking_cmd = DeclareLaunchArgument(
"reranking",
default_value="False",
description="Whether the model is an reranking model")

return LaunchDescription([
embedding_cmd,
reranking_cmd,
OpaqueFunction(function=run_llama, args=[embedding, reranking])
])
description="Whether the model is an reranking model",
)

return LaunchDescription(
[
embedding_cmd,
reranking_cmd,
OpaqueFunction(function=run_llama, args=[embedding, reranking]),
]
)
15 changes: 11 additions & 4 deletions llama_bringup/launch/minicpm-2.6.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@


def generate_launch_description():
return LaunchDescription([
create_llama_launch_from_yaml(os.path.join(
get_package_share_directory("llama_bringup"), "models", "MiniCPM-2.6.yaml"))
])
return LaunchDescription(
[
create_llama_launch_from_yaml(
os.path.join(
get_package_share_directory("llama_bringup"),
"models",
"MiniCPM-2.6.yaml",
)
)
]
)
15 changes: 11 additions & 4 deletions llama_bringup/launch/spaetzle.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@


def generate_launch_description():
return LaunchDescription([
create_llama_launch_from_yaml(os.path.join(
get_package_share_directory("llama_bringup"), "models", "Spaetzle.yaml"))
])
return LaunchDescription(
[
create_llama_launch_from_yaml(
os.path.join(
get_package_share_directory("llama_bringup"),
"models",
"Spaetzle.yaml",
)
)
]
)
49 changes: 33 additions & 16 deletions llama_bringup/llama_bringup/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,41 @@

def download_model(repo: str, file: str) -> str:

match = re.search(r'-(\d+)-of-(\d+)\.gguf', file)
match = re.search(r"-(\d+)-of-(\d+)\.gguf", file)

if match:
total_shards = int(match.group(2))
base_name = file[:match.start()]
base_name = file[: match.start()]

# download shards
for i in range(1, total_shards + 1):
shard_file = f"{base_name}-{i:05d}-of-{total_shards:05d}.gguf"
hf_hub_download(repo_id=repo, filename=shard_file,
force_download=False)
hf_hub_download(repo_id=repo, filename=shard_file, force_download=False)

# return first shard
return hf_hub_download(
repo_id=repo,
filename=f"{base_name}-00001-of-{total_shards:05d}.gguf",
force_download=False
force_download=False,
)

return hf_hub_download(repo_id=repo, filename=file, force_download=False)


def load_prompt_type(prompt_file_name: str) -> Tuple:
file_path = os.path.join(get_package_share_directory(
"llama_bringup"), "prompts", f"{prompt_file_name}.yaml")
file_path = os.path.join(
get_package_share_directory("llama_bringup"),
"prompts",
f"{prompt_file_name}.yaml",
)
with open(file_path, "r") as file:
yaml_data = yaml.safe_load(file)
return yaml_data["prefix"], yaml_data["suffix"], yaml_data["stopping_words"], yaml_data["system_prompt"]
return (
yaml_data["prefix"],
yaml_data["suffix"],
yaml_data["stopping_words"],
yaml_data["system_prompt"],
)


def create_llama_launch_from_yaml(file_path: str) -> IncludeLaunchDescription:
Expand All @@ -70,8 +77,11 @@ def create_llama_launch_from_yaml(file_path: str) -> IncludeLaunchDescription:


def create_llama_launch(**kwargs) -> IncludeLaunchDescription:
prompt_data = load_prompt_type(kwargs["system_prompt_type"]) if kwargs.get(
"system_prompt_type") else ("", "", [], "")
prompt_data = (
load_prompt_type(kwargs["system_prompt_type"])
if kwargs.get("system_prompt_type")
else ("", "", [], "")
)
kwargs["prefix"] = kwargs.get("prefix", prompt_data[0])
kwargs["suffix"] = kwargs.get("suffix", prompt_data[1])
kwargs["system_prompt"] = kwargs.get("system_prompt", prompt_data[3])
Expand All @@ -83,9 +93,14 @@ def create_llama_launch(**kwargs) -> IncludeLaunchDescription:

# load models
for key in ["model", "mmproj"]:
if not kwargs.get(key) and kwargs.get(f"{key}_repo") and kwargs.get(f"{key}_filename"):
if (
not kwargs.get(key)
and kwargs.get(f"{key}_repo")
and kwargs.get(f"{key}_filename")
):
kwargs[key] = download_model(
kwargs[f"{key}_repo"], kwargs[f"{key}_filename"])
kwargs[f"{key}_repo"], kwargs[f"{key}_filename"]
)

# load lora adapters
lora_adapters = []
Expand Down Expand Up @@ -122,8 +137,10 @@ def create_llama_launch(**kwargs) -> IncludeLaunchDescription:
kwargs["use_llava"] = False

return IncludeLaunchDescription(
PythonLaunchDescriptionSource(os.path.join(
get_package_share_directory("llama_bringup"), "launch", "base.launch.py")),
launch_arguments={key: str(value)
for key, value in kwargs.items()}.items()
PythonLaunchDescriptionSource(
os.path.join(
get_package_share_directory("llama_bringup"), "launch", "base.launch.py"
)
),
launch_arguments={key: str(value) for key, value in kwargs.items()}.items(),
)
12 changes: 5 additions & 7 deletions llama_cli/llama_cli/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,15 @@ def launch_llm(file_path: str) -> None:
print(f"File '{file_path}' does not exists")
return

ld = LaunchDescription([
create_llama_launch_from_yaml(file_path)
])
ld = LaunchDescription([create_llama_launch_from_yaml(file_path)])
ls = LaunchService()
ls.include_launch_description(ld)
ls.run()


def prompt_llm(prompt: str, reset: bool = False,
temp: float = 0.8, image_url: str = "") -> None:
def prompt_llm(
prompt: str, reset: bool = False, temp: float = 0.8, image_url: str = ""
) -> None:

rclpy.init()
llama_client = LlamaClientNode()
Expand All @@ -71,8 +70,7 @@ def prompt_llm(prompt: str, reset: bool = False,
goal.sampling_config.temp = temp

if image_url:
req = urllib.request.Request(
image_url, headers={"User-Agent": "Mozilla/5.0"})
req = urllib.request.Request(image_url, headers={"User-Agent": "Mozilla/5.0"})
response = urllib.request.urlopen(req)
arr = np.asarray(bytearray(response.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
Expand Down
5 changes: 3 additions & 2 deletions llama_cli/llama_cli/command/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ class LlamaCommand(CommandExtension):

def add_arguments(self, parser, cli_name):
self._subparser = parser
add_subparsers_on_demand(parser, cli_name, "_verb",
"llama_cli.verb", required=False)
add_subparsers_on_demand(
parser, cli_name, "_verb", "llama_cli.verb", required=False
)

def main(self, *, parser, args):
if not hasattr(args, "_verb"):
Expand Down
3 changes: 1 addition & 2 deletions llama_cli/llama_cli/verb/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
class LaunchVerb(VerbExtension):

def add_arguments(self, parser, cli_name):
arg = parser.add_argument(
"file_path", help="path to the YAML of the LLM")
arg = parser.add_argument("file_path", help="path to the YAML of the LLM")

def main(self, *, args):
launch_llm(args.file_path)
29 changes: 17 additions & 12 deletions llama_cli/llama_cli/verb/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,26 @@
class PromptVerb(VerbExtension):

def add_arguments(self, parser, cli_name):
arg = parser.add_argument(
"prompt", help="prompt text for the LLM")
arg = parser.add_argument("prompt", help="prompt text for the LLM")
parser.add_argument(
"-r", "--reset",
"-r",
"--reset",
action="store_true",
help="Whether to reset the LLM and its context before prompting")
help="Whether to reset the LLM and its context before prompting",
)
parser.add_argument(
"-t", "--temp",
metavar="N", type=positive_float, default=0.8,
help="Temperature value (default: 0.8)")
"-t",
"--temp",
metavar="N",
type=positive_float,
default=0.8,
help="Temperature value (default: 0.8)",
)
parser.add_argument(
"--image-url",
type=str, default="",
help="Image URL to sent to the VLM")
"--image-url", type=str, default="", help="Image URL to sent to the VLM"
)

def main(self, *, args):
prompt_llm(args.prompt, reset=args.reset,
temp=args.temp, image_url=args.image_url)
prompt_llm(
args.prompt, reset=args.reset, temp=args.temp, image_url=args.image_url
)
Loading

0 comments on commit db60fdd

Please sign in to comment.