Skip to content

Commit

Permalink
fix/feat: langchain templates and models (#293)
Browse files Browse the repository at this point in the history
* fix(langchain): correctly handles prompt_id and model_id
Ref: #285

Signed-off-by: Tomas Dvorak <[email protected]>

* feat(langchain): add prompt template example
Ref: #285

Signed-off-by: Tomas Dvorak <[email protected]>

* feat(langchain): accepts dicts in addition to pydantic models
Ref: #285
  • Loading branch information
Tomas2D authored Jan 29, 2024
1 parent cca79f1 commit e4dcff7
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 4 deletions.
44 changes: 44 additions & 0 deletions examples/extensions/langchain/langchain_generate_with_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Use LangChain generation with a custom template."""

from dotenv import load_dotenv

from genai import Client, Credentials
from genai.extensions.langchain import LangChainInterface
from genai.schema import TextGenerationParameters, TextGenerationReturnOptions

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
# GENAI_API=<genai-api-endpoint> (optional) DEFAULT_API = "https://bam-api.res.ibm.com"
load_dotenv()


def heading(text: str) -> str:
"""Helper function for centering text."""
return "\n" + f" {text} ".center(80, "=") + "\n"


client = Client(credentials=Credentials.from_env())

prompt_response = client.prompt.create(
name="Recipe Generator Prompt",
model_id="google/flan-t5-xl",
input="Make a short recipe for {{meal}} (use bullet points)",
)

try:
llm = LangChainInterface(
client=client,
model_id="ibm/granite-13b-instruct-v2",
prompt_id=prompt_response.result.id,
parameters=TextGenerationParameters(
min_new_tokens=100,
max_new_tokens=500,
return_options=TextGenerationReturnOptions(input_text=False, input_tokens=True),
),
data={"meal": "Lasagne"},
)
for chunk in llm.stream(""):
print(chunk, end="")
finally:
# Delete the prompt if you don't need it
client.prompt.delete(prompt_response.result.id)
3 changes: 2 additions & 1 deletion src/genai/extensions/_common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ def _prepare_generation_request(
request["parameters"] = parameters

if request.get("prompt_id") is not None:
request.pop("model_id", None)
request.pop("input", None)
elif request.get("input") is not None:
request.pop("prompt_id", None)

return request

Expand Down
9 changes: 8 additions & 1 deletion src/genai/extensions/langchain/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from typing import Any, Dict, Iterator, Optional, Union

from pydantic import ConfigDict
from pydantic.v1 import validator

from genai import Client
from genai._types import EnumLike
from genai._utils.general import to_model_optional
from genai.extensions._common.utils import (
_prepare_chat_generation_request,
create_generation_info_from_response,
Expand Down Expand Up @@ -86,7 +88,7 @@ class LangChainChatInterface(BaseChatModel):
from genai import Client, Credentials
from genai.extensions.langchain import LangChainChatInterface
from langchain_core.messages import HumanMessage, SystemMessage
from genai.text.generation import TextGenerationParameters
from genai.schema import TextGenerationParameters
client = Client(credentials=Credentials.from_env())
llm = LangChainChatInterface(
Expand Down Expand Up @@ -116,6 +118,11 @@ class LangChainChatInterface(BaseChatModel):
conversation_id: Optional[str] = None
streaming: Optional[bool] = None

@validator("parameters", "moderations", pre=True, always=True)
@classmethod
def validate_data_models(cls, value, values, config, field):
return to_model_optional(value, Model=field.type_, copy=False)

@classmethod
def is_lc_serializable(cls) -> bool:
return True
Expand Down
10 changes: 8 additions & 2 deletions src/genai/extensions/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from typing import Any, Iterator, List, Optional, Union

from pydantic import ConfigDict
from pydantic.v1 import validator

from genai import Client
from genai._utils.general import to_model_instance
from genai._utils.general import to_model_instance, to_model_optional
from genai.extensions._common.utils import (
_prepare_generation_request,
create_generation_info,
Expand Down Expand Up @@ -56,7 +57,7 @@ class LangChainInterface(LLM):
from genai import Client, Credentials
from genai.extensions.langchain import LangChainInterface
from genai.text.generation import TextGenerationParameters
from genai.schema import TextGenerationParameters
client = Client(credentials=Credentials.from_env())
llm = LangChainInterface(
Expand All @@ -80,6 +81,11 @@ class LangChainInterface(LLM):
streaming: Optional[bool] = None
execution_options: Optional[CreateExecutionOptions] = None

@validator("parameters", "moderations", "data", "execution_options", pre=True, always=True)
@classmethod
def _validate_data_models(cls, value, values, config, field):
return to_model_optional(value, Model=field.type_, copy=False)

@property
def _common_identifying_params(self):
return {
Expand Down

0 comments on commit e4dcff7

Please sign in to comment.