Skip to content

Commit

Permalink
Merge pull request #9 from parkervg/model-caching-fix
Browse files Browse the repository at this point in the history
Caching Fix
  • Loading branch information
parkervg authored Mar 14, 2024
2 parents 1a5fa3c + 592c1d2 commit a67bf7d
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 21 deletions.
63 changes: 49 additions & 14 deletions blendsql/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
from guidance import user, system, assistant
from contextlib import nullcontext
import inspect
import ast
import textwrap


def get_contexts(model: Model):
usercontext = nullcontext()
systemcontext = nullcontext()
assistantcontext = nullcontext()
if isinstance(model, Chat):
usercontext = user()
systemcontext = system()
assistantcontext = assistant()
return (usercontext, systemcontext, assistantcontext)


class Program:
Expand All @@ -27,22 +40,44 @@ def __new__(
self.usercontext,
self.systemcontext,
self.assistantcontext,
) = self._get_contexts(self.model)
) = get_contexts(self.model)
return self.__call__(self, **kwargs)

def __call__(self, *args, **kwargs):
pass

def __str__(self):
return inspect.getsource(self.__call__)

@staticmethod
def _get_contexts(model: Model):
usercontext = nullcontext()
systemcontext = nullcontext()
assistantcontext = nullcontext()
if isinstance(model, Chat):
usercontext = user()
systemcontext = system()
assistantcontext = assistant()
return (usercontext, systemcontext, assistantcontext)

def program_to_str(program: Program):
"""Create a string representation of a program.
It is slightly tricky, since in addition to getting the code content, we need to
1) identify all global variables referenced within a function, and then
2) evaluate the variable value
This is required, since if we have some global constant `PROMPT` called,
we don't want to fetch from a previously created cache if the value of `PROMPT` changes.
To avoid extreme messiness, we don't traverse into globals pointing at functions.
Example:
>>> PROMPT = "Here is my question: {question}"
>>> class CorrectionProgram(Program):
>>> def __call__(self, question: str, **kwargs):
>>> return self.model + PROMPT.format(question)
Some helpful refs:
- https://github.com/universe-proton/universe-topology/issues/15
"""
call_content = textwrap.dedent(inspect.getsource(program.__call__))
root = ast.parse(call_content)
root_names = {node.id for node in ast.walk(root) if isinstance(node, ast.Name)}
co_varnames = set(program.__call__.__code__.co_varnames)
names_to_resolve = sorted(root_names.difference(co_varnames))
resolved_names = ""
if len(names_to_resolve) > 0:
globals_as_dict = dict(inspect.getmembers(program.__call__))["__globals__"]
for name in names_to_resolve:
if name in globals_as_dict:
val = globals_as_dict[name]
# Ignore functions
if not callable(val):
resolved_names += f"{val}\n"
return f"{call_content}{resolved_names}"
21 changes: 14 additions & 7 deletions blendsql/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import hashlib
from abc import abstractmethod

from blendsql._program import Program
from blendsql._program import Program, program_to_str


class TokenTimer(threading.Thread):
Expand Down Expand Up @@ -42,6 +42,7 @@ class Model:
requires_config: bool = attrib(default=False)
refresh_interval_min: int = attrib(default=None)
env: str = attrib(default=".")
caching: bool = attrib(default=True)

