From 6a1102d4c094b48afe43818a079f5758dc186a98 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 28 Aug 2023 06:58:33 -0400 Subject: [PATCH 01/10] mypy fixes and formatting --- .../langchain/chains/rl_chain/base.py | 148 ++++++++++-------- .../langchain/chains/rl_chain/metrics.py | 12 +- .../chains/rl_chain/model_repository.py | 6 +- .../chains/rl_chain/pick_best_chain.py | 22 +-- .../langchain/chains/rl_chain/vw_logger.py | 4 +- .../unit_tests/chains/rl_chain/test_utils.py | 2 +- 6 files changed, 107 insertions(+), 87 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index 28baf898d2cda..d97dd255afe8c 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -3,7 +3,18 @@ import logging import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain @@ -26,47 +37,47 @@ class _BasedOn: - def __init__(self, value): + def __init__(self, value: Any): self.value = value - def __str__(self): + def __str__(self) -> str: return str(self.value) __repr__ = __str__ -def BasedOn(anything): +def BasedOn(anything: Any) -> _BasedOn: return _BasedOn(anything) class _ToSelectFrom: - def __init__(self, value): + def __init__(self, value: Any): self.value = value - def __str__(self): + def __str__(self) -> str: return str(self.value) __repr__ = __str__ -def ToSelectFrom(anything): +def ToSelectFrom(anything: Any) -> _ToSelectFrom: if not isinstance(anything, list): raise ValueError("ToSelectFrom must be a list to select from") return _ToSelectFrom(anything) class _Embed: - def __init__(self, value, keep=False): + def __init__(self, value: Any, keep: bool = False): self.value = value self.keep = keep - def __str__(self): + def __str__(self) -> str: return str(self.value) __repr__ = __str__ -def Embed(anything, keep=False): +def Embed(anything: Any, keep: bool = False) -> Any: if isinstance(anything, _ToSelectFrom): return ToSelectFrom(Embed(anything.value, keep=keep)) elif isinstance(anything, _BasedOn): @@ -80,7 +91,7 @@ def Embed(anything, keep=False): return _Embed(anything, keep=keep) -def EmbedAndKeep(anything): +def EmbedAndKeep(anything: Any) -> Any: return Embed(anything, keep=True) @@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp return [parser.parse_line(line) for line in input_str.split("\n")] -def get_based_on_and_to_select_from(inputs: Dict[str, Any]): +def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]: to_select_from = { k: inputs[k].value for k in inputs.keys() @@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]): return based_on, to_select_from -def prepare_inputs_for_autoembed(inputs: Dict[str, Any]): +def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]: """ go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed, then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status @@ -134,29 +145,35 @@ class Selected(ABC): pass -class Event(ABC): +TSelected = TypeVar("TSelected", bound=Selected) + + +class Event(Generic[TSelected], ABC): inputs: Dict[str, Any] - selected: Optional[Selected] + selected: Optional[TSelected] - def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None): + def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None): self.inputs = inputs self.selected = selected +TEvent = TypeVar("TEvent", bound=Event) + + class Policy(ABC): @abstractmethod - def predict(self, event: Event) -> Any: - pass + def predict(self, event: TEvent) -> Any: + ... @abstractmethod - def learn(self, event: Event): - pass + def learn(self, event: TEvent) -> None: + ... @abstractmethod - def log(self, event: Event): - pass + def log(self, event: TEvent) -> None: + ... - def save(self): + def save(self) -> None: pass @@ -164,11 +181,11 @@ class VwPolicy(Policy): def __init__( self, model_repo: ModelRepository, - vw_cmd: Sequence[str], + vw_cmd: List[str], feature_embedder: Embedder, vw_logger: VwLogger, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) self.model_repo = model_repo @@ -176,7 +193,7 @@ def __init__( self.feature_embedder = feature_embedder self.vw_logger = vw_logger - def predict(self, event: Event) -> Any: + def predict(self, event: TEvent) -> Any: import vowpal_wabbit_next as vw text_parser = vw.TextFormatParser(self.workspace) @@ -184,7 +201,7 @@ def predict(self, event: Event) -> Any: parse_lines(text_parser, self.feature_embedder.format(event)) ) - def learn(self, event: Event): + def learn(self, event: TEvent) -> None: import vowpal_wabbit_next as vw vw_ex = self.feature_embedder.format(event) @@ -192,19 +209,19 @@ def learn(self, event: Event): multi_ex = parse_lines(text_parser, vw_ex) self.workspace.learn_one(multi_ex) - def log(self, event: Event): + def log(self, event: TEvent) -> None: if self.vw_logger.logging_enabled(): vw_ex = self.feature_embedder.format(event) self.vw_logger.log(vw_ex) - def save(self): - self.model_repo.save() + def save(self) -> None: + self.model_repo.save(self.workspace) -class Embedder(ABC): +class Embedder(Generic[TEvent], ABC): @abstractmethod - def format(self, event: Event) -> str: - pass + def format(self, event: TEvent) -> str: + ... class SelectionScorer(ABC, BaseModel): @@ -212,7 +229,7 @@ class SelectionScorer(ABC, BaseModel): @abstractmethod def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: - pass + ... class AutoSelectionScorer(SelectionScorer, BaseModel): @@ -243,7 +260,7 @@ def get_default_prompt() -> ChatPromptTemplate: return chat_prompt @root_validator(pre=True) - def set_prompt_and_llm_chain(cls, values): + def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: llm = values.get("llm") prompt = values.get("prompt") scoring_criteria_template_str = values.get("scoring_criteria_template_str") @@ -275,7 +292,7 @@ def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: ) -class RLChain(Chain): +class RLChain(Generic[TEvent], Chain): """ The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning. @@ -305,7 +322,7 @@ class RLChain(Chain): output_key: str = "result" #: :meta private: prompt: BasePromptTemplate selection_scorer: Union[SelectionScorer, None] - policy: Optional[Policy] + policy: Policy auto_embed: bool = True selected_input_key = "rl_chain_selected" selected_based_on_input_key = "rl_chain_selected_based_on" @@ -314,14 +331,14 @@ class RLChain(Chain): def __init__( self, feature_embedder: Embedder, - model_save_dir="./", - reset_model=False, - vw_cmd=None, - policy=VwPolicy, + model_save_dir: str = "./", + reset_model: bool = False, + vw_cmd: Optional[List[str]] = None, + policy: Type[Policy] = VwPolicy, vw_logs: Optional[Union[str, os.PathLike]] = None, - metrics_step=-1, - *args, - **kwargs, + metrics_step: int = -1, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) if self.selection_scorer is None: @@ -374,29 +391,29 @@ def _validate_inputs(self, inputs: Dict[str, Any]) -> None: ) @abstractmethod - def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: - pass + def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent: + ... @abstractmethod def _call_after_predict_before_llm( - self, inputs: Dict[str, Any], event: Event, prediction: Any - ) -> Tuple[Dict[str, Any], Event]: - pass + self, inputs: Dict[str, Any], event: TEvent, prediction: Any + ) -> Tuple[Dict[str, Any], TEvent]: + ... @abstractmethod def _call_after_llm_before_scoring( - self, llm_response: str, event: Event - ) -> Tuple[Dict[str, Any], Event]: - pass + self, llm_response: str, event: TEvent + ) -> Tuple[Dict[str, Any], TEvent]: + ... @abstractmethod def _call_after_scoring_before_learning( - self, event: Event, score: Optional[float] - ) -> Event: - pass + self, event: TEvent, score: Optional[float] + ) -> TEvent: + ... def update_with_delayed_score( - self, score: float, event: Event, force_score=False + self, score: float, event: TEvent, force_score: bool = False ) -> None: """ Updates the learned policy with the score provided. @@ -407,7 +424,8 @@ def update_with_delayed_score( "The selection scorer is set, and force_score was not set to True. \ Please set force_score=True to use this function." ) - self.metrics.on_feedback(score) + if self.metrics: + self.metrics.on_feedback(score) self._call_after_scoring_before_learning(event=event, score=score) self.policy.learn(event=event) self.policy.log(event=event) @@ -422,15 +440,16 @@ def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: + ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() if self.auto_embed: inputs = prepare_inputs_for_autoembed(inputs=inputs) - event = self._call_before_predict(inputs=inputs) + event: TEvent = self._call_before_predict(inputs=inputs) prediction = self.policy.predict(event=event) - self.metrics.on_decision() + if self.metrics: + self.metrics.on_decision() next_chain_inputs, event = self._call_after_predict_before_llm( inputs=inputs, event=event, prediction=prediction @@ -462,7 +481,8 @@ def _call( f"The selection scorer was not able to score, \ and the chain was not able to adjust to this response, error: {e}" ) - self.metrics.on_feedback(score) + if self.metrics: + self.metrics.on_feedback(score) event = self._call_after_scoring_before_learning(score=score, event=event) self.policy.learn(event=event) self.policy.log(event=event) @@ -515,7 +535,7 @@ def embed_string_type( def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: """Helper function to embed a dictionary item.""" - inner_dict = {} + inner_dict: Dict[str, Union[str, List[str]]] = {} for ns, embed_item in item.items(): if isinstance(embed_item, list): inner_dict[ns] = [] @@ -530,7 +550,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: def embed_list_type( item: list, model: Any, namespace: Optional[str] = None ) -> List[Dict[str, Union[str, List[str]]]]: - ret_list = [] + ret_list: List[Dict[str, Union[str, List[str]]]] = [] for embed_item in item: if isinstance(embed_item, dict): ret_list.append(embed_dict_type(embed_item, model)) diff --git a/libs/langchain/langchain/chains/rl_chain/metrics.py b/libs/langchain/langchain/chains/rl_chain/metrics.py index b7ec949c9eaa6..4d6306f776013 100644 --- a/libs/langchain/langchain/chains/rl_chain/metrics.py +++ b/libs/langchain/langchain/chains/rl_chain/metrics.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union if TYPE_CHECKING: import pandas as pd @@ -6,11 +6,11 @@ class MetricsTracker: def __init__(self, step: int): - self._history = [] - self._step = step - self._i = 0 - self._num = 0 - self._denom = 0 + self._history: List[Dict[str, Union[int, float]]] = [] + self._step: int = step + self._i: int = 0 + self._num: float = 0 + self._denom: float = 0 @property def score(self) -> float: diff --git a/libs/langchain/langchain/chains/rl_chain/model_repository.py b/libs/langchain/langchain/chains/rl_chain/model_repository.py index eea866d1cf3c4..87f162df0ab77 100644 --- a/libs/langchain/langchain/chains/rl_chain/model_repository.py +++ b/libs/langchain/langchain/chains/rl_chain/model_repository.py @@ -4,7 +4,7 @@ import os import shutil from pathlib import Path -from typing import TYPE_CHECKING, Sequence, Union +from typing import TYPE_CHECKING, List, Union if TYPE_CHECKING: import vowpal_wabbit_next as vw @@ -22,7 +22,7 @@ def __init__( self.folder = Path(folder) self.model_path = self.folder / "latest.vw" self.with_history = with_history - if reset and self.has_history: + if reset and self.has_history(): logger.warning( "There is non empty history which is recommended to be cleaned up" ) @@ -44,7 +44,7 @@ def save(self, workspace: "vw.Workspace") -> None: if self.with_history: # write history shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw") - def load(self, commandline: Sequence[str]) -> "vw.Workspace": + def load(self, commandline: List[str]) -> "vw.Workspace": import vowpal_wabbit_next as vw model_data = None diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index 6e1a1a5eff70b..e60e685a0b5f7 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -17,7 +17,7 @@ SENTINEL = object() -class PickBestFeatureEmbedder(base.Embedder): +class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]): """ Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy @@ -25,7 +25,7 @@ class PickBestFeatureEmbedder(base.Embedder): model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer. """ # noqa E501 - def __init__(self, model: Optional[Any] = None, *args, **kwargs): + def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) if model is None: @@ -88,7 +88,7 @@ def format(self, event: PickBest.Event) -> str: return example_string[:-1] -class PickBest(base.RLChain): +class PickBest(base.RLChain[PickBest.Event]): """ `PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call. @@ -131,7 +131,7 @@ def __init__( self.probability = probability self.score = score - class Event(base.Event): + class Event(base.Event[PickBest.Selected]): def __init__( self, inputs: Dict[str, Any], @@ -146,8 +146,8 @@ def __init__( def __init__( self, feature_embedder: Optional[PickBestFeatureEmbedder] = None, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): vw_cmd = kwargs.get("vw_cmd", []) if not vw_cmd: @@ -170,7 +170,7 @@ def __init__( super().__init__(feature_embedder=feature_embedder, *args, **kwargs) - def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event: + def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: context, actions = base.get_based_on_and_to_select_from(inputs=inputs) if not actions: raise ValueError( @@ -198,7 +198,7 @@ def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event: def _call_after_predict_before_llm( self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]] - ) -> Tuple[Dict[str, Any], PickBest.Event]: + ) -> Tuple[Dict[str, Any], Event]: import numpy as np prob_sum = sum(prob for _, prob in prediction) @@ -218,8 +218,8 @@ def _call_after_predict_before_llm( return next_chain_inputs, event def _call_after_llm_before_scoring( - self, llm_response: str, event: PickBest.Event - ) -> Tuple[Dict[str, Any], PickBest.Event]: + self, llm_response: str, event: Event + ) -> Tuple[Dict[str, Any], Event]: next_chain_inputs = event.inputs.copy() # only one key, value pair in event.to_select_from value = next(iter(event.to_select_from.values())) @@ -232,7 +232,7 @@ def _call_after_llm_before_scoring( return next_chain_inputs, event def _call_after_scoring_before_learning( - self, event: PickBest.Event, score: Optional[float] + self, event: Event, score: Optional[float] ) -> Event: event.selected.score = score return event diff --git a/libs/langchain/langchain/chains/rl_chain/vw_logger.py b/libs/langchain/langchain/chains/rl_chain/vw_logger.py index 4fa471753957c..e8d2e1541f1c7 100644 --- a/libs/langchain/langchain/chains/rl_chain/vw_logger.py +++ b/libs/langchain/langchain/chains/rl_chain/vw_logger.py @@ -9,10 +9,10 @@ def __init__(self, path: Optional[Union[str, PathLike]]): if self.path: self.path.parent.mkdir(parents=True, exist_ok=True) - def log(self, vw_ex: str): + def log(self, vw_ex: str) -> None: if self.path: with open(self.path, "a") as f: f.write(f"{vw_ex}\n\n") - def logging_enabled(self): + def logging_enabled(self) -> bool: return bool(self.path) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py index 6d54d20d9219f..625c37ee00029 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py @@ -1,3 +1,3 @@ class MockEncoder: - def encode(self, to_encode): + def encode(self, to_encode: str) -> str: return "[encoded]" + to_encode From dd6fff1c6209f6b05cfac0f343654947d6bd9e2e Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 28 Aug 2023 08:13:23 -0400 Subject: [PATCH 02/10] no errors in pick best chain --- .../langchain/chains/rl_chain/__init__.py | 2 +- .../chains/rl_chain/pick_best_chain.py | 47 ++++++++----------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/__init__.py b/libs/langchain/langchain/chains/rl_chain/__init__.py index e71de1da6ccf8..6d5cfc3e29c78 100644 --- a/libs/langchain/langchain/chains/rl_chain/__init__.py +++ b/libs/langchain/langchain/chains/rl_chain/__init__.py @@ -13,7 +13,7 @@ from langchain.chains.rl_chain.pick_best_chain import PickBest -def configure_logger(): +def configure_logger() -> None: logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) ch = logging.StreamHandler() diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index e60e685a0b5f7..ca920522680fc 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type, Union import langchain.chains.rl_chain.base as base from langchain.base_language import BaseLanguageModel @@ -145,7 +145,6 @@ def __init__( def __init__( self, - feature_embedder: Optional[PickBestFeatureEmbedder] = None, *args: Any, **kwargs: Any, ): @@ -163,12 +162,14 @@ def __init__( raise ValueError( "If vw_cmd is specified, it must include --cb_explore_adf" ) - kwargs["vw_cmd"] = vw_cmd + + feature_embedder = kwargs.get("feature_embedder", None) if not feature_embedder: feature_embedder = PickBestFeatureEmbedder() + kwargs["feature_embedder"] = feature_embedder - super().__init__(feature_embedder=feature_embedder, *args, **kwargs) + super().__init__(*args, **kwargs) def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: context, actions = base.get_based_on_and_to_select_from(inputs=inputs) @@ -223,10 +224,15 @@ def _call_after_llm_before_scoring( next_chain_inputs = event.inputs.copy() # only one key, value pair in event.to_select_from value = next(iter(event.to_select_from.values())) + v = ( + value[event.selected.index] + if event.selected + else event.to_select_from.values() + ) next_chain_inputs.update( { self.selected_based_on_input_key: str(event.based_on), - self.selected_input_key: value[event.selected.index], + self.selected_input_key: v, } ) return next_chain_inputs, event @@ -234,7 +240,8 @@ def _call_after_llm_before_scoring( def _call_after_scoring_before_learning( self, event: Event, score: Optional[float] ) -> Event: - event.selected.score = score + if event.selected: + event.selected.score = score return event def _call( @@ -248,33 +255,19 @@ def _call( def _chain_type(self) -> str: return "rl_chain_pick_best" - @classmethod - def from_chain( - cls, - llm_chain: Chain, - prompt: BasePromptTemplate, - selection_scorer=SENTINEL, - **kwargs: Any, - ): - if selection_scorer is SENTINEL: - selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm) - return PickBest( - llm_chain=llm_chain, - prompt=prompt, - selection_scorer=selection_scorer, - **kwargs, - ) - @classmethod def from_llm( - cls, + cls: Type[PickBest], llm: BaseLanguageModel, prompt: BasePromptTemplate, - selection_scorer=SENTINEL, + selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL, **kwargs: Any, - ): + ) -> PickBest: llm_chain = LLMChain(llm=llm, prompt=prompt) - return PickBest.from_chain( + if selection_scorer is SENTINEL: + selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm) + + return PickBest( llm_chain=llm_chain, prompt=prompt, selection_scorer=selection_scorer, From a11ad11d063e8f5553bd25b8bd74e629e1e31dd6 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 29 Aug 2023 03:59:01 -0400 Subject: [PATCH 03/10] fix all mypy errors --- .../langchain/chains/rl_chain/base.py | 50 ++++++++++--------- .../chains/rl_chain/pick_best_chain.py | 9 +++- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index d97dd255afe8c..22ff60a403e8f 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -161,6 +161,9 @@ def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None) class Policy(ABC): + def __init__(self, **kwargs: Any): + pass + @abstractmethod def predict(self, event: TEvent) -> Any: ... @@ -233,7 +236,7 @@ def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: class AutoSelectionScorer(SelectionScorer, BaseModel): - llm_chain: Union[LLMChain, None] = None + llm_chain: LLMChain prompt: Union[BasePromptTemplate, None] = None scoring_criteria_template_str: Optional[str] = None @@ -309,7 +312,7 @@ class RLChain(Generic[TEvent], Chain): - model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory. - reset_model (bool): If set to True, the model starts training from scratch. Default is False. - vw_cmd (List[str], optional): Command line arguments for the VW model. - - policy (VwPolicy): Policy used by the chain. + - policy (Type[VwPolicy]): Policy used by the chain. - vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs. - metrics_step (int): Step for the metrics tracker. Default is -1. @@ -322,7 +325,7 @@ class RLChain(Generic[TEvent], Chain): output_key: str = "result" #: :meta private: prompt: BasePromptTemplate selection_scorer: Union[SelectionScorer, None] - policy: Policy + active_policy: Policy auto_embed: bool = True selected_input_key = "rl_chain_selected" selected_based_on_input_key = "rl_chain_selected_based_on" @@ -347,14 +350,17 @@ def __init__( reinforcement learning will be done in the RL chain \ unless update_with_delayed_score is called." ) - self.policy = policy( - model_repo=ModelRepository( - model_save_dir, with_history=True, reset=reset_model - ), - vw_cmd=vw_cmd or [], - feature_embedder=feature_embedder, - vw_logger=VwLogger(vw_logs), - ) + + if self.active_policy is None: + self.active_policy = policy( + model_repo=ModelRepository( + model_save_dir, with_history=True, reset=reset_model + ), + vw_cmd=vw_cmd or [], + feature_embedder=feature_embedder, + vw_logger=VwLogger(vw_logs), + ) + self.metrics = MetricsTracker(step=metrics_step) class Config: @@ -427,8 +433,8 @@ def update_with_delayed_score( if self.metrics: self.metrics.on_feedback(score) self._call_after_scoring_before_learning(event=event, score=score) - self.policy.learn(event=event) - self.policy.log(event=event) + self.active_policy.learn(event=event) + self.active_policy.log(event=event) def set_auto_embed(self, auto_embed: bool) -> None: """ @@ -447,7 +453,7 @@ def _call( inputs = prepare_inputs_for_autoembed(inputs=inputs) event: TEvent = self._call_before_predict(inputs=inputs) - prediction = self.policy.predict(event=event) + prediction = self.active_policy.predict(event=event) if self.metrics: self.metrics.on_decision() @@ -484,8 +490,8 @@ def _call( if self.metrics: self.metrics.on_feedback(score) event = self._call_after_scoring_before_learning(score=score, event=event) - self.policy.learn(event=event) - self.policy.log(event=event) + self.active_policy.learn(event=event) + self.active_policy.log(event=event) return {self.output_key: {"response": output, "selection_metadata": event}} @@ -493,7 +499,7 @@ def save_progress(self) -> None: """ This function should be called to save the state of the learned policy model. """ - self.policy.save() + self.active_policy.save() @property def _chain_type(self) -> str: @@ -509,7 +515,7 @@ def is_stringtype_instance(item: Any) -> bool: def embed_string_type( item: Union[str, _Embed], model: Any, namespace: Optional[str] = None -) -> Dict[str, str]: +) -> Dict[str, Union[str, List[str]]]: """Helper function to embed a string or an _Embed object.""" join_char = "" keep_str = "" @@ -533,9 +539,9 @@ def embed_string_type( return {namespace: keep_str + join_char.join(map(str, encoded))} -def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: +def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]: """Helper function to embed a dictionary item.""" - inner_dict: Dict[str, Union[str, List[str]]] = {} + inner_dict: Dict[str, Any] = {} for ns, embed_item in item.items(): if isinstance(embed_item, list): inner_dict[ns] = [] @@ -560,9 +566,7 @@ def embed_list_type( def embed( - to_embed: Union[ - Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict] - ], + to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]], model: Any, namespace: Optional[str] = None, ) -> List[Dict[str, Union[str, List[str]]]]: diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index ca920522680fc..691e0a99ce19b 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -54,9 +54,14 @@ def format(self, event: PickBest.Event) -> str: to_select_from_var_name, to_select_from = next( iter(event.to_select_from.items()), (None, None) ) + action_embs = ( - base.embed(to_select_from, self.model, to_select_from_var_name) - if event.to_select_from + ( + base.embed(to_select_from, self.model, to_select_from_var_name) + if event.to_select_from + else None + ) + if to_select_from else None ) From 0b8691c6e5de170bbe574a4cf4f9c1bc68c453d2 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 29 Aug 2023 05:19:19 -0400 Subject: [PATCH 04/10] fix all mypy errors and some renaming and refactoring --- .../langchain/chains/rl_chain/base.py | 18 +++- .../chains/rl_chain/pick_best_chain.py | 83 ++++++++++--------- .../rl_chain/test_pick_best_text_embedder.py | 54 ++++++------ 3 files changed, 86 insertions(+), 69 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index 22ff60a403e8f..721b7d35de932 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -295,7 +295,7 @@ def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: ) -class RLChain(Generic[TEvent], Chain): +class RLChain(Chain, Generic[TEvent]): """ The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning. @@ -320,12 +320,24 @@ class RLChain(Generic[TEvent], Chain): The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called. """ # noqa: E501 + class _NoOpPolicy(Policy): + """Placeholder policy that does nothing""" + + def predict(self, event: TEvent) -> Any: + return None + + def learn(self, event: TEvent) -> None: + pass + + def log(self, event: TEvent) -> None: + pass + llm_chain: Chain output_key: str = "result" #: :meta private: prompt: BasePromptTemplate selection_scorer: Union[SelectionScorer, None] - active_policy: Policy + active_policy: Policy = _NoOpPolicy() auto_embed: bool = True selected_input_key = "rl_chain_selected" selected_based_on_input_key = "rl_chain_selected_based_on" @@ -351,7 +363,7 @@ def __init__( unless update_with_delayed_score is called." ) - if self.active_policy is None: + if isinstance(self.active_policy, RLChain._NoOpPolicy): self.active_policy = policy( model_repo=ModelRepository( model_save_dir, with_history=True, reset=reset_model diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index 691e0a99ce19b..16e8bf598c9b9 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -17,7 +17,36 @@ SENTINEL = object() -class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]): +class PickBestSelected(base.Selected): + index: Optional[int] + probability: Optional[float] + score: Optional[float] + + def __init__( + self, + index: Optional[int] = None, + probability: Optional[float] = None, + score: Optional[float] = None, + ): + self.index = index + self.probability = probability + self.score = score + + +class PickBestEvent(base.Event[PickBestSelected]): + def __init__( + self, + inputs: Dict[str, Any], + to_select_from: Dict[str, Any], + based_on: Dict[str, Any], + selected: Optional[PickBestSelected] = None, + ): + super().__init__(inputs=inputs, selected=selected) + self.to_select_from = to_select_from + self.based_on = based_on + + +class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): """ Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy @@ -35,7 +64,7 @@ def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any): self.model = model - def format(self, event: PickBest.Event) -> str: + def format(self, event: PickBestEvent) -> str: """ Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW """ @@ -93,7 +122,7 @@ def format(self, event: PickBest.Event) -> str: return example_string[:-1] -class PickBest(base.RLChain[PickBest.Event]): +class PickBest(base.RLChain[PickBestEvent]): """ `PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call. @@ -121,33 +150,6 @@ class PickBest(base.RLChain[PickBest.Event]): feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized. """ # noqa E501 - class Selected(base.Selected): - index: Optional[int] - probability: Optional[float] - score: Optional[float] - - def __init__( - self, - index: Optional[int] = None, - probability: Optional[float] = None, - score: Optional[float] = None, - ): - self.index = index - self.probability = probability - self.score = score - - class Event(base.Event[PickBest.Selected]): - def __init__( - self, - inputs: Dict[str, Any], - to_select_from: Dict[str, Any], - based_on: Dict[str, Any], - selected: Optional[PickBest.Selected] = None, - ): - super().__init__(inputs=inputs, selected=selected) - self.to_select_from = to_select_from - self.based_on = based_on - def __init__( self, *args: Any, @@ -176,7 +178,7 @@ def __init__( super().__init__(*args, **kwargs) - def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: + def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: context, actions = base.get_based_on_and_to_select_from(inputs=inputs) if not actions: raise ValueError( @@ -199,12 +201,15 @@ def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: to base the selected of ToSelectFrom on." ) - event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context) + event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context) return event def _call_after_predict_before_llm( - self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]] - ) -> Tuple[Dict[str, Any], Event]: + self, + inputs: Dict[str, Any], + event: PickBestEvent, + prediction: List[Tuple[int, float]], + ) -> Tuple[Dict[str, Any], PickBestEvent]: import numpy as np prob_sum = sum(prob for _, prob in prediction) @@ -214,7 +219,7 @@ def _call_after_predict_before_llm( sampled_ap = prediction[sampled_index] sampled_action = sampled_ap[0] sampled_prob = sampled_ap[1] - selected = PickBest.Selected(index=sampled_action, probability=sampled_prob) + selected = PickBestSelected(index=sampled_action, probability=sampled_prob) event.selected = selected # only one key, value pair in event.to_select_from @@ -224,8 +229,8 @@ def _call_after_predict_before_llm( return next_chain_inputs, event def _call_after_llm_before_scoring( - self, llm_response: str, event: Event - ) -> Tuple[Dict[str, Any], Event]: + self, llm_response: str, event: PickBestEvent + ) -> Tuple[Dict[str, Any], PickBestEvent]: next_chain_inputs = event.inputs.copy() # only one key, value pair in event.to_select_from value = next(iter(event.to_select_from.values())) @@ -243,8 +248,8 @@ def _call_after_llm_before_scoring( return next_chain_inputs, event def _call_after_scoring_before_learning( - self, event: Event, score: Optional[float] - ) -> Event: + self, event: PickBestEvent, score: Optional[float] + ) -> PickBestEvent: if event.selected: event.selected.score = score return event diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py index d8ea85c6ebcc2..c299b1872032d 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py @@ -11,7 +11,7 @@ def test_pickbest_textembedder_missing_context_throws(): feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_action = {"action": ["0", "1", "2"]} - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_action, based_on={} ) with pytest.raises(ValueError): @@ -21,7 +21,7 @@ def test_pickbest_textembedder_missing_context_throws(): @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_missing_actions_throws(): feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from={}, based_on={"context": "context"} ) with pytest.raises(ValueError): @@ -33,7 +33,7 @@ def test_pickbest_textembedder_no_label_no_emb(): feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on={"context": "context"} ) vw_ex_str = feature_embedder.format(event) @@ -45,8 +45,8 @@ def test_pickbest_textembedder_w_label_no_score_no_emb(): feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on={"context": "context"}, @@ -63,8 +63,8 @@ def test_pickbest_textembedder_w_full_label_no_emb(): expected = ( """shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """ ) - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on={"context": "context"}, @@ -90,8 +90,8 @@ def test_pickbest_textembedder_w_full_label_w_emb(): named_actions = {"action1": rl_chain.Embed([str1, str2, str3])} context = {"context": rl_chain.Embed(ctx_str_1)} expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -114,8 +114,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep(): named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])} context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)} expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -128,7 +128,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb(): named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context ) vw_ex_str = feature_embedder.format(event) @@ -141,8 +141,8 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb(): named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -155,8 +155,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb(): named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -186,8 +186,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb(): } expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -219,8 +219,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee } expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -253,8 +253,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb(): context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)} expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -290,8 +290,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_ } expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -315,7 +315,7 @@ def test_raw_features_underscored(): expected_no_embed = ( f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """ ) - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context ) vw_ex_str = feature_embedder.format(event) @@ -325,7 +325,7 @@ def test_raw_features_underscored(): named_actions = {"action": rl_chain.Embed([str1])} context = {"context": rl_chain.Embed(ctx_str)} expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """ - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context ) vw_ex_str = feature_embedder.format(event) @@ -335,7 +335,7 @@ def test_raw_features_underscored(): named_actions = {"action": rl_chain.EmbedAndKeep([str1])} context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501 - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context ) vw_ex_str = feature_embedder.format(event) From b3c0728de2893b924f34b75f8107430024154669 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 29 Aug 2023 05:28:43 -0400 Subject: [PATCH 05/10] fix mypy errors in tests --- .../rl_chain/test_pick_best_chain_call.py | 34 ++++++++------- .../rl_chain/test_pick_best_text_embedder.py | 34 ++++++++------- .../rl_chain/test_rl_chain_base_embedder.py | 42 +++++++++---------- 3 files changed, 58 insertions(+), 52 deletions(-) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py index 3fad1667d91c2..7bca6b470d88a 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import pytest from test_utils import MockEncoder @@ -10,7 +12,7 @@ @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def setup(): +def setup() -> tuple: _PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm""" PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE) @@ -19,7 +21,7 @@ def setup(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_multiple_ToSelectFrom_throws(): +def test_multiple_ToSelectFrom_throws() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) actions = ["0", "1", "2"] @@ -32,7 +34,7 @@ def test_multiple_ToSelectFrom_throws(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_missing_basedOn_from_throws(): +def test_missing_basedOn_from_throws() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) actions = ["0", "1", "2"] @@ -41,7 +43,7 @@ def test_missing_basedOn_from_throws(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_ToSelectFrom_not_a_list_throws(): +def test_ToSelectFrom_not_a_list_throws() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) actions = {"actions": ["0", "1", "2"]} @@ -53,7 +55,7 @@ def test_ToSelectFrom_not_a_list_throws(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_update_with_delayed_score_with_auto_validator_throws(): +def test_update_with_delayed_score_with_auto_validator_throws() -> None: llm, PROMPT = setup() # this LLM returns a number so that the auto validator will return that auto_val_llm = FakeListChatModel(responses=["3"]) @@ -75,7 +77,7 @@ def test_update_with_delayed_score_with_auto_validator_throws(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_update_with_delayed_score_force(): +def test_update_with_delayed_score_force() -> None: llm, PROMPT = setup() # this LLM returns a number so that the auto validator will return that auto_val_llm = FakeListChatModel(responses=["3"]) @@ -99,7 +101,7 @@ def test_update_with_delayed_score_force(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_update_with_delayed_score(): +def test_update_with_delayed_score() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, selection_scorer=None @@ -117,11 +119,11 @@ def test_update_with_delayed_score(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_user_defined_scorer(): +def test_user_defined_scorer() -> None: llm, PROMPT = setup() class CustomSelectionScorer(rl_chain.SelectionScorer): - def score_response(self, inputs, llm_response: str) -> float: + def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: score = 200 return score @@ -139,7 +141,7 @@ def score_response(self, inputs, llm_response: str) -> float: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_default_embeddings(): +def test_default_embeddings() -> None: llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( @@ -173,7 +175,7 @@ def test_default_embeddings(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_default_embeddings_off(): +def test_default_embeddings_off() -> None: llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( @@ -199,7 +201,7 @@ def test_default_embeddings_off(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_default_embeddings_mixed_w_explicit_user_embeddings(): +def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None: llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( @@ -234,7 +236,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_default_no_scorer_specified(): +def test_default_no_scorer_specified() -> None: _, PROMPT = setup() chain_llm = FakeListChatModel(responses=[100]) chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT) @@ -249,7 +251,7 @@ def test_default_no_scorer_specified(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_explicitly_no_scorer(): +def test_explicitly_no_scorer() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, selection_scorer=None @@ -265,7 +267,7 @@ def test_explicitly_no_scorer(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_auto_scorer_with_user_defined_llm(): +def test_auto_scorer_with_user_defined_llm() -> None: llm, PROMPT = setup() scorer_llm = FakeListChatModel(responses=[300]) chain = pick_best_chain.PickBest.from_llm( @@ -284,7 +286,7 @@ def test_auto_scorer_with_user_defined_llm(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_calling_chain_w_reserved_inputs_throws(): +def test_calling_chain_w_reserved_inputs_throws() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) with pytest.raises(ValueError): diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py index c299b1872032d..acc7491c4024b 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py @@ -8,7 +8,7 @@ @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_missing_context_throws(): +def test_pickbest_textembedder_missing_context_throws() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_action = {"action": ["0", "1", "2"]} event = pick_best_chain.PickBestEvent( @@ -19,7 +19,7 @@ def test_pickbest_textembedder_missing_context_throws(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_missing_actions_throws(): +def test_pickbest_textembedder_missing_actions_throws() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) event = pick_best_chain.PickBestEvent( inputs={}, to_select_from={}, based_on={"context": "context"} @@ -29,7 +29,7 @@ def test_pickbest_textembedder_missing_actions_throws(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_no_label_no_emb(): +def test_pickbest_textembedder_no_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ @@ -41,7 +41,7 @@ def test_pickbest_textembedder_no_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_w_label_no_score_no_emb(): +def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ @@ -57,7 +57,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_w_full_label_no_emb(): +def test_pickbest_textembedder_w_full_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = ( @@ -75,7 +75,7 @@ def test_pickbest_textembedder_w_full_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_w_full_label_w_emb(): +def test_pickbest_textembedder_w_full_label_w_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" str2 = "1" @@ -99,7 +99,7 @@ def test_pickbest_textembedder_w_full_label_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_w_full_label_w_embed_and_keep(): +def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" str2 = "1" @@ -123,7 +123,7 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_no_label_no_emb(): +def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} @@ -136,7 +136,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_label_no_emb(): +def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} @@ -150,7 +150,7 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb(): +def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} @@ -164,7 +164,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb(): +def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" @@ -195,7 +195,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep(): +def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> ( + None +): feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" @@ -228,7 +230,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb(): +def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" @@ -262,7 +264,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep(): +def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep() -> ( + None +): feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" @@ -299,7 +303,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_ @pytest.mark.requires("vowpal_wabbit_next") -def test_raw_features_underscored(): +def test_raw_features_underscored() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "this is a long string" str1_underscored = str1.replace(" ", "_") diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py index 895fa8ebb6001..c9f8416ceb7ec 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py @@ -7,13 +7,13 @@ @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_context_str_no_emb(): +def test_simple_context_str_no_emb() -> None: expected = [{"a_namespace": "test"}] assert base.embed("test", MockEncoder(), "a_namespace") == expected @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_context_str_w_emb(): +def test_simple_context_str_w_emb() -> None: str1 = "test" encoded_str1 = " ".join(char for char in str1) expected = [{"a_namespace": encoded_text + encoded_str1}] @@ -28,7 +28,7 @@ def test_simple_context_str_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_context_str_w_nested_emb(): +def test_simple_context_str_w_nested_emb() -> None: # nested embeddings, innermost wins str1 = "test" encoded_str1 = " ".join(char for char in str1) @@ -46,13 +46,13 @@ def test_simple_context_str_w_nested_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_context_w_namespace_no_emb(): +def test_context_w_namespace_no_emb() -> None: expected = [{"test_namespace": "test"}] assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected @pytest.mark.requires("vowpal_wabbit_next") -def test_context_w_namespace_w_emb(): +def test_context_w_namespace_w_emb() -> None: str1 = "test" encoded_str1 = " ".join(char for char in str1) expected = [{"test_namespace": encoded_text + encoded_str1}] @@ -67,7 +67,7 @@ def test_context_w_namespace_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_context_w_namespace_w_emb2(): +def test_context_w_namespace_w_emb2() -> None: str1 = "test" encoded_str1 = " ".join(char for char in str1) expected = [{"test_namespace": encoded_text + encoded_str1}] @@ -82,7 +82,7 @@ def test_context_w_namespace_w_emb2(): @pytest.mark.requires("vowpal_wabbit_next") -def test_context_w_namespace_w_some_emb(): +def test_context_w_namespace_w_some_emb() -> None: str1 = "test1" str2 = "test2" encoded_str2 = " ".join(char for char in str2) @@ -111,7 +111,7 @@ def test_context_w_namespace_w_some_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_action_strlist_no_emb(): +def test_simple_action_strlist_no_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -120,7 +120,7 @@ def test_simple_action_strlist_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_action_strlist_w_emb(): +def test_simple_action_strlist_w_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -148,7 +148,7 @@ def test_simple_action_strlist_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_action_strlist_w_some_emb(): +def test_simple_action_strlist_w_some_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -181,7 +181,7 @@ def test_simple_action_strlist_w_some_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_no_emb(): +def test_action_w_namespace_no_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -204,7 +204,7 @@ def test_action_w_namespace_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_w_emb(): +def test_action_w_namespace_w_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -246,7 +246,7 @@ def test_action_w_namespace_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_w_emb2(): +def test_action_w_namespace_w_emb2() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -292,7 +292,7 @@ def test_action_w_namespace_w_emb2(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_w_some_emb(): +def test_action_w_namespace_w_some_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -333,7 +333,7 @@ def test_action_w_namespace_w_some_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict(): +def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -384,7 +384,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict(): @pytest.mark.requires("vowpal_wabbit_next") -def test_one_namespace_w_list_of_features_no_emb(): +def test_one_namespace_w_list_of_features_no_emb() -> None: str1 = "test1" str2 = "test2" expected = [{"test_namespace": [str1, str2]}] @@ -392,7 +392,7 @@ def test_one_namespace_w_list_of_features_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_one_namespace_w_list_of_features_w_some_emb(): +def test_one_namespace_w_list_of_features_w_some_emb() -> None: str1 = "test1" str2 = "test2" encoded_str2 = " ".join(char for char in str2) @@ -404,24 +404,24 @@ def test_one_namespace_w_list_of_features_w_some_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_nested_list_features_throws(): +def test_nested_list_features_throws() -> None: with pytest.raises(ValueError): base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder()) @pytest.mark.requires("vowpal_wabbit_next") -def test_dict_in_list_throws(): +def test_dict_in_list_throws() -> None: with pytest.raises(ValueError): base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder()) @pytest.mark.requires("vowpal_wabbit_next") -def test_nested_dict_throws(): +def test_nested_dict_throws() -> None: with pytest.raises(ValueError): base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder()) @pytest.mark.requires("vowpal_wabbit_next") -def test_list_of_tuples_throws(): +def test_list_of_tuples_throws() -> None: with pytest.raises(ValueError): base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder()) From 8d10a52525e8e85a2b6930be37e44de7f2b9bcbd Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 29 Aug 2023 05:36:45 -0400 Subject: [PATCH 06/10] fix linting complaints --- libs/langchain/langchain/chains/rl_chain/pick_best_chain.py | 1 - .../chains/rl_chain/test_pick_best_text_embedder.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index 16e8bf598c9b9..fa7f18f8fb25d 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -6,7 +6,6 @@ import langchain.chains.rl_chain.base as base from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.prompts import BasePromptTemplate diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py index acc7491c4024b..c49bacac6085c 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py @@ -264,9 +264,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep() -> ( - None -): +def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" From 44485c2b26a052939ca910c5a80a0e10dc18adaa Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 29 Aug 2023 05:42:45 -0400 Subject: [PATCH 07/10] make input arg type more explicit --- .../unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py index c9f8416ceb7ec..7402c64d381b0 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py @@ -116,7 +116,8 @@ def test_simple_action_strlist_no_emb() -> None: str2 = "test2" str3 = "test3" expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}] - assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected + to_embed: List[str] = [str1, str2, str3] + assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected @pytest.mark.requires("vowpal_wabbit_next") From 758225dc17c90548d205edfafe79d833dedbe825 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 29 Aug 2023 05:44:09 -0400 Subject: [PATCH 08/10] include type --- .../unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py index 7402c64d381b0..4ccc75868d28f 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py @@ -1,5 +1,6 @@ import pytest from test_utils import MockEncoder +from typing import List import langchain.chains.rl_chain.base as base From d50c0f139de710a685c276ec373bb8c4433bfc81 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 29 Aug 2023 05:46:56 -0400 Subject: [PATCH 09/10] re order imports --- .../unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py index 4ccc75868d28f..d0abc97e758cf 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py @@ -1,6 +1,7 @@ +from typing import List + import pytest from test_utils import MockEncoder -from typing import List import langchain.chains.rl_chain.base as base From 4e6e03ef50bf76887776c6a4a2fa2057047b7976 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 29 Aug 2023 05:51:52 -0400 Subject: [PATCH 10/10] fix mypy complaint --- .../unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py index d0abc97e758cf..bd0cc584ef117 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Union import pytest from test_utils import MockEncoder @@ -118,7 +118,7 @@ def test_simple_action_strlist_no_emb() -> None: str2 = "test2" str3 = "test3" expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}] - to_embed: List[str] = [str1, str2, str3] + to_embed: List[Union[str, base._Embed]] = [str1, str2, str3] assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected