Skip to content

Commit

Permalink
feat: freeform text2sql with static configuration (#36)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Ludwik Trammer <[email protected]>
Co-authored-by: Michał Pstrąg <[email protected]>
  • Loading branch information
3 people authored May 28, 2024
1 parent edd6de9 commit 8f7a166
Show file tree
Hide file tree
Showing 19 changed files with 395 additions and 604 deletions.
84 changes: 9 additions & 75 deletions docs/how-to/update_similarity_indexes.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,16 @@ The Similarity Index is a feature provided by db-ally that takes user input and

While Similarity Indexes can be used directly, they are usually used with [Views](../concepts/views.md), annotating arguments to filter methods. This technique lets db-ally automatically match user-provided arguments to the most similar value in the data source. You can see an example of using similarity indexes with views on the [Quickstart Part 2: Semantic Similarity](../quickstart/quickstart2.md) page.

Similarity Indexes are designed to index all possible values (e.g., on disk or in a different data store). Consequently, when the data source undergoes changes, the Similarity Index must update to reflect these alterations. This guide will explain different ways to update Similarity Indexes.
Similarity Indexes are designed to index all possible values (e.g., on disk or in a different data store). Consequently, when the data source undergoes changes, the Similarity Index must update to reflect these alterations. This guide will explain how to update Similarity Indexes in your code.

You can update the Similarity Index through Python code or via the db-ally CLI. The following sections explain how to update these indexes using both methods:
* [Update a Single Similarity Index](#update-a-single-similarity-index)
* [Update Similarity Indexes from all Views in a Collection](#update-similarity-indexes-from-all-views-in-a-collection)

* [Update Similarity Indexes via the CLI](#update-similarity-indexes-via-the-cli)
* [Update Similarity Indexes via Python Code](#update-similarity-indexes-via-python-code)
* [Update on a Single Similarity Index](#update-on-a-single-similarity-index)
* [Update Similarity Indexes from all Views in a Collection](#update-similarity-indexes-from-all-views-in-a-collection)
* [Detect Similarity Indexes in Views](#detect-similarity-indexes-in-views)

## Update Similarity Indexes via the CLI

To update Similarity Indexes via the CLI, you can use the `dbally update-index` command. This command requires a path to what you wish to update. The path should follow this format: "path.to.module:ViewName.method_name.argument_name" where each part after the colon is optional. The more specific your target is, the fewer Similarity Indexes will be updated.

For example, to update all Similarity Indexes in a module `my_module.views`, use this command:

```bash
dbally update-index my_module.views
```

To update all Similarity Indexes in a specific View, add the name of the View following the module path:

```bash
dbally update-index my_module.views:MyView
```

To update all Similarity Indexes within a specific method of a View, add the method's name after the View name:

```bash
dbally update-index my_module.views:MyView.method_name
```

Lastly, to update all Similarity Indexes in a particular argument of a method, add the argument name after the method name:

```bash
dbally update-index my_module.views:MyView.method_name.argument_name
```

## Update Similarity Indexes via Python Code
### Update on a Single Similarity Index
## Update a Single Similarity Index
To manually update a Similarity Index, call the `update` method on the Similarity Index object. The `update` method will re-fetch all possible values from the data source and re-index them. Below is an example of how to manually update a Similarity Index:

```python
from db_ally import SimilarityIndex
from dbally import SimilarityIndex

# Create a similarity index
similarity_index = SimilarityIndex(fetcher=fetcher, store=store)
Expand All @@ -56,14 +22,14 @@ similarity_index = SimilarityIndex(fetcher=fetcher, store=store)
await similarity_index.update()
```

### Update Similarity Indexes from all Views in a Collection
## Update Similarity Indexes from all Views in a Collection
If you have a [collection](../concepts/collections.md) and want to update Similarity Indexes in all views, you can use the `update_similarity_indexes` method. This method will update all Similarity Indexes in all views registered with the collection:

```python
from db_ally import create_collection
from db_ally.llms.litellm import LiteLLM
from dbally import create_collection
from dbally.llms.litellm import LiteLLM

my_collection = create_collection("collection_name", llm=LiteLLM())
my_collection = create_collection("my_collection", llm=LiteLLM())

# ... add views to the collection

Expand All @@ -72,35 +38,3 @@ await my_collection.update_similarity_indexes()

!!! info
Alternatively, for more advanced use cases, you can use Collection's [`get_similarity_indexes`][dbally.Collection.get_similarity_indexes] method to get a list of all Similarity Indexes (allongside the places where they are used) and update them individually.

### Detect Similarity Indexes in Views
If you are using Similarity Indexes to annotate arguments in views, you can use the [`SimilarityIndexDetector`][dbally.similarity.detector.SimilarityIndexDetector] to locate all Similarity Indexes in a view and update them.

For example, to update all Similarity Indexes in a view named `MyView` in a module labeled `my_module.views`, use the following code:

```python
from db_ally import SimilarityIndexDetector

detector = SimilarityIndexDetector.from_path("my_module.views:MyView")
[await index.update() for index in detector.list_indexes()]
```

The `from_path` method constructs a `SimilarityIndexDetector` object from a view path string in the same format as the CLI command. The `list_indexes` method returns a list of Similarity Indexes detected in the view.

For instance, to detect all Similarity Indexes in a module, provide only the path:

```python
detector = SimilarityIndexDetector.from_path("my_module.views")
```

Conversely, to detect all Similarity Indexes in a specific method of a view, provide the method name:

```python
detector = SimilarityIndexDetector.from_path("my_module.views:MyView.method_name")
```

Lastly, to detect all Similarity Indexes in a particular argument of a method, provide the argument name:

```python
detector = SimilarityIndexDetector.from_path("my_module.views:MyView.method_name.argument_name")
```
7 changes: 0 additions & 7 deletions docs/reference/similarity/detector.md

This file was deleted.

75 changes: 75 additions & 0 deletions examples/freeform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import asyncio
from typing import List

import sqlalchemy

import dbally
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.llms import LiteLLM
from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig


class MyText2SqlView(BaseText2SQLView):
"""
A Text2SQL view for the example.
"""

def get_tables(self) -> List[TableConfig]:
"""
Get the tables used by the view.
Returns:
A list of tables.
"""
return [
TableConfig(
name="customers",
columns=[
ColumnConfig("id", "SERIAL PRIMARY KEY"),
ColumnConfig("name", "VARCHAR(255)"),
ColumnConfig("city", "VARCHAR(255)"),
ColumnConfig("country", "VARCHAR(255)"),
ColumnConfig("age", "INTEGER"),
],
),
TableConfig(
name="products",
columns=[
ColumnConfig("id", "SERIAL PRIMARY KEY"),
ColumnConfig("name", "VARCHAR(255)"),
ColumnConfig("category", "VARCHAR(255)"),
ColumnConfig("price", "REAL"),
],
),
TableConfig(
name="purchases",
columns=[
ColumnConfig("customer_id", "INTEGER"),
ColumnConfig("product_id", "INTEGER"),
ColumnConfig("quantity", "INTEGER"),
ColumnConfig("date", "TEXT"),
],
),
]


async def main():
"""Main function to run the example."""
engine = sqlalchemy.create_engine("sqlite:///:memory:")

# Create tables from config
with engine.connect() as connection:
for table_config in MyText2SqlView(engine).get_tables():
connection.execute(sqlalchemy.text(table_config.ddl))

llm = LiteLLM()
collection = dbally.create_collection("text2sql", llm=llm, event_handlers=[CLIEventHandler()])
collection.add(MyText2SqlView, lambda: MyText2SqlView(engine))

await collection.ask("What are the names of products bought by customers from London?")
await collection.ask("Which customers bought products from the category 'electronics'?")
await collection.ask("What is the total quantity of products bought by customers from the UK?")


if __name__ == "__main__":
asyncio.run(main())
1 change: 0 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ nav:
- reference/similarity/similarity_fetcher/index.md
- reference/similarity/similarity_fetcher/sqlalchemy.md
- reference/similarity/similarity_fetcher/sqlalchemy_simple.md
- reference/similarity/detector.md
- Embeddings:
- reference/embeddings/index.md
- reference/embeddings/litellm.md
Expand Down
40 changes: 17 additions & 23 deletions src/dbally/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import inspect
import textwrap
import time
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Type, TypeVar

from dbally.audit.event_handlers.base import EventHandler
from dbally.audit.event_tracker import EventTracker
Expand All @@ -14,8 +15,7 @@
from dbally.similarity.index import AbstractSimilarityIndex
from dbally.utils.errors import NoViewFoundError
from dbally.view_selection.base import ViewSelector
from dbally.views.base import BaseView
from dbally.views.structured import BaseStructuredView
from dbally.views.base import BaseView, IndexLocation


class IndexUpdateError(Exception):
Expand Down Expand Up @@ -248,26 +248,22 @@ async def ask(

return result

def get_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[Tuple[str, str, str]]]:
def get_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]:
"""
List all similarity indexes from all structured views in the collection.
List all similarity indexes from all views in the collection.
Returns:
Dictionary with similarity indexes as keys and values containing lists of places where they are used
(represented by a tupple containing view name, method name and argument name)
Mapping of similarity indexes to their locations, following view format.
For:
- freeform views, the format is (view_name, table_name, column_name)
- structured views, the format is (view_name, filter_name, argument_name)
"""
indexes: Dict[AbstractSimilarityIndex, List[Tuple[str, str, str]]] = {}
indexes = defaultdict(list)
for view_name in self._views:
view = self.get(view_name)

if not isinstance(view, BaseStructuredView):
continue

filters = view.list_filters()
for filter_ in filters:
for param in filter_.parameters:
if param.similarity_index:
indexes.setdefault(param.similarity_index, []).append((view_name, filter_.name, param.name))
view_indexes = view.list_similarity_indexes()
for index, location in view_indexes.items():
indexes[index].extend(location)
return indexes

async def update_similarity_indexes(self) -> None:
Expand All @@ -280,14 +276,12 @@ async def update_similarity_indexes(self) -> None:
the dictionary were updated successfully.
"""
indexes = self.get_similarity_indexes()
update_corutines = [index.update() for index in indexes]
results = await asyncio.gather(*update_corutines, return_exceptions=True)
update_coroutines = [index.update() for index in indexes]
results = await asyncio.gather(*update_coroutines, return_exceptions=True)
failed_indexes = {
index: exception for index, exception in zip(indexes, results) if isinstance(exception, Exception)
}
if failed_indexes:
failed_locations = [loc for index in failed_indexes for loc in indexes[index]]
description = ", ".join(
f"{view_name}.{method_name}.{param_name}" for view_name, method_name, param_name in failed_locations
)
raise IndexUpdateError(f"Failed to update similarity indexes for {description}", failed_indexes)
descriptions = ", ".join(".".join(name for name in location) for location in failed_locations)
raise IndexUpdateError(f"Failed to update similarity indexes for {descriptions}", failed_indexes)
2 changes: 1 addition & 1 deletion src/dbally/prompts/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class PromptTemplate:
Class for prompt templates
Attributes:
response_format: Optional argument used in the OpenAI API - used to force json output
response_format: Optional argument for OpenAI Turbo models - may be used to force json output
llm_response_parser: Function parsing the LLM response into IQL
"""

Expand Down
Loading

0 comments on commit 8f7a166

Please sign in to comment.