model: guidance.models.Model = attrib(init=False)
prompts: list = attrib(init=False)
Expand Down Expand Up @@ -95,10 +96,11 @@ def predict(self, program: Program, **kwargs) -> dict:
>>> llm.predict(program, **kwargs)
{"result": '"This is Model generated output"'}
"""
# First, check our cache
key = self._create_key(program, **kwargs)
if key in self.cache:
return self.cache.get(key)
if self.caching:
# First, check our cache
key = self._create_key(program, **kwargs)
if key in self.cache:
return self.cache.get(key)
# Modify fields used for tracking Model usage
self.num_llm_calls += 1
model = program(model=self.model, **kwargs)
Expand All @@ -109,7 +111,8 @@ def predict(self, program: Program, **kwargs) -> dict:
prompt = re.sub(r"\<.*?\>", "", prompt)
self.num_prompt_tokens += len(self.tokenizer.encode(prompt))
self.prompts.append(prompt)
self.cache[key] = model._variables
if self.caching:
self.cache[key] = model._variables
return model._variables

def _create_key(self, program: Program, **kwargs) -> str:
Expand All @@ -131,7 +134,11 @@ def _create_key(self, program: Program, **kwargs) -> str:
]
)
)
combined = "{}{}".format(str(program), options_str).encode()
combined = "{}||{}||{}".format(
f"{self.model_name_or_path}||{type(self)}",
program_to_str(program),
options_str,
).encode()
hasher.update(combined)
return hasher.hexdigest()

Expand Down
139 changes: 139 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import uuid
from dataclasses import dataclass
from blendsql.models import Model
from blendsql._program import Program

TEST_QUESTION = "The quick brown fox jumps over the lazy dog"

MODEL_A = "a"
MODEL_B = "b"


@dataclass
class DummyModelOutput:
_variables: dict


class DummyModel(Model):
def __init__(self, model_name_or_path: str, **kwargs):
super().__init__(
model_name_or_path=model_name_or_path,
requires_config=False,
tokenizer=None,
**kwargs,
)

def _load_model(self):
return self.model_name_or_path


class DummyProgram(Program):
def __new__(
self,
**kwargs,
):
return self.__call__(self, **kwargs)

def __call__(self, question: str, **kwargs):
return DummyModelOutput({"uuid": str(uuid.uuid4())})


class DifferentDummyProgram(Program):
def __new__(
self,
**kwargs,
):
return self.__call__(self, **kwargs)

def __call__(self, question: str, unused: str = None, **kwargs):
return DummyModelOutput({"uuid": str(uuid.uuid4())})


class DummyProgramWithGlobal(Program):
def __new__(
self,
**kwargs,
):
return self.__call__(self, **kwargs)

def __call__(self, question: str, **kwargs):
print(TEST_GLOBAL)
return DummyModelOutput({"uuid": str(uuid.uuid4())})


def test_simple_cache():
a = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[
"uuid"
]

b = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[
"uuid"
]

assert a == b


def test_different_models():
a = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[
"uuid"
]

b = DummyModel(MODEL_B).predict(program=DummyProgram, question=TEST_QUESTION)[
"uuid"
]

assert a != b


def test_different_arguments():
a = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[
"uuid"
]

b = DummyModel(MODEL_A).predict(
program=DummyProgram, question="This is a different question"
)["uuid"]

assert a != b


def test_different_programs():
a = DummyModel(MODEL_A).predict(program=DummyProgram, question=TEST_QUESTION)[
"uuid"
]

b = DummyModel(MODEL_A).predict(
program=DifferentDummyProgram, question=TEST_QUESTION
)["uuid"]

assert a != b


def test_same_global_vars():
global TEST_GLOBAL
TEST_GLOBAL = "This is the same value"
a = DummyModel(MODEL_A).predict(
program=DummyProgramWithGlobal, question=TEST_QUESTION
)["uuid"]

TEST_GLOBAL = "This is the same value"
b = DummyModel(MODEL_A).predict(
program=DummyProgramWithGlobal, question=TEST_QUESTION
)["uuid"]

assert a == b


def test_different_global_vars():
global TEST_GLOBAL
TEST_GLOBAL = "This is one value"
a = DummyModel(MODEL_A).predict(
program=DummyProgramWithGlobal, question=TEST_QUESTION
)["uuid"]

TEST_GLOBAL = "This is a different value"
b = DummyModel(MODEL_A).predict(
program=DummyProgramWithGlobal, question=TEST_QUESTION
)["uuid"]

assert a != b

0 comments on commit a67bf7d

Please sign in to comment.