Skip to content

Commit

Permalink
cli: standard tests in cli, test that they run, skip vectorstore tests (
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Dec 5, 2024
1 parent c5acedd commit 43c35d1
Show file tree
Hide file tree
Showing 36 changed files with 1,578 additions and 636 deletions.
6 changes: 4 additions & 2 deletions .github/scripts/check_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,17 @@ def _get_configs_for_multi_dirs(
# TODO: update to include all packages that rely on standard-tests (all partner packages)
# note: won't run on external repo partners
dirs_to_run["lint"].add("libs/standard-tests")
dirs_to_run["test"].add("libs/standard-tests")
dirs_to_run["test"].add("libs/partners/mistralai")
dirs_to_run["test"].add("libs/partners/openai")
dirs_to_run["test"].add("libs/partners/anthropic")
dirs_to_run["test"].add("libs/partners/fireworks")
dirs_to_run["test"].add("libs/partners/groq")

elif file.startswith("libs/cli"):
# todo: add cli makefile
pass
dirs_to_run["lint"].add("libs/cli")
dirs_to_run["test"].add("libs/cli")

elif file.startswith("libs/partners"):
partner_dir = file.split("/")[2]
if os.path.isdir(f"libs/partners/{partner_dir}") and [
Expand Down
2 changes: 2 additions & 0 deletions libs/cli/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.integration_test
51 changes: 45 additions & 6 deletions libs/cli/Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,47 @@
lint lint_diff:
poetry run poe lint

test:
poetry run poe test
######################
# LINTING AND FORMATTING
######################

format:
poetry run poe format
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/cli --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_cli
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)

test tests: _test _e2e_test

PYTHON = .venv/bin/python

_test:
poetry run pytest tests

# custom integration testing for cli integration flow
# currently ignores vectorstores test because lacks implementation
_e2e_test:
rm -rf .integration_test
mkdir .integration_test
cd .integration_test && \
python3 -m venv .venv && \
$(PYTHON) -m pip install --upgrade poetry && \
$(PYTHON) -m pip install -e .. && \
$(PYTHON) -m langchain_cli.cli integration new --name parrot-link --name-class ParrotLink && \
cd langchain-parrot-link && \
poetry install --with lint,typing,test && \
poetry run pip install -e ../../../standard-tests && \
make format lint tests && \
poetry install --with test_integration && \
rm tests/integration_tests/test_vectorstores.py && \
make integration_test
1 change: 1 addition & 0 deletions libs/cli/langchain_cli/dev_scripts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
"""
Development Scripts for template packages
"""
Expand Down
4 changes: 2 additions & 2 deletions libs/cli/langchain_cli/integration_template/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)

spell_check:
poetry run codespell --toml pyproject.toml
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from importlib import metadata

from __module_name__.chat_models import Chat__ModuleName__
from __module_name__.document_loaders import __ModuleName__Loader
from __module_name__.embeddings import __ModuleName__Embeddings
from __module_name__.llms import __ModuleName__LLM
from __module_name__.retrievers import __ModuleName__Retriever
from __module_name__.toolkits import __ModuleName__Toolkit
from __module_name__.tools import __ModuleName__Tool
from __module_name__.vectorstores import __ModuleName__VectorStore

try:
Expand All @@ -14,8 +17,11 @@

__all__ = [
"Chat__ModuleName__",
"__ModuleName__LLM",
"__ModuleName__VectorStore",
"__ModuleName__Embeddings",
"__ModuleName__Loader",
"__ModuleName__Retriever",
"__ModuleName__Toolkit",
"__ModuleName__Tool",
"__version__",
]
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
"""__ModuleName__ chat models."""

from typing import Any, List, Optional
from typing import Any, Dict, Iterator, List, Optional

from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field


class Chat__ModuleName__(BaseChatModel):
# TODO: Replace all TODOs in docstring. See example docstring:
# https://github.com/langchain-ai/langchain/blob/7ff05357bac6eaedf5058a2af88f23a1817d40fe/libs/partners/openai/langchain_openai/chat_models/base.py#L1120
"""__ModuleName__ chat model integration.
The default implementation echoes the first `parrot_buffer_length` characters of the input.
# TODO: Replace with relevant packages, env vars.
Setup:
Install ``__package_name__`` and set environment variable ``__MODULE_NAME___API_KEY``.
Expand Down Expand Up @@ -258,24 +266,138 @@ class Joke(BaseModel):
""" # noqa: E501

# TODO: This method must be implemented to generate chat responses.
model_name: str = Field(alias="model")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
max_retries: int = 2

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-__package_name_short__"

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
}

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError()
"""Override the _generate method to implement the chat model logic.
This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
# Replace this with actual logic to generate a response from a list
# of messages.
last_message = messages[-1]
tokens = last_message.content[: self.parrot_buffer_length]
ct_input_tokens = sum(len(message.content) for message in messages)
ct_output_tokens = len(tokens)
message = AIMessage(
content=tokens,
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
##

generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.
This method should be implemented if the model can generate output
in a streaming fashion. If the model does not support streaming,
do not implement it. In that case streaming requests will be automatically
handled by the _generate method.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]
tokens = str(last_message.content[: self.parrot_buffer_length])
ct_input_tokens = sum(len(message.content) for message in messages)

for token in tokens:
usage_metadata = UsageMetadata(
{
"input_tokens": ct_input_tokens,
"output_tokens": 1,
"total_tokens": ct_input_tokens + 1,
}
)
ct_input_tokens = 0
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
)

# TODO: Implement if Chat__ModuleName__ supports streaming. Otherwise delete method.
# def _stream(
# self,
# messages: List[BaseMessage],
# stop: Optional[List[str]] = None,
# run_manager: Optional[CallbackManagerForLLMRun] = None,
# **kwargs: Any,
# ) -> Iterator[ChatGenerationChunk]:
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)

yield chunk

# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3})
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk

# TODO: Implement if Chat__ModuleName__ supports async streaming. Otherwise delete.
# async def _astream(
Expand All @@ -294,8 +416,3 @@ def _generate(
# run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
# **kwargs: Any,
# ) -> ChatResult:

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-__package_name_short__"
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ class __ModuleName__Embeddings(Embeddings):
# TODO: Replace with relevant packages, env vars.
Setup:
Install ``__package_name__`` and set environment variable ``__MODULE_NAME___API_KEY``.
Install ``__package_name__`` and set environment variable
``__MODULE_NAME___API_KEY``.
.. code-block:: bash
Expand Down Expand Up @@ -70,21 +71,26 @@ class __ModuleName__Embeddings(Embeddings):
"""

def __init__(self, model: str):
self.model = model

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
raise NotImplementedError
return [[0.5, 0.6, 0.7] for _ in texts]

def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
raise NotImplementedError

# only keep aembed_documents and aembed_query if they're implemented!
# delete them otherwise to use the base class' default
# implementation, which calls the sync version in an executor
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError

async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
raise NotImplementedError
return self.embed_documents([text])[0]

# optional: add custom async implementations here
# you can also delete these, and the base class will
# use the default implementation, which calls the sync
# version in an async executor:

# async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
# """Asynchronous Embed search docs."""
# ...

# async def aembed_query(self, text: str) -> List[float]:
# """Asynchronous Embed query text."""
# ...
Loading

0 comments on commit 43c35d1

Please sign in to comment.