Skip to content

Commit

Permalink
Update chat model output type (langchain-ai#11833)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
nfcampos and baskaryan authored Oct 19, 2023
1 parent ed62984 commit 7db6aab
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 33 deletions.
52 changes: 23 additions & 29 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
List,
Optional,
Sequence,
Union,
cast,
)

Expand All @@ -38,12 +37,10 @@
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import (
AIMessage,
AnyMessage,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
Expand Down Expand Up @@ -79,7 +76,7 @@ async def _agenerate_from_stream(
return ChatResult(generations=[generation])


class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""

cache: Optional[bool] = None
Expand Down Expand Up @@ -116,9 +113,7 @@ class Config:
@property
def OutputType(self) -> Any:
"""Get the output type for this runnable."""
return Union[
HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage
]
return AnyMessage

def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
Expand All @@ -140,23 +135,20 @@ def invoke(
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
) -> BaseMessage:
config = config or {}
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message,
)
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message

async def ainvoke(
self,
Expand All @@ -165,7 +157,7 @@ async def ainvoke(
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
) -> BaseMessage:
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
Expand All @@ -176,9 +168,7 @@ async def ainvoke(
run_name=config.get("run_name"),
**kwargs,
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
)
return cast(ChatGeneration, llm_result.generations[0][0]).message

def stream(
self,
Expand All @@ -190,7 +180,9 @@ def stream(
) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
Expand Down Expand Up @@ -241,7 +233,9 @@ async def astream(
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2163,19 +2163,19 @@
dict({
'anyOf': list([
dict({
'$ref': '#/definitions/HumanMessage',
'$ref': '#/definitions/AIMessage',
}),
dict({
'$ref': '#/definitions/AIMessage',
'$ref': '#/definitions/HumanMessage',
}),
dict({
'$ref': '#/definitions/ChatMessage',
}),
dict({
'$ref': '#/definitions/FunctionMessage',
'$ref': '#/definitions/SystemMessage',
}),
dict({
'$ref': '#/definitions/SystemMessage',
'$ref': '#/definitions/FunctionMessage',
}),
]),
'definitions': dict({
Expand Down

0 comments on commit 7db6aab

Please sign in to comment.