diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml index b5af5ea01..9b874bd3d 100644 --- a/.github/workflows/bench.yml +++ b/.github/workflows/bench.yml @@ -14,7 +14,6 @@ jobs: defaults: run: working-directory: libs/langgraph - needs: [baseline] steps: - uses: actions/checkout@v4 - id: files diff --git a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py index e777a8176..e91aa425d 100644 --- a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py +++ b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py @@ -16,7 +16,7 @@ IPv6Interface, IPv6Network, ) -from typing import Any, Optional +from typing import Any, Optional, Sequence from uuid import UUID from langchain_core.load.load import Reviver @@ -35,66 +35,70 @@ def _encode_constructor_args( constructor: type[Any], *, method: Optional[str] = None, - args: Optional[list[Any]] = None, + args: Optional[Sequence[Any]] = None, kwargs: Optional[dict[str, Any]] = None, ): - return { + out = { "lc": 2, "type": "constructor", - "id": [*constructor.__module__.split("."), constructor.__name__], - "method": method, - "args": args if args is not None else [], - "kwargs": kwargs if kwargs is not None else {}, + "id": (*constructor.__module__.split("."), constructor.__name__), } + if method is not None: + out["method"] = method + if args is not None: + out["args"] = args + if kwargs is not None: + out["kwargs"] = kwargs + return out def _default(self, obj): if isinstance(obj, Serializable): return obj.to_json() elif hasattr(obj, "model_dump") and callable(obj.model_dump): return self._encode_constructor_args( - obj.__class__, method=[None, "model_construct"], kwargs=obj.model_dump() + obj.__class__, method=(None, "model_construct"), kwargs=obj.model_dump() ) elif hasattr(obj, "dict") and callable(obj.dict): return self._encode_constructor_args( - obj.__class__, method=[None, "construct"], kwargs=obj.dict() + obj.__class__, method=(None, "construct"), kwargs=obj.dict() ) elif isinstance(obj, pathlib.Path): return self._encode_constructor_args(pathlib.Path, args=obj.parts) elif isinstance(obj, re.Pattern): return self._encode_constructor_args( - re.compile, args=[obj.pattern, obj.flags] + re.compile, args=(obj.pattern, obj.flags) ) elif isinstance(obj, UUID): - return self._encode_constructor_args(UUID, args=[obj.hex]) + return self._encode_constructor_args(UUID, args=(obj.hex,)) elif isinstance(obj, decimal.Decimal): - return self._encode_constructor_args(decimal.Decimal, args=[str(obj)]) + return self._encode_constructor_args(decimal.Decimal, args=(str(obj),)) elif isinstance(obj, (set, frozenset, deque)): - return self._encode_constructor_args(type(obj), args=[list(obj)]) + return self._encode_constructor_args(type(obj), args=(tuple(obj),)) elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)): - return self._encode_constructor_args(obj.__class__, args=[str(obj)]) + return self._encode_constructor_args(obj.__class__, args=(str(obj),)) elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)): - return self._encode_constructor_args(obj.__class__, args=[str(obj)]) + return self._encode_constructor_args(obj.__class__, args=(str(obj),)) elif isinstance(obj, datetime): return self._encode_constructor_args( - datetime, method="fromisoformat", args=[obj.isoformat()] + datetime, method="fromisoformat", args=(obj.isoformat(),) ) elif isinstance(obj, timezone): return self._encode_constructor_args(timezone, args=obj.__getinitargs__()) elif isinstance(obj, ZoneInfo): - return self._encode_constructor_args(ZoneInfo, args=[obj.key]) + return self._encode_constructor_args(ZoneInfo, args=(obj.key,)) elif isinstance(obj, timedelta): return self._encode_constructor_args( - timedelta, args=[obj.days, obj.seconds, obj.microseconds] + timedelta, args=(obj.days, obj.seconds, obj.microseconds) ) elif isinstance(obj, date): return self._encode_constructor_args( - date, args=[obj.year, obj.month, obj.day] + date, args=(obj.year, obj.month, obj.day) ) elif isinstance(obj, time): return self._encode_constructor_args( time, - args=[obj.hour, obj.minute, obj.second, obj.microsecond, obj.tzinfo], + args=(obj.hour, obj.minute, obj.second, obj.microsecond, obj.tzinfo), kwargs={"fold": obj.fold}, ) elif dataclasses.is_dataclass(obj): @@ -106,14 +110,14 @@ def _default(self, obj): }, ) elif isinstance(obj, Enum): - return self._encode_constructor_args(obj.__class__, args=[obj.value]) + return self._encode_constructor_args(obj.__class__, args=(obj.value,)) elif isinstance(obj, SendProtocol): return self._encode_constructor_args( obj.__class__, kwargs={"node": obj.node, "arg": obj.arg} ) elif isinstance(obj, (bytes, bytearray)): return self._encode_constructor_args( - obj.__class__, method="fromhex", args=[obj.hex()] + obj.__class__, method="fromhex", args=(obj.hex(),) ) elif isinstance(obj, BaseException): return repr(obj) @@ -136,25 +140,28 @@ def _reviver(self, value: dict[str, Any]) -> Any: # Import class cls = getattr(mod, name) # Instantiate class - if isinstance(value["method"], str): - methods = [getattr(cls, value["method"])] - elif isinstance(value["method"], list): + method = value.get("method") + if isinstance(method, str): + methods = [getattr(cls, method)] + elif isinstance(method, list): methods = [ cls if method is None else getattr(cls, method) - for method in value["method"] + for method in method ] else: methods = [cls] + args = value.get("args") + kwargs = value.get("kwargs") for method in methods: try: if isclass(method) and issubclass(method, BaseException): return None - if value["args"] and value["kwargs"]: - return method(*value["args"], **value["kwargs"]) - elif value["args"]: - return method(*value["args"]) - elif value["kwargs"]: - return method(**value["kwargs"]) + if args and kwargs: + return method(*args, **kwargs) + elif args: + return method(*args) + elif kwargs: + return method(**kwargs) else: return method() except Exception: diff --git a/libs/checkpoint/tests/test_jsonplus.py b/libs/checkpoint/tests/test_jsonplus.py index 0ead0ed53..22241a44a 100644 --- a/libs/checkpoint/tests/test_jsonplus.py +++ b/libs/checkpoint/tests/test_jsonplus.py @@ -122,7 +122,7 @@ def test_serde_jsonplus() -> None: assert dumped == ( "json", - b"""{"path": {"lc": 2, "type": "constructor", "id": ["pathlib", "Path"], "method": null, "args": ["foo", "bar"], "kwargs": {}}, "re": {"lc": 2, "type": "constructor", "id": ["re", "compile"], "method": null, "args": ["foo", 48], "kwargs": {}}, "decimal": {"lc": 2, "type": "constructor", "id": ["decimal", "Decimal"], "method": null, "args": ["1.10101"], "kwargs": {}}, "ip4": {"lc": 2, "type": "constructor", "id": ["ipaddress", "IPv4Address"], "method": null, "args": ["192.168.0.1"], "kwargs": {}}, "deque": {"lc": 2, "type": "constructor", "id": ["collections", "deque"], "method": null, "args": [[1, 2, 3]], "kwargs": {}}, "tzn": {"lc": 2, "type": "constructor", "id": ["zoneinfo", "ZoneInfo"], "method": null, "args": ["America/New_York"], "kwargs": {}}, "date": {"lc": 2, "type": "constructor", "id": ["datetime", "date"], "method": null, "args": [2024, 4, 19], "kwargs": {}}, "time": {"lc": 2, "type": "constructor", "id": ["datetime", "time"], "method": null, "args": [23, 4, 57, 51022, {"lc": 2, "type": "constructor", "id": ["datetime", "timezone"], "method": null, "args": [{"lc": 2, "type": "constructor", "id": ["datetime", "timedelta"], "method": null, "args": [0, 86340, 0], "kwargs": {}}], "kwargs": {}}], "kwargs": {"fold": 0}}, "uid": {"lc": 2, "type": "constructor", "id": ["uuid", "UUID"], "method": null, "args": ["00000000000000000000000000000001"], "kwargs": {}}, "timestamp": {"lc": 2, "type": "constructor", "id": ["datetime", "datetime"], "method": "fromisoformat", "args": ["2024-04-19T23:04:57.051022+23:59"], "kwargs": {}}, "my_slotted_class": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclassWSlots"], "method": null, "args": [], "kwargs": {"foo": "bar", "bar": 2}}, "my_dataclass": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclass"], "method": null, "args": [], "kwargs": {"foo": "foo", "bar": 1}}, "my_enum": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyEnum"], "method": null, "args": ["foo"], "kwargs": {}}, "my_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydantic"], "method": [null, "model_construct"], "args": [], "kwargs": {"foo": "foo", "bar": 1}}, "my_funny_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyFunnyPydantic"], "method": [null, "construct"], "args": [], "kwargs": {"foo": "foo", "bar": 1}}, "person": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "Person"], "method": null, "args": [], "kwargs": {"name": "foo"}}, "a_bool": true, "a_none": null, "a_str": "foo", "a_str_nuc": "foo\\u0000", "a_str_uc": "foo \xe2\x9b\xb0\xef\xb8\x8f", "a_str_ucuc": "foo \xe2\x9b\xb0\xef\xb8\x8f\\u0000", "a_str_ucucuc": "foo \\\\u26f0\\\\ufe0f", "text": ["Hello", "Python", "Surrogate", "Example", "String", "With", "Surrogates", "Embedded", "In", "The", "Text", "\xe6\x94\xb6\xe8\x8a\xb1\xf0\x9f\x99\x84\xc2\xb7\xe5\x88\xb0"], "an_int": 1, "a_float": 1.1, "runnable_map": {"lc": 1, "type": "constructor", "id": ["langchain", "schema", "runnable", "RunnableParallel"], "kwargs": {"steps__": {}}, "name": "RunnableParallel<>", "graph": {"nodes": [{"id": 0, "type": "schema", "data": "Parallel<>Input"}, {"id": 1, "type": "schema", "data": "Parallel<>Output"}], "edges": []}}}""", + b"""{"path": {"lc": 2, "type": "constructor", "id": ["pathlib", "Path"], "args": ["foo", "bar"]}, "re": {"lc": 2, "type": "constructor", "id": ["re", "compile"], "args": ["foo", 48]}, "decimal": {"lc": 2, "type": "constructor", "id": ["decimal", "Decimal"], "args": ["1.10101"]}, "ip4": {"lc": 2, "type": "constructor", "id": ["ipaddress", "IPv4Address"], "args": ["192.168.0.1"]}, "deque": {"lc": 2, "type": "constructor", "id": ["collections", "deque"], "args": [[1, 2, 3]]}, "tzn": {"lc": 2, "type": "constructor", "id": ["zoneinfo", "ZoneInfo"], "args": ["America/New_York"]}, "date": {"lc": 2, "type": "constructor", "id": ["datetime", "date"], "args": [2024, 4, 19]}, "time": {"lc": 2, "type": "constructor", "id": ["datetime", "time"], "args": [23, 4, 57, 51022, {"lc": 2, "type": "constructor", "id": ["datetime", "timezone"], "args": [{"lc": 2, "type": "constructor", "id": ["datetime", "timedelta"], "args": [0, 86340, 0]}]}], "kwargs": {"fold": 0}}, "uid": {"lc": 2, "type": "constructor", "id": ["uuid", "UUID"], "args": ["00000000000000000000000000000001"]}, "timestamp": {"lc": 2, "type": "constructor", "id": ["datetime", "datetime"], "method": "fromisoformat", "args": ["2024-04-19T23:04:57.051022+23:59"]}, "my_slotted_class": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclassWSlots"], "kwargs": {"foo": "bar", "bar": 2}}, "my_dataclass": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclass"], "kwargs": {"foo": "foo", "bar": 1}}, "my_enum": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyEnum"], "args": ["foo"]}, "my_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydantic"], "method": [null, "model_construct"], "kwargs": {"foo": "foo", "bar": 1}}, "my_funny_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyFunnyPydantic"], "method": [null, "construct"], "kwargs": {"foo": "foo", "bar": 1}}, "person": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "Person"], "kwargs": {"name": "foo"}}, "a_bool": true, "a_none": null, "a_str": "foo", "a_str_nuc": "foo\\u0000", "a_str_uc": "foo \xe2\x9b\xb0\xef\xb8\x8f", "a_str_ucuc": "foo \xe2\x9b\xb0\xef\xb8\x8f\\u0000", "a_str_ucucuc": "foo \\\\u26f0\\\\ufe0f", "text": ["Hello", "Python", "Surrogate", "Example", "String", "With", "Surrogates", "Embedded", "In", "The", "Text", "\xe6\x94\xb6\xe8\x8a\xb1\xf0\x9f\x99\x84\xc2\xb7\xe5\x88\xb0"], "an_int": 1, "a_float": 1.1, "runnable_map": {"lc": 1, "type": "constructor", "id": ["langchain", "schema", "runnable", "RunnableParallel"], "kwargs": {"steps__": {}}, "name": "RunnableParallel<>", "graph": {"nodes": [{"id": 0, "type": "schema", "data": "Parallel<>Input"}, {"id": 1, "type": "schema", "data": "Parallel<>Output"}], "edges": []}}}""", ) assert serde.loads_typed(dumped) == { diff --git a/libs/langgraph/Makefile b/libs/langgraph/Makefile index 4f5b9ad82..0173351d6 100644 --- a/libs/langgraph/Makefile +++ b/libs/langgraph/Makefile @@ -13,8 +13,14 @@ OUTPUT ?= out/benchmark.json benchmark: mkdir -p out + rm -f $(OUTPUT) poetry run python -m bench -o $(OUTPUT) --rigorous +benchmark-fast: + mkdir -p out + rm -f $(OUTPUT) + poetry run python -m bench -o $(OUTPUT) --fast + GRAPH ?= bench/fanout_to_subgraph.py profile: diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index bfa9792ee..809531f5f 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -559,7 +559,9 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any: channels=(list(input_values) if is_single_input else input_values), # coerce state dict to schema class (eg. pydantic model) mapper=( - None if is_single_input else partial(_coerce_state, input_schema) + None + if is_single_input or issubclass(input_schema, dict) + else partial(_coerce_state, input_schema) ), writers=[ # publish to this channel and state keys @@ -669,7 +671,11 @@ def _get_state_reader( select=select[0] if select == ["__root__"] else select, fresh=True, # coerce state dict to schema class (eg. pydantic model) - mapper=(None if state_keys == ["__root__"] else partial(_coerce_state, schema)), + mapper=( + None + if state_keys == ["__root__"] or issubclass(schema, dict) + else partial(_coerce_state, schema) + ), ) diff --git a/libs/langgraph/langgraph/pregel/get_state.py b/libs/langgraph/langgraph/pregel/get_state.py deleted file mode 100644 index 79e5f0bf8..000000000 --- a/libs/langgraph/langgraph/pregel/get_state.py +++ /dev/null @@ -1,36 +0,0 @@ -from langgraph.constants import NS_SEP -from langgraph.pregel.types import StateSnapshot - - -def assemble_state_snapshot_hierarchy( - root_checkpoint_ns: str, - checkpoint_ns_to_state_snapshots: dict[str, StateSnapshot], -) -> StateSnapshot: - checkpoint_ns_list_to_visit = sorted( - checkpoint_ns_to_state_snapshots.keys(), - key=lambda x: len(x.split(NS_SEP)), - ) - while checkpoint_ns_list_to_visit: - checkpoint_ns = checkpoint_ns_list_to_visit.pop() - state_snapshot = checkpoint_ns_to_state_snapshots[checkpoint_ns] - *path, subgraph_node = checkpoint_ns.split(NS_SEP) - parent_checkpoint_ns = NS_SEP.join(path) - if subgraph_node and ( - parent_state_snapshot := checkpoint_ns_to_state_snapshots.get( - parent_checkpoint_ns - ) - ): - parent_subgraph_snapshots = { - **(parent_state_snapshot.subgraphs or {}), - subgraph_node: state_snapshot, - } - checkpoint_ns_to_state_snapshots[parent_checkpoint_ns] = ( - checkpoint_ns_to_state_snapshots[ - parent_checkpoint_ns - ]._replace(subgraphs=parent_subgraph_snapshots) - ) - - state_snapshot = checkpoint_ns_to_state_snapshots.pop(root_checkpoint_ns, None) - if state_snapshot is None: - raise ValueError(f"Missing checkpoint for checkpoint NS '{root_checkpoint_ns}'") - return state_snapshot diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 2cf6fab5e..1b4077bd6 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -659,7 +659,7 @@ def __enter__(self) -> Self: **saved.config.get("configurable", {}), }, } - self.checkpoint = copy_checkpoint(saved.checkpoint) + self.checkpoint = saved.checkpoint self.checkpoint_metadata = saved.metadata self.checkpoint_pending_writes = ( [(str(tid), k, v) for tid, k, v in saved.pending_writes] @@ -777,7 +777,7 @@ async def __aenter__(self) -> Self: **saved.config.get("configurable", {}), }, } - self.checkpoint = copy_checkpoint(saved.checkpoint) + self.checkpoint = saved.checkpoint self.checkpoint_metadata = saved.metadata self.checkpoint_pending_writes = ( [(str(tid), k, v) for tid, k, v in saved.pending_writes]