Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kenarsa committed May 16, 2024
1 parent bef7789 commit cb5ff80
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 44 deletions.
6 changes: 6 additions & 0 deletions recipes/llm-voice-assistant/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ python main.py --help

## Profiling

### Realtime Factor

### Token per Second

### Latency

```console
python3 main.py --access_key ${ACCESS_KEY} --picollm_model_path ${PICOLLM_MODEL_PATH} --profile
```
Expand Down
68 changes: 24 additions & 44 deletions recipes/llm-voice-assistant/python/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import signal
import time
from argparse import ArgumentParser
from enum import Enum
from multiprocessing import (
Pipe,
Process,
Expand All @@ -18,23 +17,6 @@
from pvrecorder import PvRecorder


class Logger:
class Levels(Enum):
DEBUG = 'DEBUG'
INFO = 'INFO'

def __init__(self, level: 'Logger.Levels' = Levels.INFO) -> None:
self._level = level

def debug(self, message: str, end: str = '\n') -> None:
if self._level is self.Levels.DEBUG:
print(message, end=end, flush=True)

# noinspection PyMethodMayBeStatic
def info(self, message: str, end: str = '\n') -> None:
print(message, end=end, flush=True)


class RTFProfiler:
def __init__(self, sample_rate: int) -> None:
self._sample_rate = sample_rate
Expand Down Expand Up @@ -234,11 +216,7 @@ def main() -> None:
default=0.,
help="Duration of the synthesized audio to buffer before streaming it out. A higher value helps slower "
"(e.g., Raspberry Pi) to keep up with real-time at the cost of increasing the initial delay.")
parser.add_argument(
'--log_level',
choices=[x.value for x in Logger.Levels],
default=Logger.Levels.INFO.value,
help='Log level verbosity.')
parser.add_argument('--profile', action='store_true', help='Show runtime profiling information.')
args = parser.parse_args()

access_key = args.access_key
Expand All @@ -252,34 +230,32 @@ def main() -> None:
picollm_temperature = args.picollm_temperature
picollm_top_p = args.picollm_top_p
orca_warmup_sec = args.orca_warmup_sec
log_level = Logger.Levels(args.log_level)

log = Logger(log_level)
profile = args.profile

if keyword_model_path is None:
porcupine = pvporcupine.create(access_key=access_key, keywords=['picovoice'])
else:
porcupine = pvporcupine.create(access_key=access_key, keyword_paths=[keyword_model_path])
log.info(f"→ Porcupine V{porcupine.version}")
print(f"→ Porcupine V{porcupine.version}")

cheetah = pvcheetah.create(access_key=access_key, endpoint_duration_sec=cheetah_endpoint_duration_sec)
log.info(f"→ Cheetah V{cheetah.version}")
print(f"→ Cheetah V{cheetah.version}")

pllm = picollm.create(access_key=access_key, model_path=picollm_model_path, device=picollm_device)
dialog = pllm.get_dialog()
log.info(f"→ picoLLM V{pllm.version} <{pllm.model}>")
print(f"→ picoLLM V{pllm.version} <{pllm.model}>")

main_connection, orca_process_connection = Pipe()
orca_process = Process(target=orca_worker, args=(access_key, orca_process_connection, orca_warmup_sec))
orca_process.start()
while not main_connection.poll():
time.sleep(0.01)
log.info(f"→ Orca V{main_connection.recv()['version']}")
print(f"→ Orca V{main_connection.recv()['version']}")

mic = PvRecorder(frame_length=porcupine.frame_length)
mic.start()

log.info("\n$ Say `Picovoice` ...")
print(f"\n$ Say {'Picovoice' if keyword_model_path is None else 'the wake word'} ...")

stop = [False]

Expand All @@ -306,24 +282,26 @@ def handler(_, __) -> None:
wake_word_detected = porcupine.process(pcm) == 0
porcupine_profiler.tock(pcm)
if wake_word_detected:
log.debug(f"[Porcupine RTF: {porcupine_profiler.rtf():.3f}]")
log.info("$ Wake word detected, utter your request or question ...\n")
log.info("User > ", end='')
if profile:
print(f"[Porcupine RTF: {porcupine_profiler.rtf():.3f}]")
print("$ Wake word detected, utter your request or question ...\n")
print("User > ", end='', flush=True)
elif not endpoint_reached:
pcm = mic.read()
cheetah_profiler.tick()
partial_transcript, endpoint_reached = cheetah.process(pcm)
cheetah_profiler.tock(pcm)
log.info(partial_transcript, end='')
print(partial_transcript, end='', flush=True)
user_request += partial_transcript
if endpoint_reached:
utterance_end_sec = time.time()
cheetah_profiler.tick()
remaining_transcript = cheetah.flush()
cheetah_profiler.tock()
user_request += remaining_transcript
log.info(remaining_transcript, end='\n\n')
log.debug(f"[Cheetah RTF: {cheetah_profiler.rtf():.3f}]")
print(remaining_transcript, end='\n\n')
if profile:
print(f"[Cheetah RTF: {cheetah_profiler.rtf():.3f}]")
else:
dialog.add_human_request(user_request)

Expand All @@ -333,9 +311,9 @@ def llm_callback(text: str) -> None:
picollm_profiler.tock()
main_connection.send(
{'command': 'synthesize', 'text': text, 'utterance_end_sec': utterance_end_sec})
log.info(text, end='')
print(text, end='', flush=True)

log.info("\nLLM > ", end='')
print("\nLLM > ", end='', flush=True)
res = pllm.generate(
prompt=dialog.prompt(),
completion_token_limit=picollm_completion_token_limit,
Expand All @@ -345,23 +323,25 @@ def llm_callback(text: str) -> None:
top_p=picollm_top_p,
stream_callback=llm_callback)
main_connection.send({'command': 'flush'})
log.info('\n')
print('\n')
dialog.add_llm_response(res.completion)
log.debug(f"[picoLLM TPS: {picollm_profiler.tps():.2f}]")
if profile:
print(f"[picoLLM TPS: {picollm_profiler.tps():.2f}]")

while not main_connection.poll():
time.sleep(0.01)
message = main_connection.recv()
log.debug(f"[Orca RTF: {message['rtf']:.2f}]")
log.debug(f"[Delay: {message['delay']:.2f} sec]")
if profile:
print(f"[Orca RTF: {message['rtf']:.2f}]")
print(f"[Delay: {message['delay']:.2f} sec]")
while not main_connection.poll():
time.sleep(0.01)
assert main_connection.recv()['done']

wake_word_detected = False
user_request = ''
endpoint_reached = False
log.info("\n$ Say `Picovoice` ...")
print(f"\n$ Say {'Picovoice' if keyword_model_path is None else 'the wake word'} ...")
finally:
main_connection.send({'command': 'close'})
mic.delete()
Expand Down

0 comments on commit cb5ff80

Please sign in to comment.