forked from lmstudio-ai/mlx-engine
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
89 lines (75 loc) · 2.43 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import base64
from mlx_engine.generate import load_model, create_generator, tokenize
DEFAULT_PROMPT = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Explain the rules of sudoku<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="LM Studio mlx-engine inference script"
)
parser.add_argument(
"--model",
required=True,
type=str,
help="The path to the local model directory.",
)
parser.add_argument(
"--prompt",
default=DEFAULT_PROMPT,
type=str,
help="Message to be processed by the model",
)
parser.add_argument(
"--images",
type=str,
nargs="+",
help="Path of the images to process",
)
parser.add_argument(
"--stop-strings",
type=str,
nargs="+",
help="Strings that will stop the generation",
)
return parser
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
if __name__ == "__main__":
# Parse arguments
parser = setup_arg_parser()
args = parser.parse_args()
if isinstance(args.images, str):
args.images = [args.images]
# Load the model
model_path = args.model
model_kit = load_model(str(model_path), max_kv_size=4096, trust_remote_code=False)
# Tokenize the prompt
prompt = args.prompt
prompt_tokens = tokenize(model_kit, prompt)
# Handle optional images
images_base64 = []
if args.images:
if isinstance(args.images, str):
args.images = [args.images]
images_base64 = [image_to_base64(img_path) for img_path in args.images]
# Generate the response
generator = create_generator(
model_kit,
prompt_tokens,
None,
images_base64,
args.stop_strings,
{"max_tokens": 1024},
)
for generation_result in generator:
print(generation_result.text, end="", flush=True)
if generation_result.stop_condition:
print(
f"\n\nStopped generation due to: {generation_result.stop_condition.stop_reason}"
)
if generation_result.stop_condition.stop_string:
print(f"Stop string: {generation_result.stop_condition.stop_string}")
print()