From f0adc03de309c9549272043b33ac2bdcf2c1f510 Mon Sep 17 00:00:00 2001 From: Richard Kuo Date: Mon, 28 Oct 2024 17:22:21 -0700 Subject: [PATCH] backport structured json fix --- ...533f0_make_last_attempt_status_nullable.py | 6 +++ backend/danswer/llm/answering/answer.py | 6 ++- .../tests/dev_apis/test_simple_chat_api.py | 47 +++++++++---------- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py b/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py index a6938e365c6..db7b330c3e0 100644 --- a/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py +++ b/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py @@ -31,6 +31,12 @@ def upgrade() -> None: def downgrade() -> None: + # First, update any null values to a default value + op.execute( + "UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL" + ) + + # Then, make the column non-nullable op.alter_column( "connector_credential_pair", "last_attempt_status", diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 4648e0fe821..d2aeb1b14c4 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -237,6 +237,7 @@ def _raw_output_for_explicit_tool_calling_llms( prompt=prompt, tools=final_tool_definitions if final_tool_definitions else None, tool_choice="required" if self.force_use_tool.force_use else None, + structured_response_format=self.answer_style_config.structured_response_format, ): if isinstance(message, AIMessageChunk) and ( message.tool_call_chunks or message.tool_calls @@ -331,7 +332,10 @@ def _process_llm_stream( tool_choice: ToolChoiceOptions | None = None, ) -> Iterator[str | StreamStopInfo]: for message in self.llm.stream( - prompt=prompt, tools=tools, tool_choice=tool_choice + prompt=prompt, + tools=tools, + tool_choice=tool_choice, + structured_response_format=self.answer_style_config.structured_response_format, ): if isinstance(message, AIMessageChunk): if message.content: diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index fd7db7098bd..c37d1a6235d 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -154,42 +154,38 @@ def test_send_message_simple_with_history_strict_json( new_admin_user: DATestUser | None, ) -> None: # create connectors - cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch( - user_performing_action=new_admin_user, - ) - api_key: DATestAPIKey = APIKeyManager.create( - user_performing_action=new_admin_user, - ) LLMProviderManager.create(user_performing_action=new_admin_user) - cc_pair_1.documents = DocumentManager.seed_dummy_docs( - cc_pair=cc_pair_1, - num_docs=NUM_DOCS, - api_key=api_key, - ) response = requests.post( f"{API_SERVER_URL}/chat/send-message-simple-with-history", json={ + # intentionally not relevant prompt to ensure that the + # structured response format is actually used "messages": [ { - "message": "List the names of the first three US presidents in JSON format", + "message": "What is green?", "role": MessageType.USER.value, } ], "persona_id": 0, "prompt_id": 0, "structured_response_format": { - "type": "json_object", - "schema": { - "type": "object", - "properties": { - "presidents": { - "type": "array", - "items": {"type": "string"}, - "description": "List of the first three US presidents", - } + "type": "json_schema", + "json_schema": { + "name": "presidents", + "schema": { + "type": "object", + "properties": { + "presidents": { + "type": "array", + "items": {"type": "string"}, + "description": "List of the first three US presidents", + } + }, + "required": ["presidents"], + "additionalProperties": False, }, - "required": ["presidents"], + "strict": True, }, }, }, @@ -211,14 +207,17 @@ def clean_json_string(json_string: str) -> str: try: clean_answer = clean_json_string(response_json["answer"]) parsed_answer = json.loads(clean_answer) + + # NOTE: do not check content, just the structure assert isinstance(parsed_answer, dict) assert "presidents" in parsed_answer assert isinstance(parsed_answer["presidents"], list) - assert len(parsed_answer["presidents"]) == 3 for president in parsed_answer["presidents"]: assert isinstance(president, str) except json.JSONDecodeError: - assert False, "The answer is not a valid JSON object" + assert ( + False + ), f"The answer is not a valid JSON object - '{response_json['answer']}'" # Check that the answer_citationless is also valid JSON assert "answer_citationless" in response_json