Skip to content

Commit

Permalink
Support Audio Input, Update docs and logging (#46)
Browse files Browse the repository at this point in the history
* Add validation template for `neon.audio_input` Message
Add logging around backwards-compat and default value handling
Update message validation to use `msg_type` directly
Update STT response handling to return more data to caller
Update formatting

* Troubleshooting validation errors
  • Loading branch information
NeonDaniel authored Sep 13, 2023
1 parent 8312deb commit f67c1fd
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 74 deletions.
11 changes: 7 additions & 4 deletions neon_messagebus_mq_connector/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,20 @@
import sys

from typing import Optional
from neon_utils.logger import LOG
from ovos_utils.log import LOG

from neon_messagebus_mq_connector.config import Configuration
from neon_messagebus_mq_connector import ChatAPIProxy


def _get_default_config() -> dict:
try:
return Configuration(
from_files=[os.environ.get('CHAT_API_PROXY_CONFIG',
'config.json')]).config_data
legacy_config_file = os.environ.get('CHAT_API_PROXY_CONFIG',
'config.json')
if os.path.isfile(legacy_config_file):
LOG.info(f"Using legacy configuration from {legacy_config_file}")
return Configuration(
from_files=[legacy_config_file]).config_data
except Exception as e:
LOG.error(e)
return dict()
Expand Down
173 changes: 103 additions & 70 deletions neon_messagebus_mq_connector/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,29 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import time
import pika

from typing import List, Type, Tuple

import pika
from ovos_bus_client.client import MessageBusClient
from ovos_bus_client.message import Message

from ovos_utils.log import LOG, log_deprecation
from neon_utils.socket_utils import b64_to_dict
from ovos_config.config import Configuration
from neon_mq_connector.connector import MQConnector
from pika.channel import Channel
from pydantic import ValidationError

from .enums import NeonResponseTypes
from .messages import templates, BaseModel
from neon_messagebus_mq_connector.enums import NeonResponseTypes
from neon_messagebus_mq_connector.messages import templates, BaseModel


class ChatAPIProxy(MQConnector):
"""
Proxy module for establishing connection between PyKlatchat and neon chat api"""
Proxy module for establishing connection between Neon Core and an MQ Broker
"""

def __init__(self, config: dict, service_name: str):
config = config or Configuration()
Expand All @@ -55,7 +57,7 @@ def __init__(self, config: dict, service_name: str):
if config.get("MESSAGEBUS"):
log_deprecation("MESSAGEBUS config is deprecated. use `websocket`",
"1.0.0")
self.bus_config = config.get("websocket")
self.bus_config = config.get("MESSAGEBUS")
self._vhost = '/neon_chat_api'
self._bus = None
self.connect_bus()
Expand Down Expand Up @@ -89,9 +91,8 @@ def register_bus_handlers(self):

def connect_bus(self, refresh: bool = False):
"""
Convenience method for establishing connection to message bus
:param refresh: To refresh existing connection
Convenience method for establishing connection to message bus
:param refresh: To refresh existing connection
"""
if not self._bus or refresh:
self._bus = MessageBusClient(host=self.bus_config['host'],
Expand All @@ -105,36 +106,40 @@ def connect_bus(self, refresh: bool = False):
@property
def bus(self) -> MessageBusClient:
"""
Connects to Message Bus if no connection was established
:return: connected message bus client instance
Connects to Message Bus if no connection was established
:return: connected message bus client instance
"""
if not self._bus:
self.connect_bus()
return self._bus

def handle_neon_message(self, message: Message):
"""
Handles responses from Neon Core
:param message: Received Message object
Handles responses from Neon Core, optionally reformatting response data
before forwarding to the MQ bus.
:param message: Received Message object
"""

if not message.data:
message.data['msg'] = 'Failed to get response from Neon'
message.context.setdefault('klat_data', {})
if message.msg_type == 'neon.get_tts.response':
body = self.format_response(response_type=NeonResponseTypes.TTS, message=message)
message.context['klat_data'].setdefault('routing_key', 'neon_tts_response')
body = self.format_response(response_type=NeonResponseTypes.TTS,
message=message)
message.context['klat_data'].setdefault('routing_key',
'neon_tts_response')
elif message.msg_type == 'neon.get_stt.response':
body = self.format_response(response_type=NeonResponseTypes.STT, message=message)
message.context['klat_data'].setdefault('routing_key', 'neon_stt_response')
body = self.format_response(response_type=NeonResponseTypes.STT,
message=message)
message.context['klat_data'].setdefault('routing_key',
'neon_stt_response')
else:
body = {'msg_type': message.msg_type,
'data': message.data, 'context': message.context}
LOG.debug(f'Received neon response body: {body}')
if not body:
LOG.warning('Something went wrong while formatting - received empty body')
LOG.warning('Something went wrong while formatting - '
f'received empty body for {message.msg_type}')
else:
routing_key = message.context.get("mq",
{}).get("routing_key",
Expand All @@ -157,67 +162,92 @@ def handle_neon_profile_update(self, message: Message):
f"user={message.data['profile']['user']['username']}")

@staticmethod
def __validate_message_templates(msg_data: dict, message_templates: List[Type[BaseModel]] = None) -> Tuple[str, dict]:
def __validate_message_templates(
msg_data: dict,
message_templates: List[Type[BaseModel]] = None) \
-> Tuple[str, dict]:
"""
Validate selected pydantic message templates into provided message data
Validate selected pydantic message templates into provided message data
:param msg_data: Message data to fetch
:param message_templates: list of pydantic templates to fetch into data
:param msg_data: Message data to fetch
:param message_templates: list of pydantic templates to fetch into data
:returns tuple containing 2 values:
1) validation error if detected;
2) fetched message data;
:returns tuple containing 2 values:
1) validation error if detected;
2) fetched message data;
"""

if not message_templates:
LOG.warning('No matching templates found, skipping template fetching')
LOG.warning('No matching templates found, '
'skipping template fetching')
return '', msg_data

LOG.debug('Initiating template validation')
LOG.debug(f'Initiating template validation with {message_templates}')
for message_template in message_templates:
try:
msg_data = message_template(**msg_data).dict()
except (ValueError, ValidationError) as err:
LOG.error(f'Failed to validate {msg_data} with template = {message_template.__name__}, exception={err}')
LOG.error(f'Failed to validate {msg_data["msg_type"]} with template = '
f'{message_template.__name__}, exception={err}')
return str(err), msg_data
LOG.debug('Template validation completed successfully')
return '', msg_data

@classmethod
def validate_request(cls, msg_data: dict):
"""
Fetches the relevant template models and validates provided message data iteratively through them
Fetches the relevant template models and validates provided message data
iteratively through them
:param msg_data: message data for validation
:param msg_data: message data for validation
:return: validation details(None if validation passed),
input data with proper data types and filled default fields
:return: validation details(None if validation passed),
input data with proper data types and filled default fields
"""

requested_templates = msg_data.get("context", {}).get("request_skills") or ["recognizer"]
message_templates = []

for requested_template in requested_templates:
matching_template_model = templates.get(requested_template)
if not matching_template_model:
LOG.warning(f'Template under keyword "{requested_template}" does not exist')
else:
message_templates.append(matching_template_model)

detected_error, msg_data = cls.__validate_message_templates(msg_data=msg_data,
message_templates=message_templates)
msg_type = msg_data.get('msg_type')
if msg_type == "neon.get_stt":
message_templates = [templates.get("stt")]
elif msg_type == "neon.audio_input":
message_templates = [templates.get("audio_input")]
elif msg_type == "recognizer_loop:utterance":
message_templates = [templates.get("recognizer")]
elif msg_type == "neon.get_tts":
message_templates = [templates.get("tts")]
elif msg_data.get("context", {}).get("request_skills"):
LOG.warning(f"Unknown input message type: {msg_type}")
requested_templates = msg_data["context"]["request_skills"]
message_templates = []

for requested_template in requested_templates:
matching_template_model = templates.get(requested_template)
if not matching_template_model:
LOG.warning(f'Template under keyword '
f'"{requested_template}" does not exist')
else:
message_templates.append(matching_template_model)
else:
raise ValueError(f"Unable to validate input message: {msg_data}")
detected_error, msg_data = cls.__validate_message_templates(
msg_data=msg_data, message_templates=message_templates)
return detected_error, msg_data

def validate_message_context(self, message: Message) -> bool:
""" Validates message context so its relevant data could be fetched once received response """
@staticmethod
def validate_message_context(message: Message) -> bool:
"""
Validates message context so its relevant data could be fetched once
a response is received
"""
message_id = message.context.get('mq', {}).get('message_id')
if not message_id:
LOG.warning('Message context validation failed - message_id is None')
LOG.warning('Message context validation failed - '
'message.context["mq"]["message_id"] is None')
return False
else:
message.context['created_on'] = int(time.time())
if message.msg_type == 'neon.get_stt':
message.context['lang'] = message.data.get('lang')
if message.context.get('lang') != message.data.get('lang'):
LOG.warning("Context lang does not match data!")
message.context['lang'] = message.data.get('lang')
return True

def handle_user_message(self,
Expand Down Expand Up @@ -245,8 +275,13 @@ def handle_user_message(self,
dict_data["context"].setdefault("mq",
{**mq_context, **klat_context})
dict_data["context"].setdefault("klat_data", klat_context)

validation_error, dict_data = self.validate_request(dict_data)
# TODO: Consider merging this context instead of `setdefault` so
# expected keys are always present
try:
validation_error, dict_data = self.validate_request(dict_data)
except ValueError as e:
LOG.error(e)
validation_error = True
if validation_error:
LOG.error(f"Validation failed with: {validation_error}")
# Don't deserialize since this Message may be malformed
Expand All @@ -258,21 +293,23 @@ def handle_user_message(self,
'neon_chat_api_error')
self.handle_neon_message(response)
else:
# dict_data["context"].setdefault('ident', f"{dict_data['msg_type']}.response")
message = Message(**dict_data)
is_context_valid = self.validate_message_context(message)
if is_context_valid:
self.bus.emit(message)
else:
LOG.error(f'Message context is invalid - {message} is not emitted')
LOG.error(f'Message context is invalid - '
f'{message.context["mq"]["message_id"]} '
f'is not emitted')
else:
channel.basic_nack()
raise TypeError(f'Invalid body received, expected: bytes string;'
f' got: {type(body)}')

def format_response(self, response_type: NeonResponseTypes, message: Message) -> dict:
def format_response(self, response_type: NeonResponseTypes,
message: Message) -> dict:
"""
Formats received response by Neon API based on type
Reformat received response by Neon API for Klat based on type
:param response_type: response type from NeonResponseTypes Enum
:param message: Neon MessageBus Message object
Expand Down Expand Up @@ -303,19 +340,15 @@ def format_response(self, response_type: NeonResponseTypes, message: Message) ->
}
elif response_type == NeonResponseTypes.STT:
transcripts = message.data.get('transcripts', [''])
if transcripts and transcripts[0]:
LOG.info(f'transcript candidates received - {transcripts}')
response_data = {
'transcript': transcripts[0],
'other_transcripts': [transcript for transcript in
transcripts if
transcript != transcripts[0]],
'lang': message.context.get('lang', 'en-us'),
'context': message.context
}
else:
LOG.error('No transcripts received')
response_data = {}
LOG.info(f'transcript candidates received - {transcripts}')
response_data = {
'transcript': transcripts[0],
'other_transcripts': [transcript for transcript in
transcripts if
transcript != transcripts[0]],
'lang': message.context.get('lang', 'en-us'),
'context': message.context
}
else:
LOG.warning(f'Failed to response response type -> '
f'{response_type}')
Expand Down
17 changes: 17 additions & 0 deletions neon_messagebus_mq_connector/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ class RecognizerMessage(BaseModel):
)


class AudioInput(BaseModel):
msg_type: str = "neon.audio_input"
data: create_model("Data",
audio_data=(str, ...),
lang=(str, ...),
__base__=BaseModel,
)
context: create_model("Context",
source=(str, "mq_api"),
destination=(list, ["speech"]),
username=(str, "guest"),
user_profiles=(list, []),
__base__=BaseModel,
)


class STTMessage(BaseModel):
msg_type: str = "neon.get_stt"
data: create_model("Data",
Expand Down Expand Up @@ -90,5 +106,6 @@ class TTSMessage(BaseModel):
templates = {
"stt": STTMessage,
"tts": TTSMessage,
"audio_input": AudioInput,
"recognizer": RecognizerMessage,
}

0 comments on commit f67c1fd

Please sign in to comment.