Skip to content

Commit

Permalink
Merge pull request #1641 from langchain-ai/wfh/output_schema
Browse files Browse the repository at this point in the history
Support optional fields in the output schema when defining as Dataclass or TypedDict
  • Loading branch information
nfcampos authored Sep 6, 2024
2 parents 40dd40a + 4aa28fb commit a424eb4
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 60 deletions.
82 changes: 47 additions & 35 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,45 +481,22 @@ class CompiledStateGraph(CompiledGraph):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
if isclass(self.builder.input) and issubclass(
self.builder.input, (BaseModel, BaseModelV1)
):
return self.builder.input
else:
keys = list(self.builder.schemas[self.builder.input].keys())
if len(keys) == 1 and keys[0] == "__root__":
return create_model( # type: ignore[call-overload]
self.get_name("Input"),
__root__=(self.channels[keys[0]].UpdateType, None),
)
else:
return create_model( # type: ignore[call-overload]
self.get_name("Input"),
**{
k: (
self.channels[k].UpdateType,
(
get_field_default(
k,
self.channels[k].UpdateType,
self.builder.input,
)
),
)
for k in self.builder.schemas[self.builder.input]
if isinstance(self.channels[k], BaseChannel)
},
)
return _get_schema(
typ=self.builder.input,
schemas=self.builder.schemas,
channels=self.builder.channels,
name=self.get_name("Input"),
)

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
if isclass(self.builder.output) and issubclass(
self.builder.output, (BaseModel, BaseModelV1)
):
return self.builder.output

return super().get_output_schema(config)
return _get_schema(
typ=self.builder.output,
schemas=self.builder.schemas,
channels=self.builder.channels,
name=self.get_name("Output"),
)

def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None:
if key == START:
Expand Down Expand Up @@ -779,3 +756,38 @@ def _is_field_managed_value(name: str, typ: Type[Any]) -> Optional[ManagedValueS
return decoration

return None


def _get_schema(
typ: Type,
schemas: dict,
channels: dict,
name: str,
) -> type[BaseModel]:
if isclass(typ) and issubclass(typ, (BaseModel, BaseModelV1)):
return typ
else:
keys = list(schemas[typ].keys())
if len(keys) == 1 and keys[0] == "__root__":
return create_model( # type: ignore[call-overload]
name,
__root__=(channels[keys[0]].UpdateType, None),
)
else:
return create_model( # type: ignore[call-overload]
name,
**{
k: (
channels[k].UpdateType,
(
get_field_default(
k,
channels[k].UpdateType,
typ,
)
),
)
for k in schemas[typ]
if k in channels and isinstance(channels[k], BaseChannel)
},
)
Loading

0 comments on commit a424eb4

Please sign in to comment.