Skip to content

Commit

Permalink
Nc/12sep/perf (#1702)
Browse files Browse the repository at this point in the history
* Remove unused fil;e

* Add benchmark-fast command for running locally

* Small improvements to jsonplus serializer

* Don't use PregelNode.mapper when schema is a typed dict

- All it would do is create a new copy of same dict

* Avoid copying checkpoint when fetching at beginning of loop

* Fix needs array

* Update tests
  • Loading branch information
nfcampos authored Sep 12, 2024
1 parent 8e578a0 commit 95b97d6
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 74 deletions.
1 change: 0 additions & 1 deletion .github/workflows/bench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ jobs:
defaults:
run:
working-directory: libs/langgraph
needs: [baseline]
steps:
- uses: actions/checkout@v4
- id: files
Expand Down
71 changes: 39 additions & 32 deletions libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
IPv6Interface,
IPv6Network,
)
from typing import Any, Optional
from typing import Any, Optional, Sequence
from uuid import UUID

from langchain_core.load.load import Reviver
Expand All @@ -35,66 +35,70 @@ def _encode_constructor_args(
constructor: type[Any],
*,
method: Optional[str] = None,
args: Optional[list[Any]] = None,
args: Optional[Sequence[Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
):
return {
out = {
"lc": 2,
"type": "constructor",
"id": [*constructor.__module__.split("."), constructor.__name__],
"method": method,
"args": args if args is not None else [],
"kwargs": kwargs if kwargs is not None else {},
"id": (*constructor.__module__.split("."), constructor.__name__),
}
if method is not None:
out["method"] = method
if args is not None:
out["args"] = args
if kwargs is not None:
out["kwargs"] = kwargs
return out

def _default(self, obj):
if isinstance(obj, Serializable):
return obj.to_json()
elif hasattr(obj, "model_dump") and callable(obj.model_dump):
return self._encode_constructor_args(
obj.__class__, method=[None, "model_construct"], kwargs=obj.model_dump()
obj.__class__, method=(None, "model_construct"), kwargs=obj.model_dump()
)
elif hasattr(obj, "dict") and callable(obj.dict):
return self._encode_constructor_args(
obj.__class__, method=[None, "construct"], kwargs=obj.dict()
obj.__class__, method=(None, "construct"), kwargs=obj.dict()
)
elif isinstance(obj, pathlib.Path):
return self._encode_constructor_args(pathlib.Path, args=obj.parts)
elif isinstance(obj, re.Pattern):
return self._encode_constructor_args(
re.compile, args=[obj.pattern, obj.flags]
re.compile, args=(obj.pattern, obj.flags)
)
elif isinstance(obj, UUID):
return self._encode_constructor_args(UUID, args=[obj.hex])
return self._encode_constructor_args(UUID, args=(obj.hex,))
elif isinstance(obj, decimal.Decimal):
return self._encode_constructor_args(decimal.Decimal, args=[str(obj)])
return self._encode_constructor_args(decimal.Decimal, args=(str(obj),))
elif isinstance(obj, (set, frozenset, deque)):
return self._encode_constructor_args(type(obj), args=[list(obj)])
return self._encode_constructor_args(type(obj), args=(tuple(obj),))
elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)):
return self._encode_constructor_args(obj.__class__, args=[str(obj)])
return self._encode_constructor_args(obj.__class__, args=(str(obj),))
elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)):
return self._encode_constructor_args(obj.__class__, args=[str(obj)])
return self._encode_constructor_args(obj.__class__, args=(str(obj),))

