Skip to content

Commit

Permalink
together[minor]: update to pydantic v2, langchain-core v0.3, release …
Browse files Browse the repository at this point in the history
…0.2.0 (#6)

* together[major]: update to pydantic v2, langchain-core v0.3

* add tests

* update deps and increment version to 0.2.0.dev1

* set protected namespaces on TogetherEmbeddings

* update release workflows to support releases from non-main branches

* support dev and rc releases in min version checks

* update

---------

Co-authored-by: Chester Curme <[email protected]>
  • Loading branch information
baskaryan and ccurme authored Sep 13, 2024
1 parent c609001 commit 96d41ee
Show file tree
Hide file tree
Showing 12 changed files with 614 additions and 594 deletions.
40 changes: 30 additions & 10 deletions .github/scripts/get_min_versions.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,45 @@
import sys

import tomllib
if sys.version_info >= (3, 11):
import tomllib
else:
# for python 3.10 and below, which doesnt have stdlib tomllib
import tomli as tomllib

from packaging.version import parse as parse_version
import re

MIN_VERSION_LIBS = ["langchain-core"]

SKIP_IF_PULL_REQUEST = ["langchain-core"]


def get_min_version(version: str) -> str:
# base regex for x.x.x with cases for rc/post/etc
# valid strings: https://peps.python.org/pep-0440/#public-version-identifiers
vstring = r"\d+(?:\.\d+){0,2}(?:(?:a|b|rc|\.post|\.dev)\d+)?"
# case ^x.x.x
_match = re.match(r"^\^(\d+(?:\.\d+){0,2})$", version)
_match = re.match(f"^\\^({vstring})$", version)
if _match:
return _match.group(1)

# case >=x.x.x,<y.y.y
_match = re.match(r"^>=(\d+(?:\.\d+){0,2}),<(\d+(?:\.\d+){0,2})$", version)
_match = re.match(f"^>=({vstring}),<({vstring})$", version)
if _match:
_min = _match.group(1)
_max = _match.group(2)
assert parse_version(_min) < parse_version(_max)
return _min

# case x.x.x
_match = re.match(r"^(\d+(?:\.\d+){0,2})$", version)
_match = re.match(f"^({vstring})$", version)
if _match:
return _match.group(1)

raise ValueError(f"Unrecognized version format: {version}")


def get_min_version_from_toml(toml_path: str):
def get_min_version_from_toml(toml_path: str, versions_for: str):
# Parse the TOML file
with open(toml_path, "rb") as file:
toml_data = tomllib.load(file)
Expand All @@ -42,11 +52,18 @@ def get_min_version_from_toml(toml_path: str):

# Iterate over the libs in MIN_VERSION_LIBS
for lib in MIN_VERSION_LIBS:
if versions_for == "pull_request" and lib in SKIP_IF_PULL_REQUEST:
# some libs only get checked on release because of simultaneous
# changes
continue
# Check if the lib is present in the dependencies
if lib in dependencies:
# Get the version string
version_string = dependencies[lib]

if isinstance(version_string, dict):
version_string = version_string["version"]

# Use parse_version to get the minimum supported version from version_string
min_version = get_min_version(version_string)

Expand All @@ -56,10 +73,13 @@ def get_min_version_from_toml(toml_path: str):
return min_versions


# Get the TOML file path from the command line argument
toml_file = sys.argv[1]
if __name__ == "__main__":
# Get the TOML file path from the command line argument
toml_file = sys.argv[1]
versions_for = sys.argv[2]
assert versions_for in ["release", "pull_request"]

# Call the function to get the minimum versions
min_versions = get_min_version_from_toml(toml_file)
# Call the function to get the minimum versions
min_versions = get_min_version_from_toml(toml_file, versions_for)

print(" ".join([f"{lib}=={version}" for lib, version in min_versions.items()]))
print(" ".join([f"{lib}=={version}" for lib, version in min_versions.items()]))
12 changes: 9 additions & 3 deletions .github/workflows/_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@ on:
required: true
type: string
default: 'libs/together'
dangerous-nonmaster-release:
required: false
type: boolean
default: false
description: "Release from a non-master branch (danger!)"

env:
PYTHON_VERSION: "3.11"
POETRY_VERSION: "1.7.1"

jobs:
build:
if: github.ref == 'refs/heads/main'
if: github.ref == 'refs/heads/main' || inputs.dangerous-nonmaster-release
runs-on: ubuntu-latest

outputs:
Expand Down Expand Up @@ -75,6 +80,7 @@ jobs:
permissions: write-all
with:
working-directory: ${{ inputs.working-directory }}
dangerous-nonmaster-release: ${{ inputs.dangerous-nonmaster-release }}
secrets: inherit

pre-release-checks:
Expand Down Expand Up @@ -168,7 +174,7 @@ jobs:
id: min-version
run: |
poetry run pip install packaging
min_versions="$(poetry run python $GITHUB_WORKSPACE/.github/scripts/get_min_versions.py pyproject.toml)"
min_versions="$(poetry run python $GITHUB_WORKSPACE/.github/scripts/get_min_versions.py pyproject.toml release)"
echo "min-versions=$min_versions" >> "$GITHUB_OUTPUT"
echo "min-versions=$min_versions"
Expand Down Expand Up @@ -262,4 +268,4 @@ jobs:
draft: false
generateReleaseNotes: true
tag: ${{ inputs.working-directory }}/v${{ needs.build.outputs.version }}
commit: main
commit: ${{ github.sha }}
7 changes: 6 additions & 1 deletion .github/workflows/_test_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ on:
required: true
type: string
description: "From which folder this pipeline executes"
dangerous-nonmaster-release:
required: false
type: boolean
default: false
description: "Release from a non-master branch (danger!)"

env:
POETRY_VERSION: "1.7.1"
PYTHON_VERSION: "3.10"

jobs:
build:
if: github.ref == 'refs/heads/main'
if: github.ref == 'refs/heads/main' || inputs.dangerous-nonmaster-release
runs-on: ubuntu-latest

outputs:
Expand Down
57 changes: 27 additions & 30 deletions libs/together/langchain_together/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@

import openai
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
from_env,
secret_from_env,
)
from langchain_core.utils import from_env, secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self


