Skip to content

Commit

Permalink
Fixes and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Nov 15, 2024
1 parent 020a7ef commit c9891f6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
6 changes: 5 additions & 1 deletion chatsky/llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import base64
import logging
from chatsky.core.context import Context
from chatsky.core.message import Image, Message
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage

# logging.basicConfig(level=logging.DEBUG)

async def message_to_langchain(message: Message, ctx: Context, source: str = "human", max_size: int = 1000):
"""
Expand Down Expand Up @@ -65,13 +66,16 @@ async def context_to_history(ctx: Context, length: int, filter_func, model_name:
[ctx.requests[x] for x in range(1, len(ctx.requests) + 1)],
[ctx.responses[x] for x in range(1, len(ctx.responses) + 1)],
)
logging.debug(f"Dialogue turns: {pairs}")
if length != -1:
for req, resp in filter(lambda x: filter_func(ctx, x[0], x[1], model_name), list(pairs)[-length:]):
logging.debug(f"This pair is valid: {req, resp}")
history.append(await message_to_langchain(req, ctx=ctx, max_size=max_size))
history.append(await message_to_langchain(resp, ctx=ctx, source="ai", max_size=max_size))
else:
# TODO: Fix redundant code
for req, resp in filter(lambda x: filter_func(ctx, x[0], x[1], model_name), list(pairs)):
logging.debug(f"This pair is valid: {req, resp}")
history.append(await message_to_langchain(req, ctx=ctx, max_size=max_size))
history.append(await message_to_langchain(resp, ctx=ctx, source="ai", max_size=max_size))
return history
9 changes: 5 additions & 4 deletions chatsky/responses/llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Union, Type
import logging
from pydantic import BaseModel, Field
from chatsky.core.message import Message
from chatsky.core.context import Context
from langchain_core.messages import SystemMessage
from chatsky.llm.utils import message_to_langchain, context_to_history
from chatsky.llm.filters import BaseFilter
from pydantic import BaseModel, Field
from chatsky.core.script_function import BaseResponse, AnyResponse


class LLMResponse(BaseResponse):
"""
Basic function for receiving LLM responses.
Expand Down Expand Up @@ -60,13 +60,14 @@ async def call(self, ctx: Context) -> Message:
)
)

if await self.prompt(ctx) != Message(text=""):
msg = await self.prompt(ctx)
msg = await self.prompt(ctx)
if msg.text:
history_messages.append(await message_to_langchain(msg, ctx=ctx, source="system"))

history_messages.append(
await message_to_langchain(ctx.last_request, ctx=ctx, source="human", max_size=self.max_size)
)
logging.debug(f"History: {history_messages}")
result = await model.respond(history_messages, message_schema=self.message_schema)

if result.annotations:
Expand Down
2 changes: 1 addition & 1 deletion tutorials/llm/3_filtering_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Transition as Tr,
conditions as cnd,
destinations as dst,
labels as lbl,
)
from chatsky.utils.testing import is_interactive_mode
from chatsky.llm import LLM_API
Expand Down Expand Up @@ -85,6 +84,7 @@ def __call__(
"remind_node": {
RESPONSE: LLMResponse(
model_name="assistant_model",
prompt="Create a bullet list from all the previous messages tagged with #important.",
history=15,
filter_func=FilterImportant(),
),
Expand Down

0 comments on commit c9891f6

Please sign in to comment.