Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Nov 28, 2024
1 parent 513eb19 commit 2cd5d41
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 31 deletions.
15 changes: 3 additions & 12 deletions chatsky/slots/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@

from pydantic import BaseModel, Field, create_model

from chatsky.slots.slots import (
ValueSlot,
SlotNotExtracted,
GroupSlot,
ExtractedGroupSlot,
ExtractedValueSlot
)
from chatsky.slots.slots import ValueSlot, SlotNotExtracted, GroupSlot, ExtractedGroupSlot, ExtractedValueSlot

if TYPE_CHECKING:
from chatsky.core import Context
Expand Down Expand Up @@ -70,9 +64,7 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot:
flat_items = self._flatten_llm_group_slot(self)
captions = {}
for child_name, slot_item in flat_items.items():
captions[child_name] = (slot_item.return_type,
Field(description=slot_item.caption,
default=None))
captions[child_name] = (slot_item.return_type, Field(description=slot_item.caption, default=None))

logger.debug(f"Flattened group slot: {flat_items}")
DynamicGroupModel = create_model("DynamicGroupModel", **captions)
Expand All @@ -85,8 +77,7 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot:
logger.debug(f"Result JSON: {result_json}")

res = {
name: ExtractedValueSlot.model_construct(is_slot_extracted=True,
extracted_value=result_json[name])
name: ExtractedValueSlot.model_construct(is_slot_extracted=True, extracted_value=result_json[name])
for name in result_json
if result_json[name] is not None or not self.allow_partial_extraction
}
Expand Down
10 changes: 6 additions & 4 deletions tutorials/llm/3_filtering_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ def __call__(
],
},
"main_node": {
RESPONSE: Message("Hi! I am your note taking assistant. "
"Just send me your thoughts and if you need to "
"rewind a bit just send /remind and I will send "
"you a summary of your #important notes."),
RESPONSE: Message(
"Hi! I am your note taking assistant. "
"Just send me your thoughts and if you need to "
"rewind a bit just send /remind and I will send "
"you a summary of your #important notes."
),
TRANSITIONS: [
Tr(dst="remind_node", cnd=cnd.ExactMatch("/remind")),
Tr(dst=dst.Current()),
Expand Down
20 changes: 11 additions & 9 deletions tutorials/llm/4_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,31 @@
from pydantic import BaseModel, Field



load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")

# Initialize our models
movie_model = LLM_API(
ChatAnthropic(model="claude-3.5-sonnet", api_key=anthropic_api_key),
temperature=0
ChatAnthropic(model="claude-3.5-sonnet", api_key=anthropic_api_key),
temperature=0,
)
review_model = LLM_API(
ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0),
)


# Define structured output schemas
class Movie(BaseModel):
name: str = Field(description="Name of the movie")
genre: str = Field(description="Genre of the movie")
plot: str = Field(description="Plot of the movie in chapters")
cast: list = Field(description="List of the actors")


class MovieReview(Message):
"""Schema for movie reviews (uses Message.misc for metadata)"""

text: str = Field(description="The actual review text")
misc: dict = Field(
description="A dictionary with the following keys and values:"
Expand All @@ -66,7 +68,10 @@ class MovieReview(Message):
script = {
GLOBAL: {
TRANSITIONS: [
Tr(dst=("greeting_flow", "start_node"), cnd=cnd.ExactMatch("/start")),
Tr(
dst=("greeting_flow", "start_node"),
cnd=cnd.ExactMatch("/start"),
),
Tr(dst=("movie_flow", "create"), cnd=cnd.ExactMatch("/create")),
Tr(dst=("movie_flow", "review"), cnd=cnd.Regexp("/review \w*")),
]
Expand Down Expand Up @@ -102,7 +107,7 @@ class MovieReview(Message):
message_schema=MovieReview,
),
TRANSITIONS: [Tr(dst=("greeting_flow", "start_node"))],
}
},
},
}

Expand All @@ -111,10 +116,7 @@ class MovieReview(Message):
script=script,
start_label=("greeting_flow", "start_node"),
fallback_label=("greeting_flow", "fallback_node"),
models={
"movie_model": movie_model,
"review_model": review_model
},
models={"movie_model": movie_model, "review_model": review_model},
)

if __name__ == "__main__":
Expand Down
10 changes: 4 additions & 6 deletions tutorials/llm/5_llm_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
"""

# %%
slot_model = LLM_API(ChatOpenAI(
model="gpt-4o-mini", api_key=openai_api_key, temperature=0
))
slot_model = LLM_API(
ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0)
)

SLOTS = {
"person": LLMGroupSlot(
Expand Down Expand Up @@ -108,9 +108,7 @@
start_label=("user_flow", "start"),
fallback_label=("user_flow", "repeat_question"),
slots=SLOTS,
models={
"slot_model": slot_model
}
models={"slot_model": slot_model},
)


Expand Down

0 comments on commit 2cd5d41

Please sign in to comment.