Skip to content

Commit

Permalink
Cache calls to create_model for get_input_schema and get_output_schema (
Browse files Browse the repository at this point in the history
#17755)

Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
  • Loading branch information
nfcampos authored Feb 19, 2024
1 parent 5ed16ad commit 07ee41d
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 37 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,6 @@ docs/docs/build
docs/docs/node_modules
docs/docs/yarn.lock
_dist
docs/docs/templates
docs/docs/templates

prof
3 changes: 3 additions & 0 deletions libs/core/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ tests:
test_watch:
poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests

test_profile:
poetry run pytest -vv tests/unit_tests/ --profile-svg

check_imports: $(shell find langchain_core -name '*.py')
poetry run python ./scripts/check_imports.py $^

Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
PromptValue,
StringPromptValue,
)
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.runnables.utils import create_model

if TYPE_CHECKING:
from langchain_core.documents import Document
Expand Down
22 changes: 2 additions & 20 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from langchain_core._api import beta_decorator
from langchain_core.load.dump import dumpd
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import BaseConfig, BaseModel, Field, create_model
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import (
RunnableConfig,
acall_func_with_variable_args,
Expand All @@ -65,6 +65,7 @@
accepts_config,
accepts_context,
accepts_run_manager,
create_model,
gather_with_concurrency,
get_function_first_arg_dict_keys,
get_function_nonlocals,
Expand Down Expand Up @@ -95,10 +96,6 @@
Other = TypeVar("Other")


class _SchemaConfig(BaseConfig):
arbitrary_types_allowed = True


class Runnable(Generic[Input, Output], ABC):
"""A unit of work that can be invoked, batched, streamed, transformed and composed.
Expand Down Expand Up @@ -301,7 +298,6 @@ def get_input_schema(
return create_model(
self.get_name("Input"),
__root__=(root_type, None),
__config__=_SchemaConfig,
)

@property
Expand Down Expand Up @@ -334,7 +330,6 @@ def get_output_schema(
return create_model(
self.get_name("Output"),
__root__=(root_type, None),
__config__=_SchemaConfig,
)

@property
Expand Down Expand Up @@ -371,15 +366,13 @@ def config_schema(
)
for spec in config_specs
},
__config__=_SchemaConfig,
)
if config_specs
else None
)

return create_model( # type: ignore[call-overload]
self.get_name("Config"),
__config__=_SchemaConfig,
**({"configurable": (configurable, None)} if configurable else {}),
**{
field_name: (field_type, None)
Expand Down Expand Up @@ -1691,7 +1684,6 @@ def _seq_input_schema(
for k, v in next_input_schema.__fields__.items()
if k not in first.mapper.steps
},
__config__=_SchemaConfig,
)
elif isinstance(first, RunnablePick):
return _seq_input_schema(steps[1:], config)
Expand Down Expand Up @@ -1724,7 +1716,6 @@ def _seq_output_schema(
for k, v in mapper_output_schema.__fields__.items()
},
},
__config__=_SchemaConfig,
)
elif isinstance(last, RunnablePick):
prev_output_schema = _seq_output_schema(steps[:-1], config)
Expand All @@ -1738,14 +1729,12 @@ def _seq_output_schema(
for k, v in prev_output_schema.__fields__.items()
if k in last.keys
},
__config__=_SchemaConfig,
)
else:
field = prev_output_schema.__fields__[last.keys]
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
__root__=(field.annotation, field.default),
__config__=_SchemaConfig,
)

return last.get_output_schema(config)
Expand Down Expand Up @@ -2598,7 +2587,6 @@ def get_input_schema(
for k, v in step.get_input_schema(config).__fields__.items()
if k != "__root__"
},
__config__=_SchemaConfig,
)

return super().get_input_schema(config)
Expand All @@ -2610,7 +2598,6 @@ def get_output_schema(
return create_model( # type: ignore[call-overload]
self.get_name("Output"),
**{k: (v.OutputType, None) for k, v in self.steps.items()},
__config__=_SchemaConfig,
)

@property
Expand Down Expand Up @@ -3250,13 +3237,11 @@ def get_input_schema(
return create_model(
self.get_name("Input"),
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
__config__=_SchemaConfig,
)
else:
return create_model(
self.get_name("Input"),
__root__=(List[Any], None),
__config__=_SchemaConfig,
)

if self.InputType != Any:
Expand All @@ -3266,7 +3251,6 @@ def get_input_schema(
return create_model(
self.get_name("Input"),
**{key: (Any, None) for key in dict_keys}, # type: ignore
__config__=_SchemaConfig,
)

return super().get_input_schema(config)
Expand Down Expand Up @@ -3756,7 +3740,6 @@ def get_input_schema(
List[self.bound.get_input_schema(config)], # type: ignore
None,
),
__config__=_SchemaConfig,
)

@property
Expand All @@ -3773,7 +3756,6 @@ def get_output_schema(
List[schema], # type: ignore
None,
),
__config__=_SchemaConfig,
)

@property
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.config import run_in_executor
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
create_model,
get_unique_config_specs,
)

Expand Down
8 changes: 6 additions & 2 deletions libs/core/langchain_core/runnables/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
cast,
)

from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
Other,
Runnable,
Expand All @@ -36,7 +36,11 @@
patch_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
from langchain_core.runnables.utils import (
AddableDict,
ConfigurableFieldSpec,
create_model,
)
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee

Expand Down
32 changes: 32 additions & 0 deletions libs/core/langchain_core/runnables/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import inspect
import textwrap
from functools import lru_cache
from inspect import signature
from itertools import groupby
from typing import (
Expand All @@ -21,10 +22,13 @@
Protocol,
Sequence,
Set,
Type,
TypeVar,
Union,
)

from langchain_core.pydantic_v1 import BaseConfig, BaseModel
from langchain_core.pydantic_v1 import create_model as _create_model_base
from langchain_core.runnables.schema import StreamEvent

Input = TypeVar("Input", contravariant=True)
Expand Down Expand Up @@ -489,3 +493,31 @@ def include_event(self, event: StreamEvent, root_type: str) -> bool:
)

return include


class _SchemaConfig(BaseConfig):
arbitrary_types_allowed = True
frozen = True


def create_model(
__model_name: str,
**field_definitions: Any,
) -> Type[BaseModel]:
try:
return _create_model_cached(__model_name, **field_definitions)
except TypeError:
# something in field definitions is not hashable
return _create_model_base(
__model_name, __config__=_SchemaConfig, **field_definitions
)


@lru_cache(maxsize=256)
def _create_model_cached(
__model_name: str,
**field_definitions: Any,
) -> Type[BaseModel]:
return _create_model_base(
__model_name, __config__=_SchemaConfig, **field_definitions
)
55 changes: 53 additions & 2 deletions libs/core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"


[tool.poetry.group.test_integration]
Expand Down
Loading

0 comments on commit 07ee41d

Please sign in to comment.