From d7deeea981dc9d423517ac906e69fdd98eb0d036 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 14 Mar 2024 10:40:30 -0400 Subject: [PATCH 1/3] Adding caching fixes Now, we can specify `caching=False` to disable caching. Also, global variable values are considered in creating a key for our cache. --- blendsql/_program.py | 62 +++++++++++++++++++++++++++++++-------- blendsql/models/_model.py | 21 ++++++++----- 2 files changed, 64 insertions(+), 19 deletions(-) diff --git a/blendsql/_program.py b/blendsql/_program.py index 289f1e7a..366e3b03 100644 --- a/blendsql/_program.py +++ b/blendsql/_program.py @@ -7,6 +7,19 @@ from guidance import user, system, assistant from contextlib import nullcontext import inspect +import ast +import textwrap + + +def get_contexts(model: Model): + usercontext = nullcontext() + systemcontext = nullcontext() + assistantcontext = nullcontext() + if isinstance(model, Chat): + usercontext = user() + systemcontext = system() + assistantcontext = assistant() + return (usercontext, systemcontext, assistantcontext) class Program: @@ -27,22 +40,47 @@ def __new__( self.usercontext, self.systemcontext, self.assistantcontext, - ) = self._get_contexts(self.model) + ) = get_contexts(self.model) return self.__call__(self, **kwargs) def __call__(self, *args, **kwargs): pass - def __str__(self): + def to_string(self) -> str: return inspect.getsource(self.__call__) - @staticmethod - def _get_contexts(model: Model): - usercontext = nullcontext() - systemcontext = nullcontext() - assistantcontext = nullcontext() - if isinstance(model, Chat): - usercontext = user() - systemcontext = system() - assistantcontext = assistant() - return (usercontext, systemcontext, assistantcontext) + +def program_to_str(program: Program): + """Create a string representation of a program. + It is slightly tricky, since in addition to getting the code content, we need to + 1) identify all global variables referenced within a function, and then + 2) evaluate the variable value + This is required, since if we have some global constant `PROMPT` called, + we don't want to fetch from a previously created cache if the value of `PROMPT` changes. + + To avoid extreme messiness, we don't traverse into globals pointing at functions. + + Example: + >>> PROMPT = "Here is my question: {question}" + >>> class CorrectionProgram(Program): + >>> def __call__(self, question: str, **kwargs): + >>> return self.model + PROMPT.format(question) + + Some helpful refs: + - https://github.com/universe-proton/universe-topology/issues/15 + """ + call_content = textwrap.dedent(inspect.getsource(program.__call__)) + root = ast.parse(call_content) + root_names = {node.id for node in ast.walk(root) if isinstance(node, ast.Name)} + co_varnames = set(program.__call__.__code__.co_varnames) + names_to_resolve = sorted(root_names.difference(co_varnames)) + resolved_names = "" + if len(names_to_resolve) > 0: + globals_as_dict = dict(inspect.getmembers(program.__call__))["__globals__"] + for name in names_to_resolve: + if name in globals_as_dict: + val = globals_as_dict[name] + # Ignore functions + if not callable(val): + resolved_names += f"{val}\n" + return f"{call_content}{resolved_names}" diff --git a/blendsql/models/_model.py b/blendsql/models/_model.py index d56bf73b..abe2185b 100644 --- a/blendsql/models/_model.py +++ b/blendsql/models/_model.py @@ -13,7 +13,7 @@ import hashlib from abc import abstractmethod -from blendsql._program import Program +from blendsql._program import Program, program_to_str class TokenTimer(threading.Thread): @@ -42,6 +42,7 @@ class Model: requires_config: bool = attrib(default=False) refresh_interval_min: int = attrib(default=None) env: str = attrib(default=".") + caching: bool = attrib(default=True) model: guidance.models.Model = attrib(init=False) prompts: list = attrib(init=False) @@ -95,10 +96,11 @@ def predict(self, program: Program, **kwargs) -> dict: >>> llm.predict(program, **kwargs) {"result": '"This is Model generated output"'} """ - # First, check our cache - key = self._create_key(program, **kwargs) - if key in self.cache: - return self.cache.get(key) + if self.caching: + # First, check our cache + key = self._create_key(program, **kwargs) + if key in self.cache: + return self.cache.get(key) # Modify fields used for tracking Model usage self.num_llm_calls += 1 model = program(model=self.model, **kwargs) @@ -109,7 +111,8 @@ def predict(self, program: Program, **kwargs) -> dict: prompt = re.sub(r"\<.*?\>", "", prompt) self.num_prompt_tokens += len(self.tokenizer.encode(prompt)) self.prompts.append(prompt) - self.cache[key] = model._variables + if self.caching: + self.cache[key] = model._variables return model._variables def _create_key(self, program: Program, **kwargs) -> str: @@ -131,7 +134,11 @@ def _create_key(self, program: Program, **kwargs) -> str: ] ) ) - combined = "{}{}".format(str(program), options_str).encode() + combined = "{}||{}||{}".format( + f"{self.model_name_or_path}||{type(self)}", + program_to_str(program), + options_str, + ).encode() hasher.update(combined) return hasher.hexdigest() From 9d15c5770bdc4475a11d5c42dd494d1b7f8560e5 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 14 Mar 2024 10:59:25 -0400 Subject: [PATCH 2/3] Adding tests for d7deeea --- tests/test_model.py | 139 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 tests/test_model.py diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 00000000..7b3433be --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,139 @@ +import uuid +from dataclasses import dataclass +from blendsql.models import Model +from blendsql._program import Program + +TEST_QUESTION = "The quick brown fox jumps over the lazy dog" + +MODEL_A = "a" +MODEL_B = "b" + + +@dataclass +class DummyModelOutput: + _variables: dict + + +class DummyModel(Model): + def __init__(self, model_name_or_path: str, **kwargs): + super().__init__( + model_name_or_path=model_name_or_path, + requires_config=False, + tokenizer=None, + **kwargs, + ) + + def _load_model(self): + return self.model_name_or_path + + +class DummyProgram(Program): + def __new__( + self, + **kwargs, + ): + return self.__call__(self, **kwargs) + + def __call__(self, question: str, **kwargs): + return DummyModelOutput({"uuid": str(uuid.uuid4())}) + + +class DifferentDummyProgram(Program): + def __new__( + self, + **kwargs, + ): + return self.__call__(self, **kwargs) + + def __call__(self, question: str, unused: str = None, **kwargs): + return DummyModelOutput({"uuid": str(uuid.uuid4())}) + + +class DummyProgramWithGlobal(Program): + def __new__( + self, + **kwargs, + ): + return self.__call__(self, **kwargs) + + def __call__(self, question: str, **kwargs): + print(TEST_GLOBAL) + return DummyModelOutput({"uuid": str(uuid.uuid4())}) + + +def test_simple_cache(): + a = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[ + "uuid" + ] + + b = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[ + "uuid" + ] + + assert a == b + + +def test_different_models(): + a = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[ + "uuid" + ] + + b = DummyModel(MODEL_B).predict(program=DummyProgram, question=TEST_QUESTION)[ + "uuid" + ] + + assert a != b + + +def test_different_arguments(): + a = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[ + "uuid" + ] + + b = DummyModel(MODEL_A).predict( + program=DummyProgram, question="This is a different question" + )["uuid"] + + assert a != b + + +def test_different_programs(): + a = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[ + "uuid" + ] + + b = DummyModel(MODEL_A).predict( + program=DifferentDummyProgram, question=TEST_QUESTION + )["uuid"] + + assert a != b + + +def test_same_global_vars(): + global TEST_GLOBAL + TEST_GLOBAL = "This is the same value" + a = DummyModel(MODEL_A).predict( + program=DummyProgramWithGlobal, question=TEST_QUESTION + )["uuid"] + + TEST_GLOBAL = "This is the same value" + b = DummyModel(MODEL_A).predict( + program=DummyProgramWithGlobal, question=TEST_QUESTION + )["uuid"] + + assert a == b + + +def test_different_global_vars(): + global TEST_GLOBAL + TEST_GLOBAL = "This is one value" + a = DummyModel(MODEL_A).predict( + program=DummyProgramWithGlobal, question=TEST_QUESTION + )["uuid"] + + TEST_GLOBAL = "This is a different value" + b = DummyModel(MODEL_A).predict( + program=DummyProgramWithGlobal, question=TEST_QUESTION + )["uuid"] + + assert a != b From 592c1d20ec454bbee7e43c277a9cc8bb99cccabe Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 14 Mar 2024 11:06:37 -0400 Subject: [PATCH 3/3] Removing unnecessary `program.to_string()`` function --- blendsql/_program.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/blendsql/_program.py b/blendsql/_program.py index 366e3b03..8373256a 100644 --- a/blendsql/_program.py +++ b/blendsql/_program.py @@ -46,9 +46,6 @@ def __new__( def __call__(self, *args, **kwargs): pass - def to_string(self) -> str: - return inspect.getsource(self.__call__) - def program_to_str(program: Program): """Create a string representation of a program.