From 1cba8410e3c180cb9c9c89e1a398f667a536f718 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Tue, 14 May 2024 18:39:07 -0600 Subject: [PATCH] More tests and interface cleanup --- poetry.lock | 66 +++++- pyproject.toml | 11 + rigging/chat.py | 193 +++++++++++------- rigging/completion.py | 179 ++++++++++------ rigging/generator.py | 32 +-- rigging/tool.py | 3 +- tests/{test_messages.py => test_chat.py} | 109 ++++++++++ tests/test_completion.py | 89 ++++++++ tests/test_generation.py | 57 +++++- ...ator_creation.py => test_generator_ids.py} | 0 tests/test_xml_parsing.py | 2 +- 11 files changed, 567 insertions(+), 174 deletions(-) rename tests/{test_messages.py => test_chat.py} (68%) create mode 100644 tests/test_completion.py rename tests/{test_generator_creation.py => test_generator_ids.py} (100%) diff --git a/poetry.lock b/poetry.lock index f810bee..a9cc946 100644 --- a/poetry.lock +++ b/poetry.lock @@ -473,6 +473,70 @@ traitlets = ">=4" [package.extras] test = ["pytest"] +[[package]] +name = "coverage" +version = "7.5.1" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"}, + {file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"}, + {file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"}, + {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"}, + {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"}, + {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"}, + {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"}, + {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"}, + {file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"}, + {file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"}, + {file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"}, + {file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"}, + {file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"}, + {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"}, + {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"}, + {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"}, + {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"}, + {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"}, + {file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"}, + {file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"}, + {file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"}, + {file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"}, + {file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"}, + {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"}, + {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"}, + {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"}, + {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"}, + {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"}, + {file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"}, + {file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"}, + {file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"}, + {file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"}, + {file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"}, + {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"}, + {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"}, + {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"}, + {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"}, + {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"}, + {file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"}, + {file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"}, + {file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"}, + {file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"}, + {file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"}, + {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"}, + {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"}, + {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"}, + {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"}, + {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"}, + {file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"}, + {file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"}, + {file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"}, + {file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"}, +] + +[package.extras] +toml = ["tomli"] + [[package]] name = "cssselect2" version = "0.7.0" @@ -3045,4 +3109,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "42c0c05546ff2553daf2a0af0b94ef0603daac3a5a241c40260fad030f67e948" +content-hash = "ad27c2a99f274a1564485ea6cf0c7686d68bdbb7a2d9db16642c2087947bcbe5" diff --git a/pyproject.toml b/pyproject.toml index 7ff28d2..a69d3aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ ruff = "^0.1.14" pytest = "^8.0.0" pandas = "^2.2.2" pandas-stubs = "^2.2.1.240316" +coverage = "^7.5.1" [tool.poetry.group.docs.dependencies] mkdocs = "^1.6.0" @@ -42,6 +43,16 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] filterwarnings = ["ignore::DeprecationWarning"] +[tool.coverage.run] +command_line = "-m pytest" + +[tool.coverage.report] +include = ["rigging/*.py"] +show_missing = true + +[tool.coverage.lcov] +output = "lcov.info" + [tool.mypy] plugins = "pydantic.mypy" strict = true diff --git a/rigging/chat.py b/rigging/chat.py index f22752d..0cb3029 100644 --- a/rigging/chat.py +++ b/rigging/chat.py @@ -22,6 +22,7 @@ ) from rigging.error import ExhaustedMaxRoundsError +from rigging.generator import GenerateParams, Generator, get_generator from rigging.message import Message, MessageDict, Messages from rigging.model import ( Model, @@ -32,9 +33,6 @@ from rigging.prompt import system_tool_extension from rigging.tool import Tool, ToolCalls, ToolDescriptionList, ToolResult, ToolResults -if t.TYPE_CHECKING: - from rigging.generator import GenerateParams, Generator - DEFAULT_MAX_ROUNDS = 5 @@ -61,7 +59,8 @@ class Chat(BaseModel): params: t.Optional["GenerateParams"] = Field(None, exclude=True, repr=False) """Any additional generation params used for this chat.""" - @computed_field(repr=False) + @computed_field(repr=False) # type: ignore[misc] + @property def generator_id(self) -> str | None: """The identifier of the generator used to create the chat""" if self.generator is not None: @@ -84,7 +83,6 @@ def __init__( generator: The generator associated with this chat. **kwargs: Additional keyword arguments (typically used for deserialization) """ - from rigging.generator import get_generator if "generator_id" in kwargs and generator is None: # TODO: Should we move params to self.params? @@ -126,14 +124,14 @@ def conversation(self) -> str: return "\n\n".join([str(m) for m in self.all]) @property - def message_dicts(self) -> list[dict[str, MessageDict]]: + def message_dicts(self) -> list[MessageDict]: """ Returns the chat as a minimal dictionary Returns: The chat as a list of messages with roles and content. """ - return [m.model_dump(include={"role", "content"}) for m in self.all] + return [t.cast(MessageDict, m.model_dump(include={"role", "content"})) for m in self.all] def meta(self, **kwargs: t.Any) -> "Chat": """ @@ -333,7 +331,8 @@ async def __call__(self, chats: list[Chat]) -> list[Chat]: ... -PostRunCallbacks = ThenChatCallback | AsyncThenChatCallback | MapChatCallback | AsyncMapChatCallback +ThenChatCallbacks = ThenChatCallback | AsyncThenChatCallback +MapChatCallbacks = MapChatCallback | AsyncMapChatCallback # Generators @@ -362,6 +361,13 @@ class BatchRunState: done: bool = False +@dataclass +class BatchRunPool: + generator: "Generator" + finished_states: list[BatchRunState] + pending_states: list[BatchRunState] + + class PendingChat: """ Represents a pending chat that can be modified and executed. @@ -385,9 +391,13 @@ def __init__( self.until_tools: list[Tool] = [] self.inject_tool_prompt: bool = True self.force_tool: bool = False - self.post_run_callbacks: list[PostRunCallbacks] = [] + self.then_chat_callbacks: list[ThenChatCallbacks] = [] + self.map_chat_callbacks: list[MapChatCallbacks] = [] # self.producer: MessageProducer | None = None + def __len__(self) -> int: + return len(self.chat) + def with_(self, params: t.Optional["GenerateParams"] = None, **kwargs: t.Any) -> "PendingChat": """ Assign specific generation parameter overloads for this chat. @@ -402,8 +412,6 @@ def with_(self, params: t.Optional["GenerateParams"] = None, **kwargs: t.Any) -> Returns: A new instance of PendingChat with the updated parameters. """ - from rigging.generator import GenerateParams - if params is None: params = GenerateParams(**kwargs) @@ -518,7 +526,7 @@ def process(chat: Chat) -> Chat | None: Returns: The current instance of the chat. """ - self.post_run_callbacks.append(callback) + self.then_chat_callbacks.append(callback) return self def map(self, callback: MapChatCallback | AsyncMapChatCallback) -> "PendingChat": @@ -546,7 +554,7 @@ def process(chats: list[Chat]) -> list[Chat]: Returns: The current instance of the chat. """ - self.post_run_callbacks.append(callback) + self.map_chat_callbacks.append(callback) return self # def from_(self, producer: MessageProducer) -> "PendingChat": @@ -822,26 +830,36 @@ def _process(self) -> t.Generator[list[Message], Message, list[Message]]: return new_messages def _post_run(self, chats: list[Chat]) -> list[Chat]: - if any(asyncio.iscoroutinefunction(callback) for callback in self.post_run_callbacks): - raise ValueError("Cannot use async then()/map() callbacks inside a non-async run call") - - for callback in self.post_run_callbacks: - if isinstance(callback, MapChatCallback): - chats = callback(chats) - elif isinstance(callback, ThenChatCallback): - chats = [callback(chat) or chat for chat in chats] + for map_callback in self.map_chat_callbacks: + if asyncio.iscoroutinefunction(map_callback): + raise ValueError( + f"Cannot use async map() callbacks inside a non-async run call: {map_callback.__name__}" + ) + chats = map_callback(chats) # type: ignore + + for then_callback in self.then_chat_callbacks: + if asyncio.iscoroutinefunction(then_callback): + raise ValueError( + f"Cannot use async then() callbacks inside a non-async run call: {then_callback.__name__}" + ) + chats = [then_callback(chat) or chat for chat in chats] # type: ignore + return chats async def _apost_run(self, chats: list[Chat]) -> list[Chat]: - if not all(asyncio.iscoroutinefunction(callback) for callback in self.post_run_callbacks): - raise ValueError("Cannot use non-async then()/map() callbacks inside a async run call") + for map_callback in self.map_chat_callbacks: + if not asyncio.iscoroutinefunction(map_callback): + raise ValueError( + f"Cannot use non-async map() callbacks inside an async run call: {map_callback.__call__.__name__}" + ) + chats = await map_callback(chats) - for callback in self.post_run_callbacks: - if isinstance(callback, AsyncMapChatCallback): - chats = await callback(chats) - elif isinstance(callback, AsyncThenChatCallback): - updated = await asyncio.gather(*[callback(chat) for chat in chats]) - chats = [updated[i] or chat for i, chat in enumerate(chats)] + for then_callback in self.then_chat_callbacks: + if not asyncio.iscoroutinefunction(then_callback): + raise ValueError( + f"Cannot use non-async then() callbacks inside an async run call: {then_callback.__call__.__name__}" + ) + chats = [await then_callback(chat) or chat for chat in chats] return chats @@ -861,8 +879,6 @@ def _pre_run(self) -> None: def _fit_params( self, count: int, params: t.Sequence[t.Optional["GenerateParams"] | None] | None = None ) -> list["GenerateParams"]: - from rigging.generator import GenerateParams - params = [None] * count if params is None else list(params) if len(params) != count: raise ValueError(f"The number of params must be {count}") @@ -994,9 +1010,15 @@ async def arun_many( def run_batch( self, - many: t.Sequence[t.Sequence[Message]] | t.Sequence[Message] | t.Sequence[MessageDict] | t.Sequence[str], + many: t.Sequence[t.Sequence[Message]] + | t.Sequence[Message] + | t.Sequence[MessageDict] + | t.Sequence[str] + | MessageDict + | str, params: t.Sequence[t.Optional["GenerateParams"]] | None = None, *, + size: int | None = None, skip_failed: bool = False, ) -> list[Chat]: """ @@ -1009,13 +1031,14 @@ def run_batch( Parameters: many: A sequence of sequences of messages to be generated. params: A sequence of parameters to be used for each set of messages. + size: The max chunk size of messages to execute generation at once. skip_failed: Enable to ignore any max rounds errors and return only successful chats. Returns: A list of generatated Chats. """ - if isinstance(many, str | dict): - many = [many] + if isinstance(many, dict) or isinstance(many, str): # Some strange typechecking here + many = t.cast(t.Sequence[str] | t.Sequence[MessageDict], [many]) count = max(len(many), len(params) if params is not None else 0) many = self._fit_many(count, many) @@ -1026,30 +1049,33 @@ def run_batch( ] _ = [next(state.processor) for state in states] + size = size or len(states) + pending_states = states while pending_states: - inbounds = self.generator.generate_messages( - [s.inputs + s.messages for s in pending_states], - [s.params for s in pending_states], - prefix=self.chat.all, - ) + for chunk in [pending_states[i : i + size] for i in range(0, len(pending_states), size)]: + inbounds = self.generator.generate_messages( + [s.inputs + s.messages for s in chunk], + [s.params for s in chunk], + prefix=self.chat.all, + ) - for inbound, state in zip(inbounds, pending_states, strict=True): - try: - state.messages = state.processor.send(inbound) - except StopIteration as stop: - state.done = True - state.chat = Chat( - self.chat.all + state.inputs, - t.cast(list[Message], stop.value), - generator=self.generator, - metadata=self.metadata, - params=state.params, - ) - except ExhaustedMaxRoundsError: - if not skip_failed: - raise - state.done = True + for inbound, state in zip(inbounds, chunk, strict=True): + try: + state.messages = state.processor.send(inbound) + except StopIteration as stop: + state.done = True + state.chat = Chat( + self.chat.all + state.inputs, + t.cast(list[Message], stop.value), + generator=self.generator, + metadata=self.metadata, + params=state.params, + ) + except ExhaustedMaxRoundsError: + if not skip_failed: + raise + state.done = True pending_states = [s for s in pending_states if not s.done] @@ -1057,14 +1083,20 @@ def run_batch( async def arun_batch( self, - many: t.Sequence[t.Sequence[Message]] | t.Sequence[Message] | t.Sequence[MessageDict] | t.Sequence[str], + many: t.Sequence[t.Sequence[Message]] + | t.Sequence[Message] + | t.Sequence[MessageDict] + | t.Sequence[str] + | MessageDict + | str, params: t.Sequence[t.Optional["GenerateParams"]] | None = None, *, + size: int | None = None, skip_failed: bool = False, ) -> list[Chat]: """async variant of the [rigging.chat.PendingChat.run_batch][] method.""" - if isinstance(many, str | dict): - many = [many] + if isinstance(many, dict) or isinstance(many, str): # Some strange typechecking here + many = t.cast(t.Sequence[str] | t.Sequence[MessageDict], [many]) count = max(len(many), len(params) if params is not None else 0) many = self._fit_many(count, many) @@ -1075,30 +1107,33 @@ async def arun_batch( ] _ = [next(state.processor) for state in states] + size = size or len(states) + pending_states = states while pending_states: - inbounds = await self.generator.agenerate_messages( - [s.inputs + s.messages for s in pending_states], - [s.params for s in pending_states], - prefix=self.chat.all, - ) + for chunk in [pending_states[i : i + size] for i in range(0, len(pending_states), size)]: + inbounds = await self.generator.agenerate_messages( + [s.inputs + s.messages for s in chunk], + [s.params for s in chunk], + prefix=self.chat.all, + ) - for inbound, state in zip(inbounds, pending_states, strict=True): - try: - state.messages = state.processor.send(inbound) - except StopIteration as stop: - state.done = True - state.chat = Chat( - self.chat.all + state.inputs, - t.cast(list[Message], stop.value), - generator=self.generator, - metadata=self.metadata, - params=state.params, - ) - except ExhaustedMaxRoundsError: - if not skip_failed: - raise - state.done = True + for inbound, state in zip(inbounds, chunk, strict=True): + try: + state.messages = state.processor.send(inbound) + except StopIteration as stop: + state.done = True + state.chat = Chat( + self.chat.all + state.inputs, + t.cast(list[Message], stop.value), + generator=self.generator, + metadata=self.metadata, + params=state.params, + ) + except ExhaustedMaxRoundsError: + if not skip_failed: + raise + state.done = True pending_states = [s for s in pending_states if not s.done] diff --git a/rigging/completion.py b/rigging/completion.py index 1abae6e..70375ab 100644 --- a/rigging/completion.py +++ b/rigging/completion.py @@ -20,15 +20,13 @@ ) from rigging.error import ExhaustedMaxRoundsError +from rigging.generator import GenerateParams, Generator, get_generator from rigging.model import ( Model, ModelT, ) from rigging.parsing import parse_many -if t.TYPE_CHECKING: - from rigging.generator import GenerateParams, Generator - DEFAULT_MAX_ROUNDS = 5 # TODO: Chats and Completions share a lot of structure and code. @@ -51,44 +49,45 @@ class Completion(BaseModel): generated: str """The generated text.""" metadata: dict[str, t.Any] = Field(default_factory=dict) - """Additional metadata for the chat.""" + """Additional metadata for the completion.""" - pending: t.Optional["PendingCompletion"] = Field(None, exclude=True, repr=False) - """The pending completion associated with this completion.""" + generator: t.Optional["Generator"] = Field(None, exclude=True, repr=False) + """The generator associated with the completion.""" + params: t.Optional["GenerateParams"] = Field(None, exclude=True, repr=False) + """Any additional generation params used for this completion.""" - @computed_field(repr=False) + @computed_field(repr=False) # type: ignore[misc] + @property def generator_id(self) -> str | None: """The identifier of the generator used to create the completion""" - if self.pending is not None: - return self.pending.generator.to_identifier(self.pending.params) + if self.generator is not None: + return self.generator.to_identifier(self.params) return None def __init__( self, text: str, generated: str, - pending: t.Optional["PendingCompletion"] = None, + generator: t.Optional["Generator"] = None, **kwargs: t.Any, ): """ - Initialize a Chat object. + Initialize a Completion object. Args: text: The original text. generated: The generated text. - pending: The pending completion associated with this completion. + generator: The generator associated with this completion. **kwargs: Additional keyword arguments (typically used for serialization). """ - from rigging.generator import get_generator - - if "generator_id" in kwargs and pending is None: + if "generator_id" in kwargs and generator is None: + # TODO: Should we move params to self.params? generator = get_generator(kwargs.pop("generator_id")) - pending = generator.complete(text) super().__init__( text=text, generated=generated, - pending=pending, + generator=generator, **kwargs, ) @@ -105,7 +104,7 @@ def restart(self, *, generator: t.Optional["Generator"] = None, include_all: boo Attempt to convert back to a PendingCompletion for further generation. Args: - generator: The generator to use for the restarted chat. Otherwise + generator: The generator to use for the restarted completion. Otherwise the generator from the original PendingCompletion will be used. include_all: Whether to include the generation before the next round. @@ -116,14 +115,14 @@ def restart(self, *, generator: t.Optional["Generator"] = None, include_all: boo ValueError: If the completion was not created with a PendingCompletion and no generator is provided. """ - text = self.all if include_all else self.text - if generator is not None: - return generator.complete(text) - elif self.pending is None: - raise ValueError("Cannot restart Completion that was not created with a PendingCompletion") - return PendingCompletion(self.pending.generator, text, self.pending.params) + text = self.all if include_all else self.generated + if generator is None: + generator = self.generator + if generator is None: + raise ValueError("Cannot restart a completion without an associated generator") + return generator.complete(text, self.params) - def fork(self, text: str) -> "PendingCompletion": + def fork(self, text: str, *, include_all: bool = False) -> "PendingCompletion": """ Forks the completion by creating calling [rigging.completion.Completion.restart][] and appends the specified text. @@ -133,11 +132,32 @@ def fork(self, text: str) -> "PendingCompletion": Returns: A new instance of a pending competion with the specified messages added. """ - return self.restart().add(text) + return self.restart(include_all=include_all).add(text) + + def continue_(self, text: str) -> "PendingCompletion": + """Alias for the [rigging.completion.Completion.fork][] with `include_all=True`.""" + return self.fork(text, include_all=True) + + def clone(self, *, only_messages: bool = False) -> "Completion": + """Creates a deep copy of the completion.""" + new = Completion(self.text, self.generated, self.generator) + if not only_messages: + new.metadata = deepcopy(self.metadata) + return new + + def meta(self, **kwargs: t.Any) -> "Completion": + """ + Updates the metadata of the completion with the provided key-value pairs. + + Args: + **kwargs: Key-value pairs representing the metadata to be updated. - def clone(self) -> "Completion": - """Creates a deep copy of the chat.""" - return Completion(self.text, self.generated, self.pending) + Returns: + The updated completion object. + """ + new = self.clone() + new.metadata.update(kwargs) + return new # Callbacks @@ -175,7 +195,7 @@ class MapCompletionCallback(t.Protocol): def __call__(self, completions: list[Completion]) -> list[Completion]: """ Passed a finalized completion to process. Can replace completions in the pipeline by returning - a new chat object. + a new completion object. """ ... @@ -189,9 +209,8 @@ async def __call__(self, completions: list[Completion]) -> list[Completion]: ... -PostRunCallbacks = ( - ThenCompletionCallback | AsyncThenCompletionCallback | MapCompletionCallback | AsyncMapCompletionCallback -) +ThenCompletionCallbacks = ThenCompletionCallback | AsyncThenCompletionCallback +MapCompletionCallbacks = MapCompletionCallback | AsyncMapCompletionCallback @dataclass @@ -221,7 +240,11 @@ def __init__(self, generator: "Generator", text: str, params: t.Optional["Genera # (callback, all_text, max_rounds) self.until_callbacks: list[tuple[UntilCompletionCallback, bool, int]] = [] self.until_types: list[type[Model]] = [] - self.post_run_callbacks: list[PostRunCallbacks] = [] + self.then_callbacks: list[ThenCompletionCallbacks] = [] + self.map_callbacks: list[MapCompletionCallbacks] = [] + + def __len__(self) -> int: + return len(self.text) def with_(self, params: t.Optional["GenerateParams"] = None, **kwargs: t.Any) -> "PendingCompletion": """ @@ -237,14 +260,12 @@ def with_(self, params: t.Optional["GenerateParams"] = None, **kwargs: t.Any) -> Returns: The current (or cloned) instance of the completion. """ - from rigging.generator import GenerateParams - if params is None: params = GenerateParams(**kwargs) if self.params is not None: new = self.clone() - new.params = params + new.params = self.params.merge_with(params) return new self.params = params @@ -259,7 +280,7 @@ def then(self, callback: ThenCompletionCallback) -> "PendingCompletion": for the remainder of the callbacks + return value of `run()`. ``` - def process(chat: Completion) -> Completion | None: + def process(completion: Completion) -> Completion | None: ... pending.then(process).run() @@ -271,7 +292,7 @@ def process(chat: Completion) -> Completion | None: Returns: The current instance of the pending completion. """ - self.post_run_callbacks.append(callback) + self.then_callbacks.append(callback) return self def map(self, callback: MapCompletionCallback | AsyncMapCompletionCallback) -> "PendingCompletion": @@ -287,7 +308,7 @@ def map(self, callback: MapCompletionCallback | AsyncMapCompletionCallback) -> " run methods when executing the generation process. ``` - def process(chats: list[str]) -> list[str]: + def process(completions: list[Completion]) -> list[Completion]: ... pending.map(process).run() @@ -299,7 +320,7 @@ def process(chats: list[str]) -> list[str]: Returns: The current instance of the completion. """ - self.post_run_callbacks.append(callback) + self.map_callbacks.append(callback) return self def add(self, text: str) -> "PendingCompletion": @@ -448,34 +469,44 @@ def _until_parse_callback(self, text: str) -> bool: return False def _post_run(self, completions: list[Completion]) -> list[Completion]: - if any(asyncio.iscoroutinefunction(callback) for callback in self.post_run_callbacks): - raise ValueError("Cannot use async then()/map() callbacks inside a non-async run call") + for map_callback in self.map_callbacks: + if asyncio.iscoroutinefunction(map_callback): + raise ValueError( + f"Cannot use async map() callbacks inside a non-async run call: {map_callback.__name__}" + ) + completions = map_callback(completions) # type: ignore + + for then_callback in self.then_callbacks: + if asyncio.iscoroutinefunction(then_callback): + raise ValueError( + f"Cannot use async then() callbacks inside a non-async run call: {then_callback.__name__}" + ) + updated = [then_callback(completion) for completion in completions] + completions = [updated[i] or completion for i, completion in enumerate(completions)] # type: ignore - for callback in self.post_run_callbacks: - if isinstance(callback, MapCompletionCallback): - chats = callback(completions) - elif isinstance(callback, ThenCompletionCallback): - chats = [callback(completion) or completion for completion in completions] - return chats + return completions async def _apost_run(self, completions: list[Completion]) -> list[Completion]: - if not all(asyncio.iscoroutinefunction(callback) for callback in self.post_run_callbacks): - raise ValueError("Cannot use non-async then()/map() callbacks inside a async run call") - - for callback in self.post_run_callbacks: - if isinstance(callback, AsyncMapCompletionCallback): - chats = await callback(completions) - elif isinstance(callback, AsyncThenCompletionCallback): - updated = await asyncio.gather(*[callback(completion) for completion in completions]) - chats = [updated[i] or completion for i, completion in enumerate(completions)] + for map_callback in self.map_callbacks: + if not asyncio.iscoroutinefunction(map_callback): + raise ValueError( + f"Cannot use non-async map() callbacks inside an async run call: {map_callback.__call__.__name__}" + ) + completions = await map_callback(completions) + + for then_callback in self.then_callbacks: + if not asyncio.iscoroutinefunction(then_callback): + raise ValueError( + f"Cannot use non-async then() callbacks inside an async run call: {then_callback.__call__.__name__}" + ) + updated = [await then_callback(completion) for completion in completions] + completions = [updated[i] or completion for i, completion in enumerate(completions)] return completions def _fit_params( self, count: int, params: t.Sequence[t.Optional["GenerateParams"] | None] | None = None ) -> list["GenerateParams"]: - from rigging.generator import GenerateParams - params = [None] * count if params is None else list(params) if len(params) != count: raise ValueError(f"The number of params must be {count}") @@ -525,7 +556,7 @@ def run(self) -> Completion: return self.run_many(1)[0] async def arun(self) -> Completion: - """async variant of the [rigging.chat.PendingChat.run][] method.""" + """async variant of the [rigging.completion.PendingCompletion.run][] method.""" return (await self.arun_many(1))[0] __call__ = run @@ -565,7 +596,11 @@ def run_many( except StopIteration as stop: state.done = True state.completion = Completion( - self.text, t.cast(str, stop.value), pending=self, metadata=self.metadata + self.text, + t.cast(str, stop.value), + generator=self.generator, + params=state.params, + metadata=self.metadata, ) except ExhaustedMaxRoundsError: if not skip_failed: @@ -599,7 +634,11 @@ async def arun_many( except StopIteration as stop: state.done = True state.completion = Completion( - self.text, t.cast(str, stop.value), pending=self, metadata=self.metadata + self.text, + t.cast(str, stop.value), + generator=self.generator, + params=state.params, + metadata=self.metadata, ) except ExhaustedMaxRoundsError: if not skip_failed: @@ -652,7 +691,11 @@ def run_batch( except StopIteration as stop: state.done = True state.completion = Completion( - self.text, t.cast(str, stop.value), pending=self, metadata=self.metadata + self.text, + t.cast(str, stop.value), + generator=self.generator, + params=state.params, + metadata=self.metadata, ) except ExhaustedMaxRoundsError: if not skip_failed: @@ -670,7 +713,7 @@ async def arun_batch( *, skip_failed: bool = False, ) -> list[Completion]: - """async variant of the [rigging.chat.PendingChat.run_batch][] method.""" + """async variant of the [rigging.completion.PendingCompletion.run_batch][] method.""" params = self._fit_params(len(many), params) states: list[RunState] = [RunState(m, p, self._process()) for m, p in zip(many, params, strict=True)] _ = [next(state.processor) for state in states] @@ -689,7 +732,11 @@ async def arun_batch( except StopIteration as stop: state.done = True state.completion = Completion( - self.text, t.cast(str, stop.value), pending=self, metadata=self.metadata + self.text, + t.cast(str, stop.value), + generator=self.generator, + params=state.params, + metadata=self.metadata, ) except ExhaustedMaxRoundsError: if not skip_failed: diff --git a/rigging/generator.py b/rigging/generator.py index b9f9036..540db66 100644 --- a/rigging/generator.py +++ b/rigging/generator.py @@ -9,14 +9,16 @@ from loguru import logger from pydantic import BaseModel, ConfigDict, Field, field_validator -from rigging.chat import Chat, PendingChat -from rigging.completion import Completion, PendingCompletion from rigging.error import InvalidModelSpecifiedError from rigging.message import ( Message, MessageDict, ) +if t.TYPE_CHECKING: + from rigging.chat import PendingChat + from rigging.completion import PendingCompletion + # We should probably let people configure # this independently, but for now we'll # fix it to prevent confusion @@ -239,7 +241,7 @@ def chat( self, messages: t.Sequence[MessageDict], params: GenerateParams | None = None, - ) -> PendingChat: + ) -> "PendingChat": ... @t.overload @@ -247,14 +249,14 @@ def chat( self, messages: t.Sequence[Message] | MessageDict | Message | str | None = None, params: GenerateParams | None = None, - ) -> PendingChat: + ) -> "PendingChat": ... def chat( self, messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str | None = None, params: GenerateParams | None = None, - ) -> PendingChat: + ) -> "PendingChat": """ Build a pending chat with the given messages and optional params overloads. @@ -265,11 +267,13 @@ def chat( Returns: Pending chat to run. """ + from rigging.chat import PendingChat + return PendingChat(self, Message.fit_as_list(messages) if messages else [], params) # Helper alternative to complete(generator) -> generator.complete(...) - def complete(self, text: str, params: GenerateParams | None = None) -> PendingCompletion: + def complete(self, text: str, params: GenerateParams | None = None) -> "PendingCompletion": """ Build a pending string completion of the given text with optional param overloads. @@ -280,6 +284,8 @@ def complete(self, text: str, params: GenerateParams | None = None) -> PendingCo Returns: The completed text. """ + from rigging.completion import PendingCompletion + return PendingCompletion(self, text, params) @@ -288,7 +294,7 @@ def chat( generator: "Generator", messages: t.Sequence[MessageDict], params: GenerateParams | None = None, -) -> PendingChat: +) -> "PendingChat": ... @@ -297,7 +303,7 @@ def chat( generator: "Generator", messages: t.Sequence[Message] | MessageDict | Message | str | None = None, params: GenerateParams | None = None, -) -> PendingChat: +) -> "PendingChat": ... @@ -305,7 +311,7 @@ def chat( generator: "Generator", messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str | None = None, params: GenerateParams | None = None, -) -> PendingChat: +) -> "PendingChat": """ Creates a pending chat using the given generator, messages, and params. @@ -325,7 +331,7 @@ def complete( generator: Generator, text: str, params: GenerateParams | None = None, -) -> PendingCompletion: +) -> "PendingCompletion": return generator.complete(text, params) @@ -588,9 +594,3 @@ async def agenerate_texts( g_providers["litellm"] = LiteLLMGenerator - -# TODO: This fixes some almost-circular import issues and -# typed forwardrefs we use in the other module - -Chat.model_rebuild() -Completion.model_rebuild() diff --git a/rigging/tool.py b/rigging/tool.py index 99e381c..798905a 100644 --- a/rigging/tool.py +++ b/rigging/tool.py @@ -34,7 +34,8 @@ class ToolCallParameter(Model): attr_value: SUPPORTED_TOOL_ARGUMENT_TYPES | None = attr("value", default=None, exclude=True) text_value: SUPPORTED_TOOL_ARGUMENT_TYPES | None = Field(default=None, exclude=True) - @computed_field + @computed_field # type: ignore[misc] + @property def value(self) -> SUPPORTED_TOOL_ARGUMENT_TYPES: return self.attr_value or self.text_value or "" diff --git a/tests/test_messages.py b/tests/test_chat.py similarity index 68% rename from tests/test_messages.py rename to tests/test_chat.py index 9f24af5..8bd9ae2 100644 --- a/tests/test_messages.py +++ b/tests/test_chat.py @@ -36,6 +36,12 @@ def test_message_from_dict() -> None: assert msg.content == "You are an AI assistant." +def test_message_from_str() -> None: + msg = Message.fit("Please say hello.") + assert msg.role == "user" + assert msg.content == "Please say hello." + + def test_message_str_representation() -> None: msg = Message("assistant", "I am an AI assistant.") assert str(msg) == "[assistant]: I am an AI assistant." @@ -146,6 +152,42 @@ def test_message_reparse_modified_content() -> None: assert person.age == 25 +def test_chat_generator_id() -> None: + generator = get_generator("gpt-3.5") + chat = Chat([], generator=generator) + assert chat.generator_id == "litellm!gpt-3.5" + + other = Chat([]) + assert other.generator_id is None + + +def test_chat_metadata() -> None: + chat = Chat([]).meta(key="value") + assert chat.metadata == {"key": "value"} + + +def test_chat_restart() -> None: + chat = Chat( + [ + Message("user", "Hello"), + Message("assistant", "Hi there!"), + ], + [ + Message("user", "Other Stuff"), + ], + generator=get_generator("gpt-3.5"), + ) + + assert len(chat.restart()) == 2 + assert len(chat.restart(include_all=True)) == 3 + assert len(chat.continue_(Message("user", "User continue (should append)"))) == 3 + assert len(chat.continue_(Message("assistant", "Assistant continue"))) == 4 + + chat.generator = None + with pytest.raises(ValueError): + chat.restart() + + def test_chat_continue() -> None: chat = Chat( [ @@ -163,6 +205,52 @@ def test_chat_continue() -> None: assert continued.all[2].content == "How are you?" +def test_chat_to_message_dicts() -> None: + chat = Chat( + [ + Message("user", "Hello"), + Message("assistant", "Hi there!"), + ], + generator=get_generator("gpt-3.5"), + ) + + assert len(chat.message_dicts) == 2 + assert chat.message_dicts[0] == {"role": "user", "content": "Hello"} + assert chat.message_dicts[1] == {"role": "assistant", "content": "Hi there!"} + + +def test_chat_to_conversation() -> None: + chat = Chat( + [ + Message("user", "Hello"), + Message("assistant", "Hi there!"), + ], + generator=get_generator("gpt-3.5"), + ) + + assert "[user]: Hello" in chat.conversation + assert "[assistant]: Hi there!" in chat.conversation + + +def test_chat_properties() -> None: + user_1 = Message("user", "Hello") + assistant_1 = Message("assistant", "Hi there!") + user_2 = Message("user", "How are you?") + assistant_2 = Message("assistant", "I'm doing well, thank you!") + + chat = Chat( + [ + user_1, + ], + [assistant_1, user_2, assistant_2], + ) + + assert chat.prev == [user_1] + assert chat.next == [assistant_1, user_2, assistant_2] + assert chat.all == [user_1, assistant_1, user_2, assistant_2] + assert chat.last == assistant_2 + + def test_pending_chat_continue() -> None: pending = PendingChat(get_generator("gpt-3.5"), [], GenerateParams()) continued = pending.fork([Message("user", "Hello")]) @@ -205,6 +293,27 @@ def test_chat_continue_maintains_parsed_models() -> None: assert len(continued.all[2].parts) == 0 +def test_pending_chat_meta() -> None: + pending = PendingChat(get_generator("gpt-3.5"), [Message("user", "Hello")]) + with_meta = pending.meta(key="value") + assert with_meta == pending + assert with_meta.metadata == {"key": "value"} + + +def test_pending_chat_with() -> None: + pending = PendingChat(get_generator("gpt-3.5"), [Message("user", "Hello")]) + with_pending = pending.with_(GenerateParams(max_tokens=123)) + assert with_pending == pending + assert with_pending.params is not None + assert with_pending.params.max_tokens == 123 + + with_pending_2 = with_pending.with_(GenerateParams(top_p=0.5)) + assert with_pending_2 != with_pending + assert with_pending_2.params is not None + assert with_pending_2.params.max_tokens == 123 + assert with_pending_2.params.top_p == 0.5 + + def test_chat_strip() -> None: chat = Chat( [ diff --git a/tests/test_completion.py b/tests/test_completion.py new file mode 100644 index 0000000..ffe2969 --- /dev/null +++ b/tests/test_completion.py @@ -0,0 +1,89 @@ +import pytest + +from rigging.completion import Completion, PendingCompletion +from rigging.generator import GenerateParams, get_generator + + +def test_completion_generator_id() -> None: + generator = get_generator("gpt-3.5") + completion = Completion("foo", "bar", generator) + assert completion.generator_id == "litellm!gpt-3.5" + + completion.generator = None + assert completion.generator_id is None + + +def test_completion_properties() -> None: + generator = get_generator("gpt-3.5") + completion = Completion("foo", "bar", generator) + assert completion.text == "foo" + assert completion.generated == "bar" + assert completion.generator == generator + assert len(completion) == len("foo") + len("bar") + assert completion.all == "foobar" + + +def test_completion_restart() -> None: + generator = get_generator("gpt-3.5") + completion = Completion("foo", "bar", generator) + assert len(completion.restart()) == 3 + assert len(completion.restart(include_all=True)) == 6 + + assert len(completion.fork("baz")) == 6 + assert len(completion.continue_("baz")) == 9 + + completion.generator = None + with pytest.raises(ValueError): + completion.restart() + + +def test_completion_clone() -> None: + generator = get_generator("gpt-3.5") + original = Completion("foo", "bar", generator).meta(key="value") + clone = original.clone() + assert clone.text == original.text + assert clone.generated == original.generated + assert clone.metadata == original.metadata + + clone_2 = original.clone(only_messages=True) + assert clone.metadata != clone_2.metadata + + +def test_pending_completion_with() -> None: + pending = PendingCompletion(get_generator("gpt-3.5"), "foo") + with_pending = pending.with_(GenerateParams(max_tokens=123)) + assert with_pending == pending + assert with_pending.params is not None + assert with_pending.params.max_tokens == 123 + + with_pending_2 = with_pending.with_(top_p=0.5) + assert with_pending_2 != with_pending + assert with_pending_2.params is not None + assert with_pending_2.params.max_tokens == 123 + assert with_pending_2.params.top_p == 0.5 + + +def test_pending_completion_fork() -> None: + pending = PendingCompletion(get_generator("gpt-3.5"), "foo") + forked_1 = pending.fork("bar") + forked_2 = pending.fork("baz") + + assert pending != forked_1 != forked_2 + assert pending.text == "foo" + assert forked_1.text == "foobar" + assert forked_2.text == "foobaz" + + +def test_pending_completion_meta() -> None: + pending = PendingCompletion(get_generator("gpt-3.5"), "foo") + with_meta = pending.meta(key="value") + assert with_meta == pending + assert with_meta.metadata == {"key": "value"} + + +def test_pending_completion_apply() -> None: + pending = PendingCompletion(get_generator("gpt-3.5"), "Hello $name") + applied = pending.apply(name="World", noexist="123") + assert pending != applied + assert pending.text == "Hello $name" + assert applied.text == "Hello World" diff --git a/tests/test_generation.py b/tests/test_generation.py index 3bc417f..75e2f11 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -5,6 +5,7 @@ from rigging import Message from rigging.generator import GenerateParams, Generator from rigging.model import YesNoAnswer +from rigging.parsing import try_parse class EchoGenerator(Generator): @@ -21,9 +22,19 @@ def generate_messages( assert len(messages) == 1 return [Message(role="assistant", content=messages[-1][-1].content) for m in messages] + def generate_texts( + self, texts: t.Sequence[str], params: t.Sequence[GenerateParams], *, prefix: str | None = None + ) -> t.Sequence[str]: + if prefix is not None: + texts = [t + prefix for t in texts] + + assert len(texts) == 1 + return [texts[-1]] + class CallbackGenerator(Generator): - callback: t.Callable[["CallbackGenerator", t.Sequence[Message]], str] | None = None + message_callback: t.Callable[["CallbackGenerator", t.Sequence[Message]], str] | None = None + text_callback: t.Callable[["CallbackGenerator", str], str] | None = None def generate_messages( self, @@ -36,11 +47,21 @@ def generate_messages( messages = [list(prefix) + list(m) for m in messages] assert len(messages) == 1 - assert self.callback is not None - return [Message(role="assistant", content=self.callback(self, m)) for m in messages] + assert self.message_callback is not None + return [Message(role="assistant", content=self.message_callback(self, m)) for m in messages] + def generate_texts( + self, texts: t.Sequence[str], params: t.Sequence[GenerateParams], *, prefix: str | None = None + ) -> t.Sequence[str]: + if prefix is not None: + texts = [prefix + t for t in texts] -def test_until_parsed_as_with_reset() -> None: + assert len(texts) == 1 + assert self.text_callback is not None + return [self.text_callback(self, text) for text in texts] + + +def test_chat_until_parsed_as_with_reset() -> None: generator = CallbackGenerator(model="callback", params=GenerateParams()) def valid_cb(self: CallbackGenerator, messages: t.Sequence[Message]) -> str: @@ -49,17 +70,17 @@ def valid_cb(self: CallbackGenerator, messages: t.Sequence[Message]) -> str: return "yes" def invalid_cb(self: CallbackGenerator, messages: t.Sequence[Message]) -> str: - self.callback = valid_cb + self.message_callback = valid_cb return "dropped" - generator.callback = invalid_cb + generator.message_callback = invalid_cb chat = generator.chat([{"role": "user", "content": "original"}]).until_parsed_as(YesNoAnswer).run() assert len(chat) == 2 assert chat.last.try_parse(YesNoAnswer) is not None @pytest.mark.parametrize("drop_dialog", [True, False]) -def test_until_parsed_as_with_recovery(drop_dialog: bool) -> None: +def test_chat_until_parsed_as_with_recovery(drop_dialog: bool) -> None: generator = CallbackGenerator(model="callback", params=GenerateParams()) def valid_cb(self: CallbackGenerator, messages: t.Sequence[Message]) -> str: @@ -74,16 +95,16 @@ def invalid_cb_2(self: CallbackGenerator, messages: t.Sequence[Message]) -> str: assert messages[0].content == "original" assert messages[1].content == "invalid1" assert "" in messages[2].content - self.callback = valid_cb + self.message_callback = valid_cb return "invalid2" def invalid_cb_1(self: CallbackGenerator, messages: t.Sequence[Message]) -> str: assert len(messages) == 1 assert messages[0].content == "original" - self.callback = invalid_cb_2 + self.message_callback = invalid_cb_2 return "invalid1" - generator.callback = invalid_cb_1 + generator.message_callback = invalid_cb_1 chat = ( generator.chat([{"role": "user", "content": "original"}]) .until_parsed_as(YesNoAnswer, attempt_recovery=True, drop_dialog=drop_dialog) @@ -92,3 +113,19 @@ def invalid_cb_1(self: CallbackGenerator, messages: t.Sequence[Message]) -> str: assert len(chat) == (2 if drop_dialog else 6) assert chat.last.try_parse(YesNoAnswer) is not None + + +def test_completion_until_parsed_as_with_reset() -> None: + generator = CallbackGenerator(model="callback", params=GenerateParams()) + + def valid_cb(self: CallbackGenerator, text: str) -> str: + assert text == "original" + return "yes" + + def invalid_cb(self: CallbackGenerator, text: str) -> str: + self.text_callback = valid_cb + return "dropped" + + generator.text_callback = invalid_cb + completion = generator.complete("original").until_parsed_as(YesNoAnswer).run() + assert try_parse(completion.generated, YesNoAnswer) is not None diff --git a/tests/test_generator_creation.py b/tests/test_generator_ids.py similarity index 100% rename from tests/test_generator_creation.py rename to tests/test_generator_ids.py diff --git a/tests/test_xml_parsing.py b/tests/test_xml_parsing.py index 253d131..f62f319 100644 --- a/tests/test_xml_parsing.py +++ b/tests/test_xml_parsing.py @@ -188,6 +188,6 @@ def test_xml_parsing_with_validation(content: str, model: Model, expectation: t. ), ], ) -def text_xml_parsing_sets(content: str, count: int, model: Model) -> None: +def test_xml_parsing_sets(content: str, count: int, model: Model) -> None: models = model.from_text(content) # type: ignore [var-annotated] assert len(models) == count, "Failed to parse model set"