Skip to content

Commit

Permalink
Use a less specific return type for | on Runnables (#11762)
Browse files Browse the repository at this point in the history
<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
nfcampos and baskaryan authored Oct 15, 2023
1 parent 6c5bb1b commit 4321d19
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
4 changes: 2 additions & 2 deletions libs/langchain/langchain/chains/sql_database/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output_parser import NoOpOutputParser
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.schema.runnable import RunnableParallel, RunnableSequence
from langchain.schema.runnable import Runnable, RunnableParallel
from langchain.utilities.sql_database import SQLDatabase


Expand All @@ -30,7 +30,7 @@ def create_sql_query_chain(
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,
) -> RunnableSequence[Union[SQLInput, SQLInputWithTables], str]:
) -> Runnable[Union[SQLInput, SQLInputWithTables], str]:
"""Create a chain that generates SQL queries.
Args:
Expand Down
8 changes: 4 additions & 4 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __or__(
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
) -> Runnable[Input, Other]:
"""Compose this runnable with another object to create a RunnableSequence."""
return RunnableSequence(first=self, last=coerce_to_runnable(other))

Expand All @@ -254,7 +254,7 @@ def __ror__(
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
) -> Runnable[Other, Output]:
"""Compose this runnable with another object to create a RunnableSequence."""
return RunnableSequence(first=coerce_to_runnable(other), last=self)

Expand Down Expand Up @@ -1064,7 +1064,7 @@ def __or__(
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
) -> Runnable[Input, Other]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=self.first,
Expand All @@ -1086,7 +1086,7 @@ def __ror__(
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
) -> Runnable[Other, Output]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=other.first,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain.schema.runnable import (
GetLocalVar,
PutLocalVar,
Runnable,
RunnablePassthrough,
RunnableSequence,
)
Expand Down Expand Up @@ -52,12 +53,12 @@ def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) ->


def test_get_in_map() -> None:
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
runnable: Runnable = PutLocalVar("input") | {"bar": GetLocalVar("input")}
assert runnable.invoke("foo") == {"bar": "foo"}


def test_put_in_map() -> None:
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
runnable: Runnable = {"bar": PutLocalVar("input")} | GetLocalVar("input")
with pytest.raises(KeyError):
runnable.invoke("foo")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1978,7 +1978,7 @@ def test_combining_sequences(
lambda x: {"question": x[0] + x[1]}
)

chain2 = input_formatter | prompt2 | chat2 | parser2
chain2 = cast(RunnableSequence, input_formatter | prompt2 | chat2 | parser2)

assert isinstance(chain, RunnableSequence)
assert chain2.first == input_formatter
Expand All @@ -1987,7 +1987,7 @@ def test_combining_sequences(
if sys.version_info >= (3, 9):
assert dumps(chain2, pretty=True) == snapshot

combined_chain = chain | chain2
combined_chain = cast(RunnableSequence, chain | chain2)

assert combined_chain.first == prompt
assert combined_chain.middle == [
Expand Down Expand Up @@ -2972,7 +2972,7 @@ def llm_with_multi_fallbacks() -> RunnableWithFallbacks:


@pytest.fixture()
def llm_chain_with_fallbacks() -> RunnableSequence:
def llm_chain_with_fallbacks() -> Runnable:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])

Expand Down

0 comments on commit 4321d19

Please sign in to comment.