Skip to content

Commit

Permalink
Merge branch 'main' into cache-key-node
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Sep 8, 2024
2 parents cacf54f + 282562f commit 76681af
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- "3.11"
- "3.12"
core-version:
- ">=0.3.0.dev1,<0.4.0"
- ">=0.3.0.dev4,<0.4.0"
- "latest"

name: "test #${{ matrix.python-version }} (langchain-core: ${{ matrix.core-version }})"
Expand Down
20 changes: 10 additions & 10 deletions libs/langgraph/tests/__snapshots__/test_pregel.ambr

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion libs/langgraph/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

DEFAULT_POSTGRES_URI = "postgres://postgres:postgres@localhost:5442/"
# TODO: fix this once core is released
SHOULD_CHECK_SNAPSHOTS = version.parse(core_version) >= version.parse("0.3.0.dev0")
IS_LANGCHAIN_CORE_030_OR_GREATER = version.parse(core_version) >= version.parse(
"0.3.0.dev0"
)
SHOULD_CHECK_SNAPSHOTS = IS_LANGCHAIN_CORE_030_OR_GREATER


@pytest.fixture
Expand Down
43 changes: 43 additions & 0 deletions libs/langgraph/tests/test_messages_state.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from typing import Annotated
from uuid import UUID

import pytest
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
RemoveMessage,
SystemMessage,
)
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1

from langgraph.graph import add_messages
from langgraph.graph.message import MessagesState
from langgraph.graph.state import END, START, StateGraph
from tests.conftest import IS_LANGCHAIN_CORE_030_OR_GREATER
from tests.messages import _AnyIdHumanMessage


def test_add_single_message():
Expand Down Expand Up @@ -135,3 +143,38 @@ def test_delete_all():
result = add_messages(left, right)
expected_result = []
assert result == expected_result


MESSAGES_STATE_SCHEMAS = [MessagesState]
if IS_LANGCHAIN_CORE_030_OR_GREATER:

class MessagesStatePydantic(BaseModel):
messages: Annotated[list[AnyMessage], add_messages]

MESSAGES_STATE_SCHEMAS.append(MessagesStatePydantic)
else:

class MessagesStatePydanticV1(BaseModelV1):
messages: Annotated[list[AnyMessage], add_messages]

MESSAGES_STATE_SCHEMAS.append(MessagesStatePydanticV1)


@pytest.mark.parametrize("state_schema", MESSAGES_STATE_SCHEMAS)
def test_messages_state(state_schema):
def foo(state):
return {"messages": [HumanMessage("foo")]}

graph = StateGraph(state_schema)
graph.add_edge(START, "foo")
graph.add_edge("foo", END)
graph.add_node(foo)

app = graph.compile()

assert app.invoke({"messages": [("user", "meow")]}) == {
"messages": [
_AnyIdHumanMessage(content="meow"),
_AnyIdHumanMessage(content="foo"),
]
}

0 comments on commit 76681af

Please sign in to comment.