Skip to content

Commit c5f7d65

Browse files
laipz8200iamjoel
andauthored
feat: Allow using file variables directly in the LLM node and support more file types. (langgenius#10679)
Co-authored-by: Joel <iamjoel007@gmail.com>
1 parent 535c72c commit c5f7d65

File tree

36 files changed

+1036
-268
lines changed

36 files changed

+1036
-268
lines changed

api/configs/app_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class DifyConfig(
2727
# read from dotenv format config file
2828
env_file=".env",
2929
env_file_encoding="utf-8",
30-
frozen=True,
3130
# ignore extra attributes
3231
extra="ignore",
3332
)

api/core/app/app_config/easy_ui_based_app/model_config/converter.py

+19-23
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class ModelConfigConverter:
1313
@classmethod
14-
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
14+
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
1515
"""
1616
Convert app model config dict to entity.
1717
:param app_config: app config
@@ -38,27 +38,23 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) ->
3838
)
3939

4040
if model_credentials is None:
41-
if not skip_check:
42-
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
43-
else:
44-
model_credentials = {}
45-
46-
if not skip_check:
47-
# check model
48-
provider_model = provider_model_bundle.configuration.get_provider_model(
49-
model=model_config.model, model_type=ModelType.LLM
50-
)
51-
52-
if provider_model is None:
53-
model_name = model_config.model
54-
raise ValueError(f"Model {model_name} not exist.")
55-
56-
if provider_model.status == ModelStatus.NO_CONFIGURE:
57-
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
58-
elif provider_model.status == ModelStatus.NO_PERMISSION:
59-
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
60-
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
61-
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
41+
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
42+
43+
# check model
44+
provider_model = provider_model_bundle.configuration.get_provider_model(
45+
model=model_config.model, model_type=ModelType.LLM
46+
)
47+
48+
if provider_model is None:
49+
model_name = model_config.model
50+
raise ValueError(f"Model {model_name} not exist.")
51+
52+
if provider_model.status == ModelStatus.NO_CONFIGURE:
53+
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
54+
elif provider_model.status == ModelStatus.NO_PERMISSION:
55+
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
56+
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
57+
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
6258

6359
# model config
6460
completion_params = model_config.parameters
@@ -76,7 +72,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) ->
7672

7773
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
7874

79-
if not skip_check and not model_schema:
75+
if not model_schema:
8076
raise ValueError(f"Model {model_name} not exist.")
8177

8278
return ModelConfigWithCredentialsEntity(

api/core/app/task_pipeline/workflow_cycle_manage.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,12 @@ def _handle_workflow_run_failed(
217217
).total_seconds()
218218
db.session.commit()
219219

220-
db.session.refresh(workflow_run)
221220
db.session.close()
222221

222+
with Session(db.engine, expire_on_commit=False) as session:
223+
session.add(workflow_run)
224+
session.refresh(workflow_run)
225+
223226
if trace_manager:
224227
trace_manager.add_trace_task(
225228
TraceTask(

api/core/file/file_manager.py

+25-44
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
from configs import dify_config
44
from core.file import file_repository
55
from core.helper import ssrf_proxy
6-
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
6+
from core.model_runtime.entities import (
7+
AudioPromptMessageContent,
8+
DocumentPromptMessageContent,
9+
ImagePromptMessageContent,
10+
VideoPromptMessageContent,
11+
)
712
from extensions.ext_database import db
813
from extensions.ext_storage import storage
914

@@ -29,43 +34,25 @@ def get_attr(*, file: File, attr: FileAttribute):
2934
return file.remote_url
3035
case FileAttribute.EXTENSION:
3136
return file.extension
32-
case _:
33-
raise ValueError(f"Invalid file attribute: {attr}")
3437

3538

3639
def to_prompt_message_content(
3740
f: File,
3841
/,
3942
*,
40-
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
43+
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
4144
):
42-
"""
43-
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
44-
45-
This function takes a File object and converts it to an appropriate PromptMessageContent
46-
object, which can be used as a prompt for image or audio-based AI models.
47-
48-
Args:
49-
f (File): The File object to convert.
50-
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
51-
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
52-
53-
Returns:
54-
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
55-
56-
Raises:
57-
ValueError: If the file type is not supported or if required data is missing.
58-
"""
5945
match f.type:
6046
case FileType.IMAGE:
47+
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
6148
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
6249
data = _to_url(f)
6350
else:
6451
data = _to_base64_data_string(f)
6552

6653
return ImagePromptMessageContent(data=data, detail=image_detail_config)
6754
case FileType.AUDIO:
68-
encoded_string = _file_to_encoded_string(f)
55+
encoded_string = _get_encoded_string(f)
6956
if f.extension is None:
7057
raise ValueError("Missing file extension")
7158
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
@@ -74,9 +61,20 @@ def to_prompt_message_content(
7461
data = _to_url(f)
7562
else:
7663
data = _to_base64_data_string(f)
64+
if f.extension is None:
65+
raise ValueError("Missing file extension")
7766
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
67+
case FileType.DOCUMENT:
68+
data = _get_encoded_string(f)
69+
if f.mime_type is None:
70+
raise ValueError("Missing file mime_type")
71+
return DocumentPromptMessageContent(
72+
encode_format="base64",
73+
mime_type=f.mime_type,
74+
data=data,
75+
)
7876
case _:
79-
raise ValueError("file type f.type is not supported")
77+
raise ValueError(f"file type {f.type} is not supported")
8078

8179

8280
def download(f: File, /):
@@ -118,40 +116,23 @@ def _get_encoded_string(f: File, /):
118116
case FileTransferMethod.REMOTE_URL:
119117
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
120118
response.raise_for_status()
121-
content = response.content
122-
encoded_string = base64.b64encode(content).decode("utf-8")
123-
return encoded_string
119+
data = response.content
124120
case FileTransferMethod.LOCAL_FILE:
125121
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
126122
data = _download_file_content(upload_file.key)
127-
encoded_string = base64.b64encode(data).decode("utf-8")
128-
return encoded_string
129123
case FileTransferMethod.TOOL_FILE:
130124
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
131125
data = _download_file_content(tool_file.file_key)
132-
encoded_string = base64.b64encode(data).decode("utf-8")
133-
return encoded_string
134-
case _:
135-
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
126+
127+
encoded_string = base64.b64encode(data).decode("utf-8")
128+
return encoded_string
136129

137130

138131
def _to_base64_data_string(f: File, /):
139132
encoded_string = _get_encoded_string(f)
140133
return f"data:{f.mime_type};base64,{encoded_string}"
141134

142135

143-
def _file_to_encoded_string(f: File, /):
144-
match f.type:
145-
case FileType.IMAGE:
146-
return _to_base64_data_string(f)
147-
case FileType.VIDEO:
148-
return _to_base64_data_string(f)
149-
case FileType.AUDIO:
150-
return _get_encoded_string(f)
151-
case _:
152-
raise ValueError(f"file type {f.type} is not supported")
153-
154-
155136
def _to_url(f: File, /):
156137
if f.transfer_method == FileTransferMethod.REMOTE_URL:
157138
if f.remote_url is None:

api/core/memory/token_buffer_memory.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
from typing import Optional
23

34
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@@ -27,7 +28,7 @@ def __init__(self, conversation: Conversation, model_instance: ModelInstance) ->
2728

2829
def get_history_prompt_messages(
2930
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
30-
) -> list[PromptMessage]:
31+
) -> Sequence[PromptMessage]:
3132
"""
3233
Get history prompt messages.
3334
:param max_token_limit: max token limit

api/core/model_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ def _get_load_balancing_manager(
100100

101101
def invoke_llm(
102102
self,
103-
prompt_messages: list[PromptMessage],
103+
prompt_messages: Sequence[PromptMessage],
104104
model_parameters: Optional[dict] = None,
105105
tools: Sequence[PromptMessageTool] | None = None,
106-
stop: Optional[list[str]] = None,
106+
stop: Optional[Sequence[str]] = None,
107107
stream: bool = True,
108108
user: Optional[str] = None,
109109
callbacks: Optional[list[Callback]] = None,

api/core/model_runtime/callbacks/base_callback.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from collections.abc import Sequence
23
from typing import Optional
34

45
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@@ -31,7 +32,7 @@ def on_before_invoke(
3132
prompt_messages: list[PromptMessage],
3233
model_parameters: dict,
3334
tools: Optional[list[PromptMessageTool]] = None,
34-
stop: Optional[list[str]] = None,
35+
stop: Optional[Sequence[str]] = None,
3536
stream: bool = True,
3637
user: Optional[str] = None,
3738
) -> None:
@@ -60,7 +61,7 @@ def on_new_chunk(
6061
prompt_messages: list[PromptMessage],
6162
model_parameters: dict,
6263
tools: Optional[list[PromptMessageTool]] = None,
63-
stop: Optional[list[str]] = None,
64+
stop: Optional[Sequence[str]] = None,
6465
stream: bool = True,
6566
user: Optional[str] = None,
6667
):
@@ -90,7 +91,7 @@ def on_after_invoke(
9091
prompt_messages: list[PromptMessage],
9192
model_parameters: dict,
9293
tools: Optional[list[PromptMessageTool]] = None,
93-
stop: Optional[list[str]] = None,
94+
stop: Optional[Sequence[str]] = None,
9495
stream: bool = True,
9596
user: Optional[str] = None,
9697
) -> None:
@@ -120,7 +121,7 @@ def on_invoke_error(
120121
prompt_messages: list[PromptMessage],
121122
model_parameters: dict,
122123
tools: Optional[list[PromptMessageTool]] = None,
123-
stop: Optional[list[str]] = None,
124+
stop: Optional[Sequence[str]] = None,
124125
stream: bool = True,
125126
user: Optional[str] = None,
126127
) -> None:

api/core/model_runtime/entities/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .message_entities import (
33
AssistantPromptMessage,
44
AudioPromptMessageContent,
5+
DocumentPromptMessageContent,
56
ImagePromptMessageContent,
67
PromptMessage,
78
PromptMessageContent,
@@ -37,4 +38,5 @@
3738
"LLMResultChunk",
3839
"LLMResultChunkDelta",
3940
"AudioPromptMessageContent",
41+
"DocumentPromptMessageContent",
4042
]

api/core/model_runtime/entities/message_entities.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC
2+
from collections.abc import Sequence
23
from enum import Enum
3-
from typing import Optional
4+
from typing import Literal, Optional
45

56
from pydantic import BaseModel, Field, field_validator
67

@@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
5758
IMAGE = "image"
5859
AUDIO = "audio"
5960
VIDEO = "video"
61+
DOCUMENT = "document"
6062

6163

6264
class PromptMessageContent(BaseModel):
@@ -101,13 +103,20 @@ class DETAIL(str, Enum):
101103
detail: DETAIL = DETAIL.LOW
102104

103105

106+
class DocumentPromptMessageContent(PromptMessageContent):
107+
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
108+
encode_format: Literal["base64"]
109+
mime_type: str
110+
data: str
111+
112+
104113
class PromptMessage(ABC, BaseModel):
105114
"""
106115
Model class for prompt message.
107116
"""
108117

109118
role: PromptMessageRole
110-
content: Optional[str | list[PromptMessageContent]] = None
119+
content: Optional[str | Sequence[PromptMessageContent]] = None
111120
name: Optional[str] = None
112121

113122
def is_empty(self) -> bool:

api/core/model_runtime/entities/model_entities.py

+3
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ class ModelFeature(Enum):
8787
AGENT_THOUGHT = "agent-thought"
8888
VISION = "vision"
8989
STREAM_TOOL_CALL = "stream-tool-call"
90+
DOCUMENT = "document"
91+
VIDEO = "video"
92+
AUDIO = "audio"
9093

9194

9295
class DefaultParameterName(str, Enum):

0 commit comments

Comments
 (0)