class ChatTogether(BaseChatOpenAI):
Expand Down Expand Up @@ -135,7 +133,7 @@ class ChatTogether(BaseChatOpenAI):
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
# Only certain models support tool calling, check the together website to confirm compatibility
llm = ChatTogether(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
Expand Down Expand Up @@ -183,7 +181,7 @@ class GetPopulation(BaseModel):
from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class Joke(BaseModel):
Expand Down Expand Up @@ -325,40 +323,39 @@ def _get_ls_params(
alias="base_url",
)

class Config:
"""Pydantic config."""

allow_population_by_field_name = True
model_config = ConfigDict(
populate_by_name=True,
)

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

client_params = {
client_params: dict = {
"api_key": (
values["together_api_key"].get_secret_value()
if values["together_api_key"]
self.together_api_key.get_secret_value()
if self.together_api_key
else None
),
"base_url": values["together_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"base_url": self.together_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}

if not values.get("client"):
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
return values
return self
69 changes: 33 additions & 36 deletions libs/together/langchain_together/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@

import openai
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import (
from_env,
get_pydantic_field_names,
secret_from_env,
model_validator,
)
from typing_extensions import Self

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -174,14 +172,15 @@ class TogetherEmbeddings(BaseModel, Embeddings):
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""

class Config:
"""Configuration for this pydantic object."""

extra = "forbid"
allow_population_by_field_name = True
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
protected_namespaces=(),
)

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
Expand All @@ -206,38 +205,36 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def post_init(self) -> Self:
"""Logic that will post Pydantic initialization."""
client_params = {
client_params: dict = {
"api_key": (
values["together_api_key"].get_secret_value()
if values["together_api_key"]
self.together_api_key.get_secret_value()
if self.together_api_key
else None
),
"base_url": values["together_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"base_url": self.together_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if not values.get("client"):
sync_specific = (
{"http_client": values["http_client"]} if values["http_client"] else {}
if not (self.client or None):
sync_specific: dict = (
{"http_client": self.http_client} if self.http_client else {}
)
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).embeddings
if not values.get("async_client"):
async_specific = (
{"http_client": values["http_async_client"]}
if values["http_async_client"]
self.client = openai.OpenAI(**client_params, **sync_specific).embeddings
if not (self.async_client or None):
async_specific: dict = (
{"http_client": self.http_async_client}
if self.http_async_client
else {}
)
values["async_client"] = openai.AsyncOpenAI(
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).embeddings
return values
return self

@property
def _invocation_params(self) -> Dict[str, Any]:
Expand Down
20 changes: 9 additions & 11 deletions libs/together/langchain_together/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
secret_from_env,
)
from langchain_core.utils import secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,14 +76,14 @@ class Together(LLM):
the response for each token generation step.
"""

class Config:
"""Configuration for this pydantic object."""

extra = "forbid"
allow_population_by_field_name = True
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)

@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key exists in environment."""
if values.get("max_tokens") is None:
warnings.warn(
Expand Down
Loading

0 comments on commit 96d41ee

Please sign in to comment.