elif isinstance(obj, datetime):
return self._encode_constructor_args(
datetime, method="fromisoformat", args=[obj.isoformat()]
datetime, method="fromisoformat", args=(obj.isoformat(),)
)
elif isinstance(obj, timezone):
return self._encode_constructor_args(timezone, args=obj.__getinitargs__())
elif isinstance(obj, ZoneInfo):
return self._encode_constructor_args(ZoneInfo, args=[obj.key])
return self._encode_constructor_args(ZoneInfo, args=(obj.key,))
elif isinstance(obj, timedelta):
return self._encode_constructor_args(
timedelta, args=[obj.days, obj.seconds, obj.microseconds]
timedelta, args=(obj.days, obj.seconds, obj.microseconds)
)
elif isinstance(obj, date):
return self._encode_constructor_args(
date, args=[obj.year, obj.month, obj.day]
date, args=(obj.year, obj.month, obj.day)
)
elif isinstance(obj, time):
return self._encode_constructor_args(
time,
args=[obj.hour, obj.minute, obj.second, obj.microsecond, obj.tzinfo],
args=(obj.hour, obj.minute, obj.second, obj.microsecond, obj.tzinfo),
kwargs={"fold": obj.fold},
)
elif dataclasses.is_dataclass(obj):
Expand All @@ -106,14 +110,14 @@ def _default(self, obj):
},
)
elif isinstance(obj, Enum):
return self._encode_constructor_args(obj.__class__, args=[obj.value])
return self._encode_constructor_args(obj.__class__, args=(obj.value,))
elif isinstance(obj, SendProtocol):
return self._encode_constructor_args(
obj.__class__, kwargs={"node": obj.node, "arg": obj.arg}
)
elif isinstance(obj, (bytes, bytearray)):
return self._encode_constructor_args(
obj.__class__, method="fromhex", args=[obj.hex()]
obj.__class__, method="fromhex", args=(obj.hex(),)
)
elif isinstance(obj, BaseException):
return repr(obj)
Expand All @@ -136,25 +140,28 @@ def _reviver(self, value: dict[str, Any]) -> Any:
# Import class
cls = getattr(mod, name)
# Instantiate class
if isinstance(value["method"], str):
methods = [getattr(cls, value["method"])]
elif isinstance(value["method"], list):
method = value.get("method")
if isinstance(method, str):
methods = [getattr(cls, method)]
elif isinstance(method, list):
methods = [
cls if method is None else getattr(cls, method)
for method in value["method"]
for method in method
]
else:
methods = [cls]
args = value.get("args")
kwargs = value.get("kwargs")
for method in methods:
try:
if isclass(method) and issubclass(method, BaseException):
return None
if value["args"] and value["kwargs"]:
return method(*value["args"], **value["kwargs"])
elif value["args"]:
return method(*value["args"])
elif value["kwargs"]:
return method(**value["kwargs"])
if args and kwargs:
return method(*args, **kwargs)
elif args:
return method(*args)
elif kwargs:
return method(**kwargs)
else:
return method()
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint/tests/test_jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_serde_jsonplus() -> None:

assert dumped == (
"json",
b"""{"path": {"lc": 2, "type": "constructor", "id": ["pathlib", "Path"], "method": null, "args": ["foo", "bar"], "kwargs": {}}, "re": {"lc": 2, "type": "constructor", "id": ["re", "compile"], "method": null, "args": ["foo", 48], "kwargs": {}}, "decimal": {"lc": 2, "type": "constructor", "id": ["decimal", "Decimal"], "method": null, "args": ["1.10101"], "kwargs": {}}, "ip4": {"lc": 2, "type": "constructor", "id": ["ipaddress", "IPv4Address"], "method": null, "args": ["192.168.0.1"], "kwargs": {}}, "deque": {"lc": 2, "type": "constructor", "id": ["collections", "deque"], "method": null, "args": [[1, 2, 3]], "kwargs": {}}, "tzn": {"lc": 2, "type": "constructor", "id": ["zoneinfo", "ZoneInfo"], "method": null, "args": ["America/New_York"], "kwargs": {}}, "date": {"lc": 2, "type": "constructor", "id": ["datetime", "date"], "method": null, "args": [2024, 4, 19], "kwargs": {}}, "time": {"lc": 2, "type": "constructor", "id": ["datetime", "time"], "method": null, "args": [23, 4, 57, 51022, {"lc": 2, "type": "constructor", "id": ["datetime", "timezone"], "method": null, "args": [{"lc": 2, "type": "constructor", "id": ["datetime", "timedelta"], "method": null, "args": [0, 86340, 0], "kwargs": {}}], "kwargs": {}}], "kwargs": {"fold": 0}}, "uid": {"lc": 2, "type": "constructor", "id": ["uuid", "UUID"], "method": null, "args": ["00000000000000000000000000000001"], "kwargs": {}}, "timestamp": {"lc": 2, "type": "constructor", "id": ["datetime", "datetime"], "method": "fromisoformat", "args": ["2024-04-19T23:04:57.051022+23:59"], "kwargs": {}}, "my_slotted_class": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclassWSlots"], "method": null, "args": [], "kwargs": {"foo": "bar", "bar": 2}}, "my_dataclass": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclass"], "method": null, "args": [], "kwargs": {"foo": "foo", "bar": 1}}, "my_enum": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyEnum"], "method": null, "args": ["foo"], "kwargs": {}}, "my_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydantic"], "method": [null, "model_construct"], "args": [], "kwargs": {"foo": "foo", "bar": 1}}, "my_funny_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyFunnyPydantic"], "method": [null, "construct"], "args": [], "kwargs": {"foo": "foo", "bar": 1}}, "person": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "Person"], "method": null, "args": [], "kwargs": {"name": "foo"}}, "a_bool": true, "a_none": null, "a_str": "foo", "a_str_nuc": "foo\\u0000", "a_str_uc": "foo \xe2\x9b\xb0\xef\xb8\x8f", "a_str_ucuc": "foo \xe2\x9b\xb0\xef\xb8\x8f\\u0000", "a_str_ucucuc": "foo \\\\u26f0\\\\ufe0f", "text": ["Hello", "Python", "Surrogate", "Example", "String", "With", "Surrogates", "Embedded", "In", "The", "Text", "\xe6\x94\xb6\xe8\x8a\xb1\xf0\x9f\x99\x84\xc2\xb7\xe5\x88\xb0"], "an_int": 1, "a_float": 1.1, "runnable_map": {"lc": 1, "type": "constructor", "id": ["langchain", "schema", "runnable", "RunnableParallel"], "kwargs": {"steps__": {}}, "name": "RunnableParallel<>", "graph": {"nodes": [{"id": 0, "type": "schema", "data": "Parallel<>Input"}, {"id": 1, "type": "schema", "data": "Parallel<>Output"}], "edges": []}}}""",
b"""{"path": {"lc": 2, "type": "constructor", "id": ["pathlib", "Path"], "args": ["foo", "bar"]}, "re": {"lc": 2, "type": "constructor", "id": ["re", "compile"], "args": ["foo", 48]}, "decimal": {"lc": 2, "type": "constructor", "id": ["decimal", "Decimal"], "args": ["1.10101"]}, "ip4": {"lc": 2, "type": "constructor", "id": ["ipaddress", "IPv4Address"], "args": ["192.168.0.1"]}, "deque": {"lc": 2, "type": "constructor", "id": ["collections", "deque"], "args": [[1, 2, 3]]}, "tzn": {"lc": 2, "type": "constructor", "id": ["zoneinfo", "ZoneInfo"], "args": ["America/New_York"]}, "date": {"lc": 2, "type": "constructor", "id": ["datetime", "date"], "args": [2024, 4, 19]}, "time": {"lc": 2, "type": "constructor", "id": ["datetime", "time"], "args": [23, 4, 57, 51022, {"lc": 2, "type": "constructor", "id": ["datetime", "timezone"], "args": [{"lc": 2, "type": "constructor", "id": ["datetime", "timedelta"], "args": [0, 86340, 0]}]}], "kwargs": {"fold": 0}}, "uid": {"lc": 2, "type": "constructor", "id": ["uuid", "UUID"], "args": ["00000000000000000000000000000001"]}, "timestamp": {"lc": 2, "type": "constructor", "id": ["datetime", "datetime"], "method": "fromisoformat", "args": ["2024-04-19T23:04:57.051022+23:59"]}, "my_slotted_class": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclassWSlots"], "kwargs": {"foo": "bar", "bar": 2}}, "my_dataclass": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclass"], "kwargs": {"foo": "foo", "bar": 1}}, "my_enum": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyEnum"], "args": ["foo"]}, "my_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydantic"], "method": [null, "model_construct"], "kwargs": {"foo": "foo", "bar": 1}}, "my_funny_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyFunnyPydantic"], "method": [null, "construct"], "kwargs": {"foo": "foo", "bar": 1}}, "person": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "Person"], "kwargs": {"name": "foo"}}, "a_bool": true, "a_none": null, "a_str": "foo", "a_str_nuc": "foo\\u0000", "a_str_uc": "foo \xe2\x9b\xb0\xef\xb8\x8f", "a_str_ucuc": "foo \xe2\x9b\xb0\xef\xb8\x8f\\u0000", "a_str_ucucuc": "foo \\\\u26f0\\\\ufe0f", "text": ["Hello", "Python", "Surrogate", "Example", "String", "With", "Surrogates", "Embedded", "In", "The", "Text", "\xe6\x94\xb6\xe8\x8a\xb1\xf0\x9f\x99\x84\xc2\xb7\xe5\x88\xb0"], "an_int": 1, "a_float": 1.1, "runnable_map": {"lc": 1, "type": "constructor", "id": ["langchain", "schema", "runnable", "RunnableParallel"], "kwargs": {"steps__": {}}, "name": "RunnableParallel<>", "graph": {"nodes": [{"id": 0, "type": "schema", "data": "Parallel<>Input"}, {"id": 1, "type": "schema", "data": "Parallel<>Output"}], "edges": []}}}""",
)

assert serde.loads_typed(dumped) == {
Expand Down
6 changes: 6 additions & 0 deletions libs/langgraph/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ OUTPUT ?= out/benchmark.json

benchmark:
mkdir -p out
rm -f $(OUTPUT)
poetry run python -m bench -o $(OUTPUT) --rigorous

benchmark-fast:
mkdir -p out
rm -f $(OUTPUT)
poetry run python -m bench -o $(OUTPUT) --fast

GRAPH ?= bench/fanout_to_subgraph.py

profile:
Expand Down
10 changes: 8 additions & 2 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,9 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
channels=(list(input_values) if is_single_input else input_values),
# coerce state dict to schema class (eg. pydantic model)
mapper=(
None if is_single_input else partial(_coerce_state, input_schema)
None
if is_single_input or issubclass(input_schema, dict)
else partial(_coerce_state, input_schema)
),
writers=[
# publish to this channel and state keys
Expand Down Expand Up @@ -669,7 +671,11 @@ def _get_state_reader(
select=select[0] if select == ["__root__"] else select,
fresh=True,
# coerce state dict to schema class (eg. pydantic model)
mapper=(None if state_keys == ["__root__"] else partial(_coerce_state, schema)),
mapper=(
None
if state_keys == ["__root__"] or issubclass(schema, dict)
else partial(_coerce_state, schema)
),
)


Expand Down
36 changes: 0 additions & 36 deletions libs/langgraph/langgraph/pregel/get_state.py

This file was deleted.

4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def __enter__(self) -> Self:
**saved.config.get("configurable", {}),
},
}
self.checkpoint = copy_checkpoint(saved.checkpoint)
self.checkpoint = saved.checkpoint
self.checkpoint_metadata = saved.metadata
self.checkpoint_pending_writes = (
[(str(tid), k, v) for tid, k, v in saved.pending_writes]
Expand Down Expand Up @@ -777,7 +777,7 @@ async def __aenter__(self) -> Self:
**saved.config.get("configurable", {}),
},
}
self.checkpoint = copy_checkpoint(saved.checkpoint)
self.checkpoint = saved.checkpoint
self.checkpoint_metadata = saved.metadata
self.checkpoint_pending_writes = (
[(str(tid), k, v) for tid, k, v in saved.pending_writes]
Expand Down

0 comments on commit 95b97d6

Please sign in to comment.