Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove audio resampling and add timing context support #33

Merged
merged 4 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions neon_iris/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ovos_utils.json_helper import merge_dict
from pika.exceptions import StreamLostError
from neon_utils.configuration_utils import get_neon_user_config
from neon_utils.metrics_utils import Stopwatch
from neon_utils.mq_utils import NeonMQHandler
from neon_utils.socket_utils import b64_to_dict
from neon_utils.file_utils import decode_base64_string_to_file, \
Expand All @@ -49,6 +50,8 @@
from ovos_utils.xdg_utils import xdg_config_home, xdg_cache_home
from ovos_config.config import Configuration

_stopwatch = Stopwatch()


class NeonAIClient:
def __init__(self, mq_config: dict = None, config_dir: str = None):
Expand Down Expand Up @@ -128,10 +131,24 @@ def handle_neon_response(self, channel, method, _, body):
Override this method to handle Neon Responses
"""
channel.basic_ack(delivery_tag=method.delivery_tag)
response = b64_to_dict(body)
recv_time = time()
with _stopwatch:
response = b64_to_dict(body)
LOG.debug(f"Message deserialized in {_stopwatch.time}s")
message = Message(response.get('msg_type'), response.get('data'),
response.get('context'))
LOG.info(message.msg_type)

# Get timing data and log
message.context.setdefault("timing", {})
resp_time = message.context['timing'].get('response_sent', recv_time)
if recv_time != resp_time:
transit_time = recv_time - resp_time
message.context['timing']['client_from_core'] = transit_time
LOG.debug(f"Response MQ transit time={transit_time}")
handling_time = recv_time - message.context['timing'].get('client_sent',
recv_time)
LOG.info(f"{message.msg_type} handled in {handling_time}")
LOG.debug(f"{pformat(message.context['timing'])}")
if message.msg_type == "klat.response":
LOG.info("Handling klat response event")
self.handle_klat_response(message)
Expand Down Expand Up @@ -267,6 +284,7 @@ def _build_message(self, msg_type: str, data: dict,
"ident": ident or str(time()),
"username": username,
"user_profiles": user_profiles,
"timing": {},
"mq": {"routing_key": self.uid,
"message_id": self.connection.create_unique_id()}
})
Expand Down Expand Up @@ -305,6 +323,11 @@ def _send_audio(self, audio_file: str, lang: str,

def _send_serialized_message(self, serialized: dict):
try:
serialized['context']['timing']['client_sent'] = time()
if serialized['context']['timing'].get('gradio_sent'):
serialized['context']['timing']['iris_input_handling'] = \
serialized['context']['timing']['client_sent'] - \
serialized['context']['timing']['gradio_sent']
self.connection.emit_mq_message(
self._connection.connection,
queue="neon_chat_api_request",
Expand Down
55 changes: 12 additions & 43 deletions neon_iris/web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from os import makedirs
from os.path import isfile, join, isdir
from time import time
from typing import List, Optional, Dict
from typing import List, Dict
from uuid import uuid4

import gradio
Expand Down Expand Up @@ -118,55 +118,18 @@ def update_profile(self, stt_lang: str, tts_lang: str, tts_lang_2: str,
LOG.info(f"Updated profile for: {session_id}")
return session_id

def send_audio(self, audio_file: str, lang: str = "en-us",
username: Optional[str] = None,
user_profiles: Optional[list] = None,
context: Optional[dict] = None):
"""
@param audio_file: path to wav audio file to send to speech module
@param lang: language code associated with request
@param username: username associated with request
@param user_profiles: user profiles expecting a response
"""
# TODO: Audio conversion is really slow here. check ovos-stt-http-server
audio_file = self.convert_audio(audio_file)
self._send_audio(audio_file, lang, username, user_profiles, context)

def convert_audio(self, audio_file: str, target_sr=16000, target_channels=1,
dtype='int16') -> str:
"""
@param audio_file: path to audio file to convert for speech model
@returns: path to converted audio file
"""
# Load the audio file
y, sr = librosa.load(audio_file, sr=None, mono=False) # Load without changing sample rate or channels

# If the file has more than one channel, mix it down to one channel
if y.ndim > 1 and target_channels == 1:
y = librosa.to_mono(y)

# Resample the audio to the target sample rate
y_resampled = librosa.resample(y, orig_sr=sr, target_sr=target_sr)

# Ensure the audio array is in the correct format (int16 for 2-byte samples)
y_resampled = (y_resampled * (2 ** (8 * 2 - 1))).astype(dtype)

output_path = join(self._audio_path, f"{time()}.wav")
# Save the audio file with the new sample rate and sample width
sf.write(output_path, y_resampled, target_sr, format='WAV', subtype='PCM_16')
LOG.info(f"Converted audio file to {output_path}")
return output_path

def on_user_input(self, utterance: str, *args, **kwargs) -> str:
"""
Callback to handle textual user input
@param utterance: String utterance submitted by the user
@returns: String response from Neon (or "ERROR")
"""
input_time = time()
LOG.debug(f"Input received")
if not self._await_response.wait(30):
LOG.error("Previous response not completed after 30 seconds")
LOG.debug(f"args={args}|kwargs={kwargs}")
in_queue = time() - input_time
self._await_response.clear()
self._response = None
gradio_id = args[2]
Expand All @@ -175,13 +138,19 @@ def on_user_input(self, utterance: str, *args, **kwargs) -> str:
LOG.info(f"Sending utterance: {utterance} with lang: {lang}")
self.send_utterance(utterance, lang, username=gradio_id,
user_profiles=[self._profiles[gradio_id]],
context={"gradio": {"session": gradio_id}})
context={"gradio": {"session": gradio_id},
"timing": {"wait_in_queue": in_queue,
"gradio_sent": time()}})
else:
LOG.info(f"Sending audio: {args[1]} with lang: {lang}")
self.send_audio(args[1], lang, username=gradio_id,
user_profiles=[self._profiles[gradio_id]],
context={"gradio": {"session": gradio_id}})
self._await_response.wait(30)
context={"gradio": {"session": gradio_id},
"timing": {"wait_in_queue": in_queue,
"gradio_sent": time()}})
if not self._await_response.wait(30):
LOG.error("No response received after 30s")
self._await_response.set()
self._response = self._response or "ERROR"
LOG.info(f"Got response={self._response}")
return self._response
Expand Down
Loading