Skip to content

Commit

Permalink
Add scoped option
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Dec 6, 2023
1 parent 49e4d08 commit f9750c4
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 11 deletions.
10 changes: 8 additions & 2 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,7 @@ def get_output_schema(

@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
from langchain_core.runnables.context import CONTEXT_CONFIG_PREFIX
from langchain_core.runnables.context import CONTEXT_CONFIG_PREFIX, _key_from_id

# get all specs
all_specs = [
Expand All @@ -1423,6 +1423,7 @@ def config_specs(self) -> List[ConfigurableFieldSpec]:
# assign context dependencies
for pos, (spec, idx) in enumerate(all_specs):
if spec.id.startswith(CONTEXT_CONFIG_PREFIX):
print(spec.id, deps_by_pos[idx])
all_specs[pos] = (
ConfigurableFieldSpec(
id=spec.id,
Expand All @@ -1431,7 +1432,12 @@ def config_specs(self) -> List[ConfigurableFieldSpec]:
default=spec.default,
description=spec.description,
is_shared=spec.is_shared,
dependencies=list(deps_by_pos[idx]) + (spec.dependencies or []),
dependencies=[
d
for d in deps_by_pos[idx]
if _key_from_id(d) != _key_from_id(spec.id)
]
+ (spec.dependencies or []),
),
idx,
)
Expand Down
59 changes: 50 additions & 9 deletions libs/core/langchain_core/runnables/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def _getter(done: threading.Event, values: Values) -> Any:
return values[done]


def _key_from_id(id_: str) -> str:
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
else:
raise ValueError(f"Invalid context config id {id_}")


def _config_with_context(
config: RunnableConfig,
steps: List[Runnable],
Expand All @@ -74,12 +84,12 @@ def _config_with_context(
key: list(group)
for key, group in groupby(
sorted(context_specs, key=lambda s: s[0].id),
key=lambda s: s[0].id.split("/")[1],
key=lambda s: _key_from_id(s[0].id),
)
}
deps_by_key = {
key: set(
dep.split("/")[1] for spec in group for dep in (spec[0].dependencies or [])
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
)
for key, group in grouped_by_key.items()
}
Expand Down Expand Up @@ -129,15 +139,18 @@ def config_with_context(


class ContextGet(RunnableSerializable):
key: Union[str, List[str]]
prefix: str = ""

def __init__(self, key: Union[str, List[str]]):
super().__init__(key=key)
key: Union[str, List[str]]

@property
def ids(self) -> List[str]:
prefix = self.prefix + "/" if self.prefix else ""
keys = self.key if isinstance(self.key, list) else [self.key]
return [f"{CONTEXT_CONFIG_PREFIX}{k}{CONTEXT_CONFIG_SUFFIX_GET}" for k in keys]
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}"
for k in keys
]

@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
Expand Down Expand Up @@ -184,6 +197,8 @@ def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]:


class ContextSet(RunnableSerializable):
prefix: str = ""

keys: Mapping[str, Optional[Runnable]]

class Config:
Expand All @@ -193,6 +208,7 @@ def __init__(
self,
key: Optional[str] = None,
value: Optional[SetValue] = None,
prefix: str = "",
**kwargs: SetValue,
):
if key is not None:
Expand All @@ -201,13 +217,15 @@ def __init__(
keys={
k: _coerce_set_value(v) if v is not None else None
for k, v in kwargs.items()
}
},
prefix=prefix,
)

@property
def ids(self) -> List[str]:
prefix = self.prefix + "/" if self.prefix else ""
return [
f"{CONTEXT_CONFIG_PREFIX}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
for key in self.keys
]

Expand Down Expand Up @@ -258,9 +276,13 @@ async def ainvoke(


class Context:
@staticmethod
def create_scope(pefix: str, /) -> "PrefixContext":

Check failure on line 280 in libs/core/langchain_core/runnables/context.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

pefix ==> prefix
return PrefixContext(prefix=pefix)

Check failure on line 281 in libs/core/langchain_core/runnables/context.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

pefix ==> prefix

@staticmethod
def getter(key: Union[str, List[str]], /) -> ContextGet:
return ContextGet(key)
return ContextGet(key=key)

@staticmethod
def setter(
Expand All @@ -270,3 +292,22 @@ def setter(
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, **kwargs)


class PrefixContext:
prefix: str = ""

def __init__(self, prefix: str = ""):
self.prefix = prefix

def getter(self, key: Union[str, List[str]], /) -> ContextGet:
return ContextGet(key=key, prefix=self.prefix)

def setter(
self,
_key: Optional[str] = None,
_value: Optional[SetValue] = None,
/,
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)
57 changes: 57 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@ def seq_naive_rag_alt() -> Runnable:
)


def seq_naive_rag_scoped() -> Runnable:
context = [
"Hi there!",
"How are you?",
"What's your name?",
]

retriever = RunnableLambda(lambda x: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])

scoped = Context.create_scope("a_scope")

return (
Context.setter("input")
| {
"context": retriever | Context.setter("context"),
"question": RunnablePassthrough(),
"scoped": scoped.setter("context") | scoped.getter("context"),
}
| prompt
| llm
| StrOutputParser()
| Context.setter("result")
| Context.getter(["context", "input", "result"])
)


test_cases = [
(
Context.setter("foo") | Context.getter("foo"),
Expand Down Expand Up @@ -269,6 +297,35 @@ def seq_naive_rag_alt() -> Runnable:
),
),
),
(
seq_naive_rag_scoped,
(
TestCase(
"What up",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "What up",
},
),
TestCase(
"Howdy",
{
"result": "hello",
"context": [
"Hi there!",
"How are you?",
"What's your name?",
],
"input": "Howdy",
},
),
),
),
]


Expand Down

0 comments on commit f9750c4

Please sign in to comment.