Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: request contextualisation - core functionality #65

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
21506e6
context logic subpackage; type-hint context extraction
Jun 21, 2024
a87e8e2
reworked type hint info extraction; extended functionality to also re…
ds-jakub-cierocki Jun 24, 2024
3ad4ecd
hidden args handling enabled
ds-jakub-cierocki Jun 24, 2024
b0cc0ae
improved type hints parsing and compatibility using package
ds-jakub-cierocki Jun 28, 2024
4ff5f62
dedicated exceptions for contex-related operations
ds-jakub-cierocki Jun 28, 2024
c479c50
useful classmethods for context-related operations
ds-jakub-cierocki Jun 28, 2024
e3bb127
make whole context utils module protected; added IQL parsing helper; …
ds-jakub-cierocki Jun 28, 2024
de72c7c
parsing type hints _extract_params_and_context() no longer excludes B…
ds-jakub-cierocki Jun 28, 2024
d3958c0
adjusted the existing code to be aware of contexts (promts yet untouc…
ds-jakub-cierocki Jun 28, 2024
be338bf
adjusted _type_validators.validate_arg_type() to handle typing.Union[]
ds-jakub-cierocki Jul 2, 2024
78f1535
context._utils._does_arg_allow_context() fix
ds-jakub-cierocki Jul 2, 2024
308e2e1
context record is now based on pydantic.BaseModel rather than datacla…
ds-jakub-cierocki Jul 2, 2024
73741d9
type hint lifting
ds-jakub-cierocki Jul 2, 2024
902f5ff
IQL generating LLM prompt passes BaseCallerContext() as filter argume…
ds-jakub-cierocki Jul 2, 2024
6309070
comments cleanup
ds-jakub-cierocki Jul 2, 2024
d523bf7
type hint fixes
ds-jakub-cierocki Jul 3, 2024
efe212f
Merge branch 'main' (which includes a large refactor by Michal) into …
ds-jakub-cierocki Jul 3, 2024
9ba89e5
post-merge fixes + minor refactor
ds-jakub-cierocki Jul 3, 2024
5fd802f
added missing docstrings; fixed type hints; fixed issues detected by …
ds-jakub-cierocki Jul 4, 2024
09bac55
reworked parse_param_type() function to increase performance, general…
ds-jakub-cierocki Jul 4, 2024
d42a369
fix: removed duplicated line from the prompt template
ds-jakub-cierocki Jul 4, 2024
c0b0522
adjusted existing unit tests to work with new contextualization logic
ds-jakub-cierocki Jul 4, 2024
9b2e131
linter-recommended fixes
ds-jakub-cierocki Jul 4, 2024
2d0ef4b
contextualization mechanism - dedicated unit tests
ds-jakub-cierocki Jul 5, 2024
6466f61
cleaned up overengineered code remanining from the previous iteration…
ds-jakub-cierocki Jul 5, 2024
637f7fa
replaced pydantic.BaseModel by dataclasses.dataclass, pydantic no lon…
ds-jakub-cierocki Jul 8, 2024
f867e25
BaseCallerContext: dataclass w.o. fields -> interface (abstract class…
ds-jakub-cierocki Jul 8, 2024
3423033
LLM now pastes Context() instead of BaseCallerContext() to indicate t…
ds-jakub-cierocki Jul 8, 2024
0d8cd1e
docstring typo fixes; more precise return type hint
ds-jakub-cierocki Jul 9, 2024
c97ba15
renamed Context() -> AskerContext(); added more detailed detailed exa…
ds-jakub-cierocki Jul 9, 2024
1294a9c
type hint parsing changes: SomeCustomContext -> AskerContext; Union[a…
ds-jakub-cierocki Jul 9, 2024
999759b
refactor: collection.results.[ViewExecutionResult, ExecutionResult]."…
ds-jakub-cierocki Jul 12, 2024
2e1005a
param type parsing: correctly handling builtins types with args (e.g.…
ds-jakub-cierocki Jul 12, 2024
820066d
type hint fix: explcitly marked BaseCallerContext.alias as typing.Cla…
ds-jakub-cierocki Jul 12, 2024
25fbfa6
docs + benchmarks adjusted to meet new naming [ExecutionResult, ViewE…
ds-jakub-cierocki Jul 15, 2024
a154577
redesigned context-not-available error to follow the same principles …
ds-jakub-cierocki Jul 15, 2024
623effd
EXPERIMENTAL: reworked context injection such it is handled immediate…
ds-jakub-cierocki Jul 15, 2024
afacf5b
additional unit tests for the new contextualization mechanism
ds-jakub-cierocki Jul 19, 2024
dd8b339
context benchmark script and data
ds-jakub-cierocki Jul 22, 2024
6bb0816
refactored main prompt (too long lines), missing end-of-line characters
ds-jakub-cierocki Jul 22, 2024
f388f92
better error handling
ds-jakub-cierocki Jul 22, 2024
fbecc51
context benchmark dataset fix
ds-jakub-cierocki Jul 23, 2024
5d4ff64
added polars-based accuracy summary to the benchmark
ds-jakub-cierocki Jul 23, 2024
e7e8826
adjusted prompt to reduce halucinations: nested filter/context calls …
ds-jakub-cierocki Jul 23, 2024
f8bf64e
merged main (inc. new benchmarks + large refactor) -> jc/issue-54-req…
ds-jakub-cierocki Aug 7, 2024
c1c871b
merge main
micpst Sep 23, 2024
8eefd9b
fix linters
micpst Sep 23, 2024
c28091f
fix tests
micpst Sep 23, 2024
69a8d58
fix tests
micpst Sep 23, 2024
d6c8fc6
fix tests
micpst Sep 23, 2024
d7026d4
rm old benchmarks
micpst Sep 23, 2024
e8271ac
some renames and stuff
micpst Sep 23, 2024
bdcc7b3
fix benchmarks
micpst Sep 23, 2024
71f53be
merge main
micpst Sep 25, 2024
c82e579
rm chroma file
micpst Sep 25, 2024
f5a40cb
add contexts to benchmarks + fix types
micpst Sep 30, 2024
fab9d3f
small refactor
micpst Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/sql/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from bench.pipelines import CollectionEvaluationPipeline, IQLViewEvaluationPipeline, SQLViewEvaluationPipeline
from bench.utils import save
from hydra.core.hydra_config import HydraConfig
from neptune.utils import stringify_unsupported
from omegaconf import DictConfig

Expand Down Expand Up @@ -120,7 +121,7 @@ async def bench(config: DictConfig) -> None:

log.info("Evaluation finished. Saving results...")

output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
output_dir = Path(HydraConfig.get().runtime.output_dir)
metrics_file = output_dir / "metrics.json"
results_file = output_dir / "results.json"

Expand Down
10 changes: 10 additions & 0 deletions benchmarks/sql/bench/contexts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Dict, Type

from dbally.context import Context

from .superhero import SuperheroContext, UserContext

CONTEXTS_REGISTRY: Dict[str, Type[Context]] = {
UserContext.__name__: UserContext,
SuperheroContext.__name__: SuperheroContext,
}
21 changes: 21 additions & 0 deletions benchmarks/sql/bench/contexts/superhero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from dbally.context import Context


@dataclass
class UserContext(Context):
"""
Current user data.
"""

name: str = "John Doe"


@dataclass
class SuperheroContext(Context):
"""
Current user favourite superhero data.
"""

name: str = "Batman"
7 changes: 4 additions & 3 deletions benchmarks/sql/bench/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type

from omegaconf import DictConfig
from typing_extensions import Self

from ..pipelines import EvaluationResult
Expand All @@ -11,7 +12,7 @@ class Metric(ABC):
Base class for metrics.
"""

def __init__(self, config: Optional[Dict] = None) -> None:
def __init__(self, config: Optional[DictConfig] = None) -> None:
"""
Initializes the metric.

Expand All @@ -38,7 +39,7 @@ class MetricSet:
Represents a set of metrics.
"""

def __init__(self, *metrics: List[Type[Metric]]) -> None:
def __init__(self, *metrics: Type[Metric]) -> None:
"""
Initializes the metric set.

Expand All @@ -48,7 +49,7 @@ def __init__(self, *metrics: List[Type[Metric]]) -> None:
self._metrics = metrics
self.metrics: List[Metric] = []

def __call__(self, config: Dict) -> Self:
def __call__(self, config: DictConfig) -> Self:
"""
Initializes the metrics.

Expand Down
50 changes: 47 additions & 3 deletions benchmarks/sql/bench/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
from functools import cached_property
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union

from omegaconf import DictConfig
from sqlalchemy import Engine, create_engine

from dbally.context import Context
from dbally.iql._exceptions import IQLError
from dbally.iql._query import IQLQuery
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.llms.base import LLM
from dbally.llms.clients.exceptions import LLMError
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM
from dbally.views.base import BaseView

from ..contexts import CONTEXTS_REGISTRY

ViewT = TypeVar("ViewT", bound=BaseView)


@dataclass
Expand All @@ -23,7 +33,7 @@ class IQL:
generated: bool = True

@classmethod
def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL":
def from_query(cls, query: Optional[Union[IQLQuery, BaseException]]) -> "IQL":
"""
Creates an IQL object from the query.

Expand Down Expand Up @@ -81,7 +91,12 @@ class EvaluationPipeline(ABC):
Collection evaluation pipeline.
"""

def get_llm(self, config: Dict) -> LLM:
def __init__(self, config: DictConfig) -> None:
super().__init__()
self.config = config

@staticmethod
def _get_llm(config: DictConfig) -> LLM:
"""
Returns the LLM based on the configuration.

Expand All @@ -95,6 +110,13 @@ def get_llm(self, config: Dict) -> LLM:
return LocalLLM(config.model_name.split("/", 1)[1])
return LiteLLM(config.model_name)

@cached_property
def dbs(self) -> Dict[str, Engine]:
"""
Returns the database engines based on the configuration.
"""
return {db: create_engine(f"sqlite:///data/{db}.db") for db in self.config.setup.views}

@abstractmethod
async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
"""
Expand All @@ -106,3 +128,25 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
Returns:
The evaluation result.
"""


class ViewEvaluationMixin(Generic[ViewT]):
"""
View evaluation mixin.
"""

@cached_property
def contexts(self) -> List[Context]:
"""
Returns the contexts based on the configuration.
"""
return [
CONTEXTS_REGISTRY[context]() for contexts in self.config.setup.contexts.values() for context in contexts
]

@cached_property
@abstractmethod
def views(self) -> Dict[str, Type[ViewT]]:
"""
Returns the view classes mapping based on the configuration
"""
74 changes: 43 additions & 31 deletions benchmarks/sql/bench/pipelines/collection.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,68 @@
from typing import Any, Dict

from sqlalchemy import create_engine
from functools import cached_property
from typing import Any, Dict, Type, Union

import dbally
from dbally.collection.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.llms.base import LLM
from dbally.view_selection.llm_view_selector import LLMViewSelector
from dbally.views.exceptions import ViewExecutionError
from dbally.views.freeform.text2sql.view import BaseText2SQLView
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

from ..views import VIEWS_REGISTRY
from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult
from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult, ViewEvaluationMixin


class CollectionEvaluationPipeline(EvaluationPipeline):
class CollectionEvaluationPipeline(
EvaluationPipeline, ViewEvaluationMixin[Union[SqlAlchemyBaseView, BaseText2SQLView]]
):
"""
Collection evaluation pipeline.
"""

def __init__(self, config: Dict) -> None:
@cached_property
def selector(self) -> LLM:
"""
Constructs the pipeline for evaluating collection predictions.

Args:
config: The configuration for the pipeline.
Returns the selector LLM.
"""
self.collection = self.get_collection(config.setup)
return self._get_llm(self.config.setup.selector_llm)

def get_collection(self, config: Dict) -> Collection:
@cached_property
def generator(self) -> LLM:
"""
Sets up the collection based on the configuration.

Args:
config: The collection configuration.
Returns the generator LLM.
"""
return self._get_llm(self.config.setup.generator_llm)

Returns:
The collection.
@cached_property
def views(self) -> Dict[str, Type[Union[SqlAlchemyBaseView, BaseText2SQLView]]]:
"""
Returns the view classes mapping based on the configuration.
"""
return {
db: cls
for db, views in self.config.setup.views.items()
for view in views
if issubclass(cls := VIEWS_REGISTRY[view], (SqlAlchemyBaseView, BaseText2SQLView))
}

@cached_property
def collection(self) -> Collection:
"""
Returns the collection used for evaluation.
"""
generator_llm = self.get_llm(config.generator_llm)
selector_llm = self.get_llm(config.selector_llm)
view_selector = LLMViewSelector(selector_llm)
view_selector = LLMViewSelector(self.selector)

collection = dbally.create_collection(
name=config.name,
llm=generator_llm,
name=self.config.setup.name,
llm=self.generator,
view_selector=view_selector,
)
collection.n_retries = 0

for db_name, view_names in config.views.items():
db = create_engine(f"sqlite:///data/{db_name}.db")
for view_name in view_names:
view_cls = VIEWS_REGISTRY[view_name]
collection.add(view_cls, lambda: view_cls(db)) # pylint: disable=cell-var-from-loop
for db, view in self.views.items():
collection.add(view, lambda: view(self.dbs[db])) # pylint: disable=cell-var-from-loop

return collection

Expand All @@ -68,6 +79,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
try:
result = await self.collection.ask(
question=data["question"],
contexts=self.contexts,
dry_run=True,
return_natural_response=False,
)
Expand All @@ -85,10 +97,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
prediction = ExecutionResult(
view_name=result.view_name,
iql=IQLResult(
filters=IQL(source=result.context["iql"]["filters"]),
aggregation=IQL(source=result.context["iql"]["aggregation"]),
filters=IQL(source=result.metadata["iql"]["filters"]),
aggregation=IQL(source=result.metadata["iql"]["aggregation"]),
),
sql=result.context["sql"],
sql=result.metadata["sql"],
)

reference = ExecutionResult(
Expand Down
Loading
Loading