-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: freeform text2sql with static configuration (#36)
--------- Co-authored-by: Ludwik Trammer <[email protected]> Co-authored-by: Michał Pstrąg <[email protected]>
- Loading branch information
1 parent
edd6de9
commit 8f7a166
Showing
19 changed files
with
395 additions
and
604 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import asyncio | ||
from typing import List | ||
|
||
import sqlalchemy | ||
|
||
import dbally | ||
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler | ||
from dbally.llms import LiteLLM | ||
from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig | ||
|
||
|
||
class MyText2SqlView(BaseText2SQLView): | ||
""" | ||
A Text2SQL view for the example. | ||
""" | ||
|
||
def get_tables(self) -> List[TableConfig]: | ||
""" | ||
Get the tables used by the view. | ||
Returns: | ||
A list of tables. | ||
""" | ||
return [ | ||
TableConfig( | ||
name="customers", | ||
columns=[ | ||
ColumnConfig("id", "SERIAL PRIMARY KEY"), | ||
ColumnConfig("name", "VARCHAR(255)"), | ||
ColumnConfig("city", "VARCHAR(255)"), | ||
ColumnConfig("country", "VARCHAR(255)"), | ||
ColumnConfig("age", "INTEGER"), | ||
], | ||
), | ||
TableConfig( | ||
name="products", | ||
columns=[ | ||
ColumnConfig("id", "SERIAL PRIMARY KEY"), | ||
ColumnConfig("name", "VARCHAR(255)"), | ||
ColumnConfig("category", "VARCHAR(255)"), | ||
ColumnConfig("price", "REAL"), | ||
], | ||
), | ||
TableConfig( | ||
name="purchases", | ||
columns=[ | ||
ColumnConfig("customer_id", "INTEGER"), | ||
ColumnConfig("product_id", "INTEGER"), | ||
ColumnConfig("quantity", "INTEGER"), | ||
ColumnConfig("date", "TEXT"), | ||
], | ||
), | ||
] | ||
|
||
|
||
async def main(): | ||
"""Main function to run the example.""" | ||
engine = sqlalchemy.create_engine("sqlite:///:memory:") | ||
|
||
# Create tables from config | ||
with engine.connect() as connection: | ||
for table_config in MyText2SqlView(engine).get_tables(): | ||
connection.execute(sqlalchemy.text(table_config.ddl)) | ||
|
||
llm = LiteLLM() | ||
collection = dbally.create_collection("text2sql", llm=llm, event_handlers=[CLIEventHandler()]) | ||
collection.add(MyText2SqlView, lambda: MyText2SqlView(engine)) | ||
|
||
await collection.ask("What are the names of products bought by customers from London?") | ||
await collection.ask("Which customers bought products from the category 'electronics'?") | ||
await collection.ask("What is the total quantity of products bought by customers from the UK?") | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.