Skip to content

Commit

Permalink
Merge branch 'main' into brace/agent-survey-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul authored Sep 6, 2024
2 parents 5b511a2 + a424eb4 commit 8f91ad0
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 62 deletions.
8 changes: 4 additions & 4 deletions docs/docs/cloud/deployment/setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ A LangGraph application must be configured with a [LangGraph API configuration f
This walkthrough is based on [this repository](https://github.com/langchain-ai/langgraph-example), which you can play around with to learn more about how to setup your LangGraph application for deployment.

!!! tip "Setup with pyproject.toml"
If you prefer using poetry for dependency management, check out [this how-to guide](./setup_pyproject.md) on using `pyproject.toml` for LangGraph Cloud.
If you prefer using poetry for dependency management, check out [this how-to guide](./setup_pyproject.md) on using `pyproject.toml` for LangGraph Cloud.

!!! tip "Setup with a Monorepo"
If you are interested in deploying a graph located inside a monorepo, take a look at [this](https://github.com/langchain-ai/langgraph-example-monorepo) repository for an example of how to do so.
If you are interested in deploying a graph located inside a monorepo, take a look at [this](https://github.com/langchain-ai/langgraph-example-monorepo) repository for an example of how to do so.

The final repo structure will look something like this:

Expand Down Expand Up @@ -129,7 +129,7 @@ graph = workflow.compile()
```

!!! warning "Assign `CompiledGraph` to Variable"
The build process for LangGraph Cloud requires that the `CompiledGraph` object be assigned to a variable at the top-level of a Python module (alternatively, you can provide [a function that creates a graph](./graph_rebuild.md)).
The build process for LangGraph Cloud requires that the `CompiledGraph` object be assigned to a variable at the top-level of a Python module (alternatively, you can provide [a function that creates a graph](./graph_rebuild.md)).

Example file directory:

Expand Down Expand Up @@ -166,7 +166,7 @@ Example `langgraph.json` file:
Note that the variable name of the `CompiledGraph` appears at the end of the value of each subkey in the top-level `graphs` key (i.e. `:<variable_name>`).

!!! warning "Configuration Location"
The LangGraph API configuration file must be placed in a directory that is at the same level or higher than the Python files that contain compiled graphs and associated dependencies.
The LangGraph API configuration file must be placed in a directory that is at the same level or higher than the Python files that contain compiled graphs and associated dependencies.

Example file directory:

Expand Down
2 changes: 1 addition & 1 deletion docs/docs/cloud/deployment/setup_pyproject.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ A LangGraph application must be configured with a [LangGraph API configuration f
This walkthrough is based on [this repository](https://github.com/langchain-ai/langgraph-example-pyproject), which you can play around with to learn more about how to setup your LangGraph application for deployment.

!!! tip "Setup with requirements.txt"
If you prefer using `requirements.txt` for dependency management, check out [this how-to guide](./setup.md).
If you prefer using `requirements.txt` for dependency management, check out [this how-to guide](./setup.md).

!!! tip "Setup with a Monorepo"
If you are interested in deploying a graph located inside a monorepo, take a look at [this](https://github.com/langchain-ai/langgraph-example-monorepo) repository for an example of how to do so.
Expand Down
95 changes: 57 additions & 38 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def __init__(
self.input = input
self.output = output
self._add_schema(state_schema)
self._add_schema(input)
self._add_schema(output)
self._add_schema(input, allow_managed=False)
self._add_schema(output, allow_managed=False)
self.config_schema = config_schema
self.waiting_edges: set[tuple[tuple[str, ...], str]] = set()

Expand All @@ -162,10 +162,17 @@ def _all_edges(self) -> set[tuple[str, str]]:
(start, end) for starts, end in self.waiting_edges for start in starts
}

def _add_schema(self, schema: Type[Any]) -> None:
def _add_schema(self, schema: Type[Any], /, allow_managed: bool = True) -> None:
if schema not in self.schemas:
_warn_invalid_state_schema(schema)
channels, managed = _get_channels(schema)
if managed and not allow_managed:
names = ", ".join(managed)
schema_name = getattr(schema, "__name__", "")
raise ValueError(
f"Invalid managed channels detected in {schema_name}: {names}."
" Managed channels are not permitted in Input/Output schema."
)
self.schemas[schema] = {**channels, **managed}
for key, channel in channels.items():
if key in self.channels:
Expand Down Expand Up @@ -474,45 +481,22 @@ class CompiledStateGraph(CompiledGraph):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
if isclass(self.builder.input) and issubclass(
self.builder.input, (BaseModel, BaseModelV1)
):
return self.builder.input
else:
keys = list(self.builder.schemas[self.builder.input].keys())
if len(keys) == 1 and keys[0] == "__root__":
return create_model( # type: ignore[call-overload]
self.get_name("Input"),
__root__=(self.channels[keys[0]].UpdateType, None),
)
else:
return create_model( # type: ignore[call-overload]
self.get_name("Input"),
**{
k: (
self.channels[k].UpdateType,
(
get_field_default(
k,
self.channels[k].UpdateType,
self.builder.input,
)
),
)
for k in self.builder.schemas[self.builder.input]
if isinstance(self.channels[k], BaseChannel)
},
)
return _get_schema(
typ=self.builder.input,
schemas=self.builder.schemas,
channels=self.builder.channels,
name=self.get_name("Input"),
)

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
if isclass(self.builder.output) and issubclass(
self.builder.output, (BaseModel, BaseModelV1)
):
return self.builder.output

return super().get_output_schema(config)
return _get_schema(
typ=self.builder.output,
schemas=self.builder.schemas,
channels=self.builder.channels,
name=self.get_name("Output"),
)

def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None:
if key == START:
Expand Down Expand Up @@ -772,3 +756,38 @@ def _is_field_managed_value(name: str, typ: Type[Any]) -> Optional[ManagedValueS
return decoration

return None


def _get_schema(
typ: Type,
schemas: dict,
channels: dict,
name: str,
) -> type[BaseModel]:
if isclass(typ) and issubclass(typ, (BaseModel, BaseModelV1)):
return typ
else:
keys = list(schemas[typ].keys())
if len(keys) == 1 and keys[0] == "__root__":
return create_model( # type: ignore[call-overload]
name,
__root__=(channels[keys[0]].UpdateType, None),
)
else:
return create_model( # type: ignore[call-overload]
name,
**{
k: (
channels[k].UpdateType,
(
get_field_default(
k,
channels[k].UpdateType,
typ,
)
),
)
for k in schemas[typ]
if k in channels and isinstance(channels[k], BaseChannel)
},
)
21 changes: 14 additions & 7 deletions libs/langgraph/langgraph/utils/fields.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import dataclasses
from typing import Any, Optional, Type, Union

from typing_extensions import (
Annotated,
NotRequired,
ReadOnly,
Required,
get_origin,
)
from typing_extensions import Annotated, NotRequired, ReadOnly, Required, get_origin


def _is_optional_type(type_: Any) -> bool:
Expand Down Expand Up @@ -92,6 +87,18 @@ def get_field_default(name: str, type_: Any, schema: Type[Any]) -> Any:
return ...
# Handle NotRequired[<type>] for earlier versions of python
return None
if dataclasses.is_dataclass(schema):
field_info = next(
(f for f in dataclasses.fields(schema) if f.name == name), None
)
if field_info:
if (
field_info.default is not dataclasses.MISSING
and field_info.default is not ...
):
return field_info.default
elif field_info.default_factory is not dataclasses.MISSING:
return field_info.default_factory()
# Note, we ignore ReadOnly attributes,
# as they don't make much sense. (we don't care if you mutate the state in your node)
# and mutating state in your node has no effect on our graph state.
Expand Down
16 changes: 8 additions & 8 deletions libs/langgraph/tests/__snapshots__/test_pregel.ambr

Large diffs are not rendered by default.

140 changes: 138 additions & 2 deletions libs/langgraph/tests/test_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
import warnings
from dataclasses import dataclass, field
from typing import Annotated as Annotated2
from typing import Any, Optional

Expand All @@ -8,6 +10,7 @@
from typing_extensions import Annotated, NotRequired, Required, TypedDict

from langgraph.graph.state import StateGraph, _warn_invalid_state_schema
from langgraph.managed.shared_value import SharedValue


class State(BaseModel):
Expand Down Expand Up @@ -104,14 +107,25 @@ class InputState(SomeParentState, total=total_): # type: ignore
val5: Annotated[Required[str], "foo"]
val6: Annotated[NotRequired[str], "bar"]

class OutputState(SomeParentState, total=total_): # type: ignore
out_val1: str
out_val2: Optional[str]
out_val3: Required[str]
out_val4: NotRequired[dict]
out_val5: Annotated[Required[str], "foo"]
out_val6: Annotated[NotRequired[str], "bar"]

class State(InputState): # this would be ignored
val4: dict
some_shared_channel: Annotated[str, SharedValue.on("assistant_id")] = field(
default="foo"
)

builder = StateGraph(State, input=InputState)
builder = StateGraph(State, input=InputState, output=OutputState)
builder.add_node("n", lambda x: x)
builder.add_edge("__start__", "n")
graph = builder.compile()
model = graph.input_schema
model = graph.get_input_schema()
json_schema = model.schema()

if total_ is False:
Expand All @@ -130,3 +144,125 @@ class State(InputState): # this would be ignored
assert (
set(json_schema["properties"].keys()) == expected_required | expected_optional
)

# Check output schema. Should be the same process
output_schema = graph.get_output_schema().schema()
if total_ is False:
expected_required = set()
expected_optional = {"out_val2", "out_val1"}
else:
expected_required = {"out_val1"}
expected_optional = {"out_val2"}

expected_required |= {"val0a", "out_val3", "out_val5"}
expected_optional |= {"val0b", "out_val4", "out_val6"}

assert set(output_schema.get("required", set())) == expected_required
assert (
set(output_schema["properties"].keys()) == expected_required | expected_optional
)


@pytest.mark.parametrize("kw_only_", [False, True])
def test_state_schema_default_values(kw_only_: bool):
kwargs = {}
if "kw_only" in inspect.signature(dataclass).parameters:
kwargs = {"kw_only": kw_only_}

@dataclass(**kwargs)
class InputState:
val1: str
val2: Optional[int]
val3: Annotated[Optional[float], "optional annotated"]
val4: Optional[str] = None
val5: list[int] = field(default_factory=lambda: [1, 2, 3])
val6: dict[str, int] = field(default_factory=lambda: {"a": 1})
val7: str = field(default=...)
val8: Annotated[int, "some metadata"] = 42
val9: Annotated[str, "more metadata"] = field(default="some foo")
val10: str = "default"
val11: Annotated[list[str], "annotated list"] = field(
default_factory=lambda: ["a", "b"]
)
some_shared_channel: Annotated[str, SharedValue.on("assistant_id")] = field(
default="foo"
)

builder = StateGraph(InputState)
builder.add_node("n", lambda x: x)
builder.add_edge("__start__", "n")
graph = builder.compile()
for model in [graph.get_input_schema(), graph.get_output_schema()]:
json_schema = model.schema()

expected_required = {"val1", "val7"}
expected_optional = {
"val2",
"val3",
"val4",
"val5",
"val6",
"val8",
"val9",
"val10",
"val11",
}

assert set(json_schema.get("required", set())) == expected_required
assert (
set(json_schema["properties"].keys()) == expected_required | expected_optional
)


def test_raises_invalid_managed():
class BadInputState(TypedDict):
some_thing: str
some_input_channel: Annotated[str, SharedValue.on("assistant_id")]

class InputState(TypedDict):
some_thing: str
some_input_channel: str

class BadOutputState(TypedDict):
some_thing: str
some_output_channel: Annotated[str, SharedValue.on("assistant_id")]

class OutputState(TypedDict):
some_thing: str
some_output_channel: str

class State(TypedDict):
some_thing: str
some_channel: Annotated[str, SharedValue.on("assistant_id")]

# All OK
StateGraph(State, input=InputState, output=OutputState)
StateGraph(State)
StateGraph(State, input=State, output=State)
StateGraph(State, input=InputState)
StateGraph(State, input=InputState)

bad_input_examples = [
(State, BadInputState, OutputState),
(State, BadInputState, BadOutputState),
(State, BadInputState, State),
(State, BadInputState, None),
]
for _state, _inp, _outp in bad_input_examples:
with pytest.raises(
ValueError,
match="Invalid managed channels detected in BadInputState: some_input_channel. Managed channels are not permitted in Input/Output schema.",
):
StateGraph(_state, input=_inp, output=_outp)
bad_output_examples = [
(State, InputState, BadOutputState),
(None, InputState, BadOutputState),
(None, State, BadOutputState),
(State, None, BadOutputState),
]
for _state, _inp, _outp in bad_output_examples:
with pytest.raises(
ValueError,
match="Invalid managed channels detected in BadOutputState: some_output_channel. Managed channels are not permitted in Input/Output schema.",
):
StateGraph(_state, input=_inp, output=_outp)
Loading

0 comments on commit 8f91ad0

Please sign in to comment.