Skip to content

Commit

Permalink
Merge 'origin/main' into DH5298/osquery
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Mar 8, 2024
2 parents f13f0a6 + 6f39892 commit 9d84cdf
Show file tree
Hide file tree
Showing 18 changed files with 380 additions and 66 deletions.
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ GOLDEN_SQL_COLLECTION = 'my-golden-records'
#Pinecone info. These fields are required if the vector store used is Pinecone
PINECONE_API_KEY =
PINECONE_ENVIRONMENT =
#AstraDB info. These fields are required if the vector store used is AstraDB
ASTRA_DB_API_ENDPOINT =
ASTRA_DB_APPLICATION_TOKEN =


# Module implementations to be used names for each required component. You can use the default ones or create your own
API_SERVER = "dataherald.api.fastapi.FastAPI"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ curl -X 'POST' \
}'
```

### Run scripts
### How to migrate data between versions
Our engine is under ongoing development and in order to support the latest features, we provide scripts to migrate the data from the previous version to the latest version. You can find all of the scripts in the `dataherald.scripts` module. To run the migration script, execute the following command:

```
Expand Down
1 change: 0 additions & 1 deletion dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ def add_golden_sqls(
{"items": [row.dict() for row in golden_sqls]},
"golden_sql_not_created",
)

return [GoldenSQLResponse(**golden_sql.dict()) for golden_sql in golden_sqls]

@override
Expand Down
1 change: 0 additions & 1 deletion dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL
metadata=record.metadata,
)
stored_golden_sqls.append(golden_sqls_repository.insert(golden_sql))

self.vector_store.add_records(stored_golden_sqls, self.golden_sql_collection)
return stored_golden_sqls

Expand Down
13 changes: 7 additions & 6 deletions dataherald/sql_database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,26 @@

from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.utils.encrypt import FernetEncrypt
from dataherald.utils.error_codes import CustomError
from dataherald.utils.s3 import S3

logger = logging.getLogger(__name__)


# Define a custom exception class
class SQLInjectionError(Exception):
class SQLInjectionError(CustomError):
pass


class InvalidDBConnectionError(Exception):
class InvalidDBConnectionError(CustomError):
pass


class EmptyDBError(Exception):
class EmptyDBError(CustomError):
pass


class SSHInvalidDatabaseConnectionError(Exception):
class SSHInvalidDatabaseConnectionError(CustomError):
pass


Expand Down Expand Up @@ -89,7 +90,7 @@ def get_sql_engine(
return engine
except Exception as e:
raise SSHInvalidDatabaseConnectionError(
f"Invalid SSH connection, {e}"
"Invalid SSH connection", description=str(e)
) from e
try:
db_uri = unquote(fernet_encrypt.decrypt(database_info.connection_uri))
Expand All @@ -107,7 +108,7 @@ def get_sql_engine(
DBConnections.add(database_info.id, engine)
except Exception as e:
raise InvalidDBConnectionError( # noqa: B904
f"Unable to connect to db: {database_info.alias}, {e}"
f"Unable to connect to db: {database_info.alias}", description=str(e)
)
return engine

Expand Down
19 changes: 15 additions & 4 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator
from dataherald.types import FineTuningStatus, Prompt, SQLGeneration
from dataherald.utils.agent_prompts import (
ERROR_PARSING_MESSAGE,
FINETUNING_AGENT_PREFIX,
FINETUNING_AGENT_PREFIX_FINETUNING_ONLY,
FINETUNING_AGENT_SUFFIX,
Expand Down Expand Up @@ -201,11 +202,11 @@ def _run(
for table in self.db_scan:
col_rep = ""
for column in table.columns:
if column.description is not None:
if column.description:
col_rep += f"{column.name}: {column.description}, "
else:
col_rep += f"{column.name}, "
if table.description is not None:
if table.description:
table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}"
else:
table_rep = f"Table {table.table_name} contain columns: [{col_rep}]"
Expand Down Expand Up @@ -368,8 +369,18 @@ def _run(
for table in self.db_scan:
if table.table_name in table_names_list:
tables_schema += table.table_schema + "\n"
descriptions = []
if table.description is not None:
tables_schema += "Table description: " + table.description + "\n"
descriptions.append(
f"Table `{table.table_name}`: {table.description}\n"
)
for column in table.columns:
if column.description is not None:
descriptions.append(
f"Column `{column.name}`: {column.description}\n"
)
if len(descriptions) > 0:
tables_schema += f"/*\n{''.join(descriptions)}*/\n"
if tables_schema == "":
tables_schema += "Tables not found in the database"
return tables_schema
Expand Down Expand Up @@ -600,7 +611,7 @@ def generate_response( # noqa: C901
max_execution_time=int(os.environ.get("DH_ENGINE_TIMEOUT", 150)),
)
agent_executor.return_intermediate_steps = True
agent_executor.handle_parsing_errors = True
agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE
if self.augment_prompt(user_prompt, storage):
user_prompt.text += " \n" + self.augment_prompt(user_prompt, storage)
with get_openai_callback() as cb:
Expand Down
3 changes: 2 additions & 1 deletion dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from dataherald.types import Prompt, SQLGeneration
from dataherald.utils.agent_prompts import (
AGENT_PREFIX,
ERROR_PARSING_MESSAGE,
FORMAT_INSTRUCTIONS,
PLAN_BASE,
PLAN_WITH_FEWSHOT_EXAMPLES,
Expand Down Expand Up @@ -742,7 +743,7 @@ def generate_response( # noqa: C901
max_execution_time=int(os.environ.get("DH_ENGINE_TIMEOUT", 150)),
)
agent_executor.return_intermediate_steps = True
agent_executor.handle_parsing_errors = True
agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE
if self.augment_prompt(user_prompt, storage):
user_prompt.text += " \n" + self.augment_prompt(user_prompt, storage)
with get_openai_callback() as cb:
Expand Down
Loading

0 comments on commit 9d84cdf

Please sign in to comment.