diff --git a/neon_data_models/models/__init__.py b/neon_data_models/models/__init__.py index 8d9d8f6..3ca2d28 100644 --- a/neon_data_models/models/__init__.py +++ b/neon_data_models/models/__init__.py @@ -23,3 +23,9 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from neon_data_models.models.user import * +User.rebuild_model() + +from neon_data_models.models.client import * +from neon_data_models.models.api import * diff --git a/neon_data_models/models/api/__init__.py b/neon_data_models/models/api/__init__.py index 5881279..ba7a6af 100644 --- a/neon_data_models/models/api/__init__.py +++ b/neon_data_models/models/api/__init__.py @@ -25,5 +25,6 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from neon_data_models.models.api.node_v1 import * -from neon_data_models.models.api.mq import * from neon_data_models.models.api.jwt import * +from neon_data_models.models.api.llm import * +from neon_data_models.models.api.mq import * diff --git a/neon_data_models/models/api/llm.py b/neon_data_models/models/api/llm.py new file mode 100644 index 0000000..5b5ab58 --- /dev/null +++ b/neon_data_models/models/api/llm.py @@ -0,0 +1,186 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2024 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List, Tuple, Optional, Literal +from pydantic import Field, model_validator, computed_field + +from neon_data_models.models.base import BaseModel + + +_DEFAULT_MQ_TO_ROLE = {"user": "user", "llm": "assistant"} + + +class LLMPersona(BaseModel): + name: str = Field(description="Unique name for this persona") + description: Optional[str] = Field( + None, description="Human-readable description of this persona") + system_prompt: str = Field( + None, description="System prompt associated with this persona. " + "If None, `description` will be used.") + enabled: bool = Field( + True, description="Flag used to mark a defined persona as " + "available for use.") + user_id: Optional[str] = Field( + None, description="`user_id` of the user who created this persona.") + + @model_validator(mode='after') + def validate_request(self): + assert any((self.description, self.system_prompt)) + if self.system_prompt is None: + self.system_prompt = self.description + return self + + @computed_field + @property + def id(self) -> str: + persona_id = self.name + if self.user_id: + persona_id += f"_{self.user_id}" + return persona_id + + +class LLMRequest(BaseModel): + query: str = Field(description="Incoming user prompt") + # TODO: History may support more options in the future + history: List[Tuple[Literal["user", "llm"], str]] = Field( + description="OpenAI-formatted chat history (excluding system prompt)") + persona: LLMPersona = Field( + description="Requested persona to respond to this message") + model: str = Field(description="Model to request") + max_tokens: int = Field( + default=512, ge=64, le=2048, + description="Maximum number of tokens to include in the response") + temperature: float = Field( + default=0.0, ge=0.0, le=1.0, + description="Temperature of response. 0 guarantees reproducibility, " + "higher values increase variability.") + repetition_penalty: float = Field( + default=1.0, ge=1.0, le=2.0, + description="Repetition penalty. Higher values limit repeated " + "information in responses") + stream: bool = Field( + default=None, description="Enable streaming responses. " + "Mutually exclusive with `beam_search`.") + best_of: int = Field( + default=1, + description="Number of beams to use if `beam_search` is enabled.") + beam_search: bool = Field( + default=None, description="Enable beam search. " + "Mutually exclusive with `stream`.") + max_history: int = Field( + default=2, description="Maximum number of user/assistant " + "message pairs to include in history context.") + + @model_validator(mode='before') + @classmethod + def validate_inputs(cls, values): + # Neon modules previously defined `user` and `llm` keys, but Open AI + # specifies `assistant` in place of `llm` and is the de-facto standard + for idx, itm in enumerate(values.get('history', [])): + if itm[0] == "assistant": + values['history'][idx] = ("llm", itm[1]) + return values + + @model_validator(mode='after') + def validate_request(self): + # If beams are specified, make sure valid `stream` and `beam_search` + # values are specified + if self.best_of > 1: + if self.stream is True: + raise ValueError("Cannot stream with a `best_of` value " + "greater than 1") + if self.beam_search is False: + raise ValueError("Cannot have a `best_of` value other than 1 " + "if `beam_search` is False") + self.stream = False + self.beam_search = True + # If streaming, beam_search must be False + if self.stream is True: + if self.beam_search is True: + raise ValueError("Cannot enable both `stream` and " + "`beam_search`") + self.beam_search = False + # If beam search is enabled, streaming must be False + if self.beam_search is True: + if self.stream is True: + raise ValueError("Cannot enable both `stream` and " + "`beam_search`") + self.stream = False + if self.stream is None and self.beam_search is None: + self.stream = True + self.beam_search = False + assert isinstance(self.stream, bool) + assert isinstance(self.beam_search, bool) + return self + + @property + def messages(self) -> List[dict]: + """ + Get chat history as a list of dict messages + """ + return [{"role": m[0], "content": m[1]} for m in self.history] + + def to_completion_kwargs(self, mq2role: dict = None) -> dict: + """ + Get kwargs to pass to an OpenAI completion request. + @param mq2role: dict mapping `llm` and `user` keys to `role` values to + use in message history. + """ + mq2role = mq2role or _DEFAULT_MQ_TO_ROLE + history = self.messages[-2*self.max_history:] + for msg in history: + msg["role"] = mq2role.get(msg["role"]) or msg["role"] + history.insert(0, {"role": "system", + "content": self.persona.system_prompt}) + return {"model": self.model, + "messages": history, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "stream": self.stream, + "extra_body": {"add_special_tokens": True, + "repetition_penalty": self.repetition_penalty, + "use_beam_search": self.beam_search, + "best_of": self.best_of}} + + +class LLMResponse(BaseModel): + response: str = Field(description="LLM Response to the input query") + history: List[Tuple[Literal["user", "llm"], str]] = Field( + description="List of (role, content) tuples in chronological order " + "(`response` is in the last list element)") + + @model_validator(mode='before') + @classmethod + def validate_inputs(cls, values): + # Neon modules previously defined `user` and `llm` keys, but Open AI + # specifies `assistant` in place of `llm` and is the de-facto standard + for idx, itm in enumerate(values.get('history', [])): + if itm[0] == "assistant": + values['history'][idx] = ("llm", itm[1]) + return values + + +__all__ = [LLMPersona.__name__, LLMRequest.__name__, LLMResponse.__name__] diff --git a/neon_data_models/models/api/mq/__init__.py b/neon_data_models/models/api/mq/__init__.py new file mode 100644 index 0000000..9fafcf2 --- /dev/null +++ b/neon_data_models/models/api/mq/__init__.py @@ -0,0 +1,28 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2024 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from neon_data_models.models.api.mq.llm import * +from neon_data_models.models.api.mq.users import * diff --git a/neon_data_models/models/api/mq/llm.py b/neon_data_models/models/api/mq/llm.py new file mode 100644 index 0000000..b0b9717 --- /dev/null +++ b/neon_data_models/models/api/mq/llm.py @@ -0,0 +1,71 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2024 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Optional, Dict, List +from pydantic import Field + +from neon_data_models.models.api.llm import LLMRequest, LLMPersona +from neon_data_models.models.base.contexts import MQContext + + +class LLMProposeRequest(MQContext, LLMRequest): + model: Optional[str] = Field( + default=None, + description="MQ implementation defines `model` as optional because the " + "queue defines the requested model in most cases.") + persona: Optional[LLMPersona] = Field( + default=None, + description="MQ implementation defines `persona` as an optional " + "parameter, with default behavior hard-coded into each " + "LLM module.") + + +class LLMProposeResponse(MQContext): + response: str = Field(description="LLM response to the prompt") + + +class LLMDiscussRequest(LLMProposeRequest): + options: Dict[str, str] = Field( + description="Mapping of participant name to response to be discussed.") + + +class LLMDiscussResponse(MQContext): + opinion: str = Field(description="LLM response to the available options.") + + +class LLMVoteRequest(LLMProposeRequest): + responses: List[str] = Field( + description="List of responses to choose from.") + + +class LLMVoteResponse(MQContext): + sorted_answer_indexes: List[int] = Field( + description="Indices of `responses` ordered high to low by preference.") + + +__all__ = [LLMProposeRequest.__name__, LLMProposeResponse.__name__, + LLMDiscussRequest.__name__, LLMDiscussResponse.__name__, + LLMVoteRequest.__name__, LLMVoteResponse.__name__] diff --git a/neon_data_models/models/api/mq.py b/neon_data_models/models/api/mq/users.py similarity index 99% rename from neon_data_models/models/api/mq.py rename to neon_data_models/models/api/mq/users.py index dc1dd53..3d7b1cb 100644 --- a/neon_data_models/models/api/mq.py +++ b/neon_data_models/models/api/mq/users.py @@ -25,7 +25,6 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from typing import Literal, Optional, Annotated, Union - from pydantic import Field, TypeAdapter, model_validator from neon_data_models.models.api.jwt import HanaToken diff --git a/neon_data_models/models/user/database.py b/neon_data_models/models/user/database.py index 4e7e885..699c08f 100644 --- a/neon_data_models/models/user/database.py +++ b/neon_data_models/models/user/database.py @@ -181,11 +181,12 @@ class TokenConfig(BaseModel): class User(BaseModel): - def __init__(self, **kwargs): + + @classmethod + def rebuild_model(cls): # Ensure `HanaToken` is populated from the import space from neon_data_models.models.api.jwt import HanaToken - self.model_rebuild() - BaseModel.__init__(self, **kwargs) + cls.model_rebuild() username: str password_hash: Optional[str] = None diff --git a/tests/models/api/test_llm.py b/tests/models/api/test_llm.py new file mode 100644 index 0000000..2548c27 --- /dev/null +++ b/tests/models/api/test_llm.py @@ -0,0 +1,153 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2024 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from unittest import TestCase +from pydantic import ValidationError + + +class TestLLM(TestCase): + def test_llm_persona(self): + from neon_data_models.models.api.llm import LLMPersona + legacy_mq_persona = LLMPersona(name="my persona", + description="You are a helpful chatbot.") + self.assertEqual(legacy_mq_persona.system_prompt, + legacy_mq_persona.description) + + legacy_bf_persona = LLMPersona(name="neon", + system_prompt="You are NeonLLM.") + self.assertIsNone(legacy_bf_persona.description) + self.assertEqual(legacy_bf_persona.system_prompt, + "You are NeonLLM.") + + full_persona = LLMPersona(name="custom chatbot", + description="A customized bot for something", + system_prompt="You are a custom chatbot.") + self.assertEqual(full_persona.system_prompt, + "You are a custom chatbot.") + self.assertEqual(full_persona.description, + "A customized bot for something") + + with self.assertRaises(ValidationError): + LLMPersona(name="underdefined persona") + + def test_llm_request(self): + from neon_data_models.models.api.llm import LLMRequest, LLMPersona + test_query = "how are you?" + test_history = [("user", "hello"), + ("assistant", "Hi, how can I help you today?"), + ("user", "I am well, how about you?"), + ("assistant", "As a large language model, I do not feel")] + test_persona = {"name": "Test Bot", + "system_prompt": "This is the system prompt."} + test_model = "my_model_spec" + # Minimal definition + valid_request = LLMRequest(query=test_query, history=test_history, + persona=test_persona, model=test_model) + self.assertIsInstance(valid_request.persona, LLMPersona) + self.assertTrue(valid_request.stream) + self.assertFalse(valid_request.beam_search) + self.assertEqual(len(valid_request.history), len(test_history)) + self.assertEqual(len(valid_request.to_completion_kwargs()['messages']), + 2 * valid_request.max_history + 1) + + # Valid explicit streaming + streaming_request = LLMRequest(query=test_query, history=test_history, + persona=test_persona, model=test_model, + stream=True) + self.assertEqual(streaming_request, valid_request) + + # Valid explicit beam search + beam_search_request = LLMRequest(query=test_query, history=test_history, + persona=test_persona, model=test_model, + beam_search=True) + self.assertTrue(beam_search_request.beam_search) + self.assertFalse(beam_search_request.stream) + + # Valid best_of, implied beam search + best_of_request = LLMRequest(query=test_query, history=test_history, + persona=test_persona, model=test_model, + best_of=3) + self.assertTrue(best_of_request.beam_search) + self.assertFalse(best_of_request.stream) + + # Valid explicitly disable streaming and beam search + valid_no_stream = LLMRequest(query=test_query, history=test_history, + persona=test_persona, model=test_model, + stream=False, beam_search=False) + self.assertFalse(valid_no_stream.beam_search) + self.assertFalse(valid_no_stream.stream) + + # Validate `llm` history input + old_history = [("user", "hello"), + ("llm", "Hi, how can I help you today?"), + ("user", "I am well, how about you?"), + ("llm", "As a large language model, I do not feel")] + validated = LLMRequest(query=test_query, history=old_history, + persona=test_persona, model=test_model) + self.assertEqual(validated.history, test_history) + + # Invalid streaming with beam search + with self.assertRaises(ValidationError): + LLMRequest(query=test_query, history=test_history, + persona=test_persona, model=test_model, stream=True, + beam_search=True) + # Invalid streaming with best_of > 1 + with self.assertRaises(ValidationError): + LLMRequest(query=test_query, history=test_history, + persona=test_persona, model=test_model, stream=True, + best_of=2) + + # Invalid history + test_history.append(("invalid_key", "okay")) + with self.assertRaises(ValidationError): + LLMRequest(query=test_query, history=test_history, + persona=test_persona, model=test_model) + test_history.pop() + + def test_llm_response(self): + from neon_data_models.models.api.llm import LLMResponse + valid_response = "hello" + valid_history = [("user", "hello"), ("assistant", "How can I help?")] + legacy_history = [("user", "hello"), ("llm", "How can I help?")] + + # Valid response with valid history + response = LLMResponse(response=valid_response, history=valid_history) + self.assertEqual(response.response, valid_response) + self.assertEqual(response.history, valid_history) + + # Valid response with legacy history + response = LLMResponse(response=valid_response, history=legacy_history) + self.assertEqual(response.response, valid_response) + self.assertEqual(response.history, valid_history) + + # Invalid response + with self.assertRaises(ValidationError): + LLMResponse(response=None, history=valid_history) + + # Invalid history + valid_history.append(("invalid", "response")) + with self.assertRaises(ValidationError): + LLMResponse(response=valid_response, history=valid_history) diff --git a/tests/models/api/test_mq.py b/tests/models/api/test_mq.py index 4768ad9..fc97716 100644 --- a/tests/models/api/test_mq.py +++ b/tests/models/api/test_mq.py @@ -27,12 +27,10 @@ from unittest import TestCase from pydantic import ValidationError -from neon_data_models.models.api.mq import UserDbRequest - class TestMQ(TestCase): def test_create_user_db_request(self): - from neon_data_models.models.api.mq import CreateUserRequest + from neon_data_models.models.api.mq.users import UserDbRequest, CreateUserRequest # Test create user valid valid_kwargs = {"message_id": "test_id", "operation": "create", @@ -50,7 +48,7 @@ def test_create_user_db_request(self): message_id="test0") def test_read_user_db_request(self): - from neon_data_models.models.api.mq import ReadUserRequest + from neon_data_models.models.api.mq.users import UserDbRequest, ReadUserRequest # Test read user valid valid_kwargs = {"message_id": "test_id", "operation": "read", @@ -68,7 +66,7 @@ def test_read_user_db_request(self): message_id="test0") def test_update_user_db_request(self): - from neon_data_models.models.api.mq import UpdateUserRequest + from neon_data_models.models.api.mq.users import UserDbRequest, UpdateUserRequest # Test update user valid valid_kwargs = {"message_id": "test_id", "operation": "update", @@ -95,7 +93,8 @@ def test_update_user_db_request(self): update = UpdateUserRequest(message_id="test_id", operation="update", user={"username": "user", "password_hash": "password"}, - auth_username="admin", auth_password="admin_pass") + auth_username="admin", + auth_password="admin_pass") self.assertEqual(update.user.username, "user") self.assertEqual(update.user.password_hash, "password") @@ -110,7 +109,7 @@ def test_update_user_db_request(self): message_id="test0") def test_delete_user_db_request(self): - from neon_data_models.models.api.mq import DeleteUserRequest + from neon_data_models.models.api.mq.users import UserDbRequest, DeleteUserRequest # Test delete user valid valid_kwargs = {"message_id": "test_id", "operation": "delete", @@ -126,3 +125,152 @@ def test_delete_user_db_request(self): with self.assertRaises(ValidationError): UserDbRequest(operation="delete", username="test_user", message_id="test0") + + def test_mq_llm_propose_request(self): + from neon_data_models.models.api.mq.llm import LLMProposeRequest + from neon_data_models.models.api.llm import LLMRequest + from neon_data_models.models.base.contexts import MQContext + + query = "who are you" + history = [] + model_name = "test_model" + persona = {"name": "test_persona", "system_prompt": "Test prompt."} + message_id = "test_mid" + + # Valid fully-defined + valid_request = LLMProposeRequest(query=query, history=history, + persona=persona, model=model_name, + message_id=message_id) + self.assertIsInstance(valid_request, LLMProposeRequest) + self.assertIsInstance(valid_request, LLMRequest) + self.assertIsInstance(valid_request, MQContext) + + # Valid backwards-compat (no model or persona) + backwards_compat = LLMProposeRequest(query=query, history=history, + message_id=message_id) + self.assertIsInstance(backwards_compat, LLMProposeRequest) + self.assertIsInstance(backwards_compat, LLMRequest) + self.assertIsInstance(backwards_compat, MQContext) + self.assertIsNone(backwards_compat.model) + self.assertIsNone(backwards_compat.persona) + + # Invalid Persona defined + with self.assertRaises(ValidationError): + LLMProposeRequest(query=query, history=history, + message_id=message_id, persona={}) + + # Invalid MQ Context + with self.assertRaises(ValidationError): + LLMProposeRequest(query=query, history=history) + + # Invalid LLM Request + with self.assertRaises(ValidationError): + LLMProposeRequest(history=history, message_id=message_id) + + def test_mq_llm_propose_response(self): + from neon_data_models.models.api.mq.llm import LLMProposeResponse + + # Valid response + self.assertIsInstance(LLMProposeResponse(response="test response", + message_id=""), + LLMProposeResponse) + + # Missing MQ required data + with self.assertRaises(ValidationError): + LLMProposeResponse(response="test response") + + # Missing response required data + with self.assertRaises(ValidationError): + LLMProposeResponse(message_id="") + + def test_mq_llm_discuss_request(self): + from neon_data_models.models.api.mq.llm import LLMDiscussRequest + query = "who are you" + history = [] + message_id = "test_mid" + opts = {"bot 1": "resp 1", "bot 2": "resp 2"} + invalid_opts = {"bot 1": "resp 1", "bot 2": None} + + # Valid request + valid_request = LLMDiscussRequest(query=query, history=history, + message_id=message_id, options=opts) + self.assertIsInstance(valid_request, LLMDiscussRequest) + + # Invalid options + with self.assertRaises(ValidationError): + LLMDiscussRequest(query=query, history=history, + message_id=message_id, options=invalid_opts) + + # Invalid MQ Context + with self.assertRaises(ValidationError): + LLMDiscussRequest(query=query, history=history, options=opts) + + # Invalid LLM Request + with self.assertRaises(ValidationError): + LLMDiscussRequest(query=query, message_id=message_id, options=opts) + + def test_mq_llm_discuss_response(self): + from neon_data_models.models.api.mq.llm import LLMDiscussResponse + + # Valid response + self.assertIsInstance(LLMDiscussResponse(opinion="test opinion", + message_id=""), + LLMDiscussResponse) + + # Missing MQ required data + with self.assertRaises(ValidationError): + LLMDiscussResponse(opinion="test opinion") + + # Missing response required data + with self.assertRaises(ValidationError): + LLMDiscussResponse(message_id="") + + def test_mq_llm_vote_request(self): + from neon_data_models.models.api.mq.llm import LLMVoteRequest + query = "who are you" + history = [] + message_id = "test_mid" + responses = ["resp 1", "resp 2"] + invalid_responses = ["resp 1", "resp 2", None] + + # Valid request + valid_request = LLMVoteRequest(query=query, history=history, + message_id=message_id, + responses=responses) + self.assertIsInstance(valid_request, LLMVoteRequest) + + # Invalid options + with self.assertRaises(ValidationError): + LLMVoteRequest(query=query, history=history, message_id=message_id, + responses=invalid_responses) + + # Invalid MQ Context + with self.assertRaises(ValidationError): + LLMVoteRequest(query=query, history=history, responses=responses) + + # Invalid LLM Request + with self.assertRaises(ValidationError): + LLMVoteRequest(query=query, message_id=message_id, + responses=responses) + + def test_mq_llm_vote_response(self): + from neon_data_models.models.api.mq.llm import LLMVoteResponse + + # Valid response + self.assertIsInstance(LLMVoteResponse(sorted_answer_indexes=[2, 0, 1], + message_id=""), + LLMVoteResponse) + + # Missing MQ required data + with self.assertRaises(ValidationError): + LLMVoteResponse(sorted_answer_indexes=[2, 0, 1]) + + # Missing response required data + with self.assertRaises(ValidationError): + LLMVoteResponse(message_id="") + + # Invalid response data + with self.assertRaises(ValidationError): + LLMVoteResponse(sorted_answer_indexes=[2, 0, 1, "invalid"], + message_id=""), +