Skip to content

Commit

Permalink
refactor(api): streamline message file handling
Browse files Browse the repository at this point in the history
- Removed redundant 'attribute' arguments from 'message_files' fields for cleaner schema definitions.
- Updated logic to retrieve message files directly from the database, enhancing data consistency.
- Enhanced type annotations for clarity and better code maintainability.
- Improved file handling logic within `Message` class properties, ensuring correct URL generation.
- Replaced enum utilization with direct imports for `FileType` and `FileTransferMethod` to simplify codebase.
- Extended various type definitions to support sequence mappings for improved flexibility in file handling.
- Added constructor to `MessageFile` class for better encapsulation and object initialization.
- Removed unused imports and redundant code from various unit tests for improved efficiency.
  • Loading branch information
laipz8200 committed Oct 9, 2024
1 parent 0b29520 commit bfd0022
Show file tree
Hide file tree
Showing 18 changed files with 160 additions and 134 deletions.
1 change: 0 additions & 1 deletion api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class SecurityConfig(BaseSettings):
description="Secret key for secure session cookie signing."
"Make sure you are changing this key for your deployment with a strong key."
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
default=None,
)

RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/explore/saved_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}
Expand Down
4 changes: 2 additions & 2 deletions api/controllers/service_api/app/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MessageListApi(Resource):
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"message_files": fields.List(fields.String, attribute="files"),
"message_files": fields.List(fields.String),
}

message_fields = {
Expand All @@ -58,7 +58,7 @@ class MessageListApi(Resource):
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/web/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class MessageListApi(WebApiResource):
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/web/saved_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}
Expand Down
4 changes: 2 additions & 2 deletions api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought
from models.model import Conversation, Message, MessageAgentThought, MessageFile
from models.tools import ToolConversationVariables

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -495,7 +495,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
return result

def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = message.message_files
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())

Expand Down
23 changes: 18 additions & 5 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
Expand Down Expand Up @@ -53,8 +54,8 @@
from enums.workflow_nodes import NodeType
from events.message_event import message_was_created
from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowRunStatus,
Expand Down Expand Up @@ -494,17 +495,29 @@ def _process_stream_response(
self._conversation_name_generate_thread.join()

def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
"""
Save message.
:return:
"""
self._refetch_message()

self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message_files = [
MessageFile(
message_id=self._message.id,
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to="assistant",
upload_file_id=file["related_id"],
created_by_role="account"
if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else "end_user",
created_by=self._message.from_account_id or self._message.from_end_user_id or "",
)
for file in self._recorded_files
]
db.session.add_all(message_files)

if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
Expand Down
2 changes: 1 addition & 1 deletion api/core/app/apps/completion/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def generate_more_like_this(
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=message.files,
mappings=message.message_files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
Expand Down
6 changes: 3 additions & 3 deletions api/core/app/apps/message_based_app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,13 @@ def _init_generate_records(
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=("account" if account_id else "end_user"),
created_by=account_id or end_user_id,
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)
db.session.commit()
Expand Down
4 changes: 2 additions & 2 deletions api/core/app/entities/task_entities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional

Expand Down Expand Up @@ -120,7 +120,7 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: dict = {}
files: Optional[list[Mapping[str, Any]]] = None
files: Optional[Sequence[Mapping[str, Any]]] = None


class MessageFileStreamResponse(StreamResponse):
Expand Down
20 changes: 10 additions & 10 deletions api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import time
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Any, Optional, Union, cast

Expand Down Expand Up @@ -607,7 +607,7 @@ def _workflow_iteration_completed_to_stream_response(
),
)

def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
Expand All @@ -624,7 +624,7 @@ def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:

return files

def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]:
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
Expand All @@ -636,17 +636,17 @@ def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dic
files = []
if isinstance(value, list):
for item in value:
file_var = self._get_file_var_from_value(item)
if file_var:
files.append(file_var)
file = self._get_file_var_from_value(item)
if file:
files.append(file)
elif isinstance(value, dict):
file_var = self._get_file_var_from_value(value)
if file_var:
files.append(file_var)
file = self._get_file_var_from_value(value)
if file:
files.append(file)

return files

def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, str | int | None] | None:
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
Expand Down
8 changes: 4 additions & 4 deletions api/core/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from collections.abc import Callable, Generator, Sequence
from typing import IO, Optional, Union, cast
from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Optional, Union, cast

from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
Expand Down Expand Up @@ -274,7 +274,7 @@ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str
user=user,
)

def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
"""
Invoke large language tts model
Expand All @@ -298,7 +298,7 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option
voice=voice,
)

def _round_robin_invoke(self, function: Callable, *args, **kwargs):
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
"""
Round-robin invoke
:param function: function to invoke
Expand Down
75 changes: 47 additions & 28 deletions api/core/model_runtime/model_providers/__base/tts_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import re
from abc import abstractmethod
from typing import Optional
from collections.abc import Iterable
from typing import Any, Optional

from pydantic import ConfigDict

Expand All @@ -22,8 +23,14 @@ class TTSModel(AIModel):
model_config = ConfigDict(protected_namespaces=())

def invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
):
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> Iterable[bytes]:
"""
Invoke large language model
Expand All @@ -50,8 +57,14 @@ def invoke(

@abstractmethod
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
):
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> Iterable[bytes]:
"""
Invoke large language model
Expand All @@ -68,27 +81,27 @@ def _invoke(

def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
"""
Get voice for given tts model voices
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
:param language: tts language
:param model: model name
:param credentials: model credentials
:return: voices lists
:param language: The language for which the voices are requested.
:param model: The name of the TTS model.
:param credentials: The credentials required to access the TTS model.
:return: A list of voices supported by the TTS model.
"""
model_schema = self.get_model_schema(model, credentials)

if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
if language:
return [
{"name": d["name"], "value": d["mode"]}
for d in voices
if language and language in d.get("language")
]
else:
return [{"name": d["name"], "value": d["mode"]} for d in voices]

def _get_model_default_voice(self, model: str, credentials: dict) -> any:
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
raise ValueError("this model does not support voice")

voices = model_schema.model_properties[ModelPropertyKey.VOICES]
if language:
return [
{"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language")
]
else:
return [{"name": d["name"], "value": d["mode"]} for d in voices]

def _get_model_default_voice(self, model: str, credentials: dict) -> Any:
"""
Get voice for given tts model
Expand All @@ -111,8 +124,10 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str:
"""
model_schema = self.get_model_schema(model, credentials)

if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
raise ValueError("this model does not support audio type")

return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]

def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""
Expand All @@ -121,8 +136,10 @@ def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""
model_schema = self.get_model_schema(model, credentials)

if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
raise ValueError("this model does not support word limit")

return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]

def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
"""
Expand All @@ -131,8 +148,10 @@ def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
"""
model_schema = self.get_model_schema(model, credentials)

if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
raise ValueError("this model does not support max workers")

return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]

@staticmethod
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
Expand Down
Loading

0 comments on commit bfd0022

Please sign in to comment.