Skip to content

Commit

Permalink
DATA-2053/removing cache, support os types
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza authored and jcjc712 committed Mar 22, 2024
1 parent aa161a3 commit 4dd3a02
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 7 deletions.
27 changes: 27 additions & 0 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from typing import Any, Dict, List, Tuple

import sqlparse
from sql_metadata import Parser
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, LLMResult
from langchain.schema.messages import BaseMessage
from langchain_community.callbacks import get_openai_callback

from dataherald.config import Component, System
from dataherald.db_scanner.models.types import TableDescription
from dataherald.model.chat_model import ChatModel
from dataherald.repositories.sql_generations import (
SQLGenerationRepository,
Expand Down Expand Up @@ -146,6 +148,31 @@ def truncate_observations(self, obervarion: str, max_length: int = 2000) -> str:
else obervarion
)

def filter_tables_based_on_os(self, db_scan: List[TableDescription], question: str):
target_os_types = question.split("[OS]")[1].split("[/OS]")[0].strip().split(",")
filtered_db_scan = []
for table in db_scan:
if "os_versions" in table.metadata.get("akamai", {}):
os_versions = table.metadata["akamai"]["os_versions"]
if any(os_version in os_versions for os_version in target_os_types):
filtered_db_scan.append(table)
else:
filtered_db_scan.append(table)
return filtered_db_scan

def filter_fewshot_sample_based_on_os(
self, db_scan: List[TableDescription], fewshot_samples: List[dict]
):
filtered_fewshot_samples = []
for sample in fewshot_samples:
target_table_names = Parser(sample["sql"]).tables
if all(
target_table_name in [table.table_name for table in db_scan]
for target_table_name in target_table_names
):
filtered_fewshot_samples.append(sample)
return filtered_fewshot_samples

@abstractmethod
def generate_response(
self,
Expand Down
30 changes: 27 additions & 3 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def create_sql_agent(
**(agent_executor_kwargs or {}),
)

def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None:
def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None: # noqa: C901
vulnerabilities = VulnerabilityRepository(storage)
cves = self.extract_cve_ids(user_prompt.text)
extra_info = ""
Expand All @@ -548,12 +548,14 @@ def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None:
extra_info += (
f"{cve} was published on {vulnerability.published_date}"
)
if vulnerability.hotfix_ids:
extra_info += f"{cve} is fixed in the following patches which can be found in patches.hotfix_id: {', '.join(vulnerability.hotfix_ids)}" # noqa: E501
if vulnerability.source:
source = vulnerability.source
return extra_info, source

@override
def generate_response( # noqa: C901
def generate_response( # noqa: C901, PLR0915
self,
user_prompt: Prompt,
database_connection: DatabaseConnection,
Expand Down Expand Up @@ -596,8 +598,15 @@ def generate_response( # noqa: C901
if not db_scan:
raise ValueError("No scanned tables found for database")
few_shot_examples, instructions = context_store.retrieve_context_for_question(
user_prompt, number_of_samples=1
user_prompt, number_of_samples=5
)
if "[OS]" in user_prompt.text.upper() and "[/OS]" in user_prompt.text.upper():
db_scan = self.filter_tables_based_on_os(db_scan, user_prompt.text)
user_prompt.text = user_prompt.text.split("[/OS]")[1]
if few_shot_examples:
few_shot_examples = self.filter_fewshot_sample_based_on_os(
db_scan, few_shot_examples
)
finetunings_repository = FinetuningsRepository(storage)
finetuning = finetunings_repository.find_by_id(self.finetuning_id)
openai_fine_tuning = OpenAIFineTuning(storage, finetuning)
Expand All @@ -608,6 +617,21 @@ def generate_response( # noqa: C901
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
)
self.database = SQLDatabase.get_sql_engine(database_connection)
"""
if few_shot_examples is not None:
for example in few_shot_examples:
question = str(example["prompt_text"]).split("Question: ")[0].strip()
query = example["sql"].split("SQL: ")[0].strip()
if question == user_prompt.text.strip():
return SQLGeneration(
prompt_id=user_prompt.id,
tokens_used=0,
completed_at=datetime.datetime.now(),
sql=query,
status="VALID",
metadata={},
)
"""
toolkit = SQLDatabaseToolkit(
db=self.database,
instructions=instructions,
Expand Down
29 changes: 27 additions & 2 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def create_sql_agent(
**(agent_executor_kwargs or {}),
)

def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None:
def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None: # noqa: C901
vulnerabilities = VulnerabilityRepository(storage)
cves = self.extract_cve_ids(user_prompt.text)
extra_info = ""
Expand All @@ -697,12 +697,14 @@ def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None:
extra_info += (
f"{cve} was published on {vulnerability.published_date}"
)
if vulnerability.hotfix_ids:
extra_info += f"{cve} is fixed in the following patches which can be found in patches.hotfix_id: {', '.join(vulnerability.hotfix_ids)}" # noqa: E501
if vulnerability.source:
source = vulnerability.source
return extra_info, source

@override
def generate_response( # noqa: C901
def generate_response( # noqa: C901, PLR0915
self,
user_prompt: Prompt,
database_connection: DatabaseConnection,
Expand Down Expand Up @@ -741,8 +743,31 @@ def generate_response( # noqa: C901
else:
new_fewshot_examples = None
number_of_samples = 0
if "[OS]" in user_prompt.text.upper() and "[/OS]" in user_prompt.text.upper():
db_scan = self.filter_tables_based_on_os(db_scan, user_prompt.text)
user_prompt.text = user_prompt.text.split("[/OS]")[1]
if new_fewshot_examples is not None:
new_fewshot_examples = self.filter_fewshot_sample_based_on_os(
db_scan, new_fewshot_examples
)
number_of_samples = len(new_fewshot_examples)
logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}")
self.database = SQLDatabase.get_sql_engine(database_connection)
"""
if new_fewshot_examples is not None:
for example in new_fewshot_examples:
question = str(example["prompt_text"]).split("Question: ")[0].strip()
query = example["sql"].split("SQL: ")[0].strip()
if question == user_prompt.text.strip():
return SQLGeneration(
prompt_id=user_prompt.id,
tokens_used=0,
completed_at=datetime.datetime.now(),
sql=query,
status="VALID",
metadata={},
)
"""
toolkit = SQLDatabaseToolkit(
db=self.database,
context=context,
Expand Down
1 change: 1 addition & 0 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,5 @@ class Vulnerability(BaseModel):
date_updated: str
description: str
affected_versions: str
hotfix_ids: list[str] | None = None
source: str
6 changes: 4 additions & 2 deletions dataherald/utils/agent_prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
AGENT_PREFIX = """You are an agent designed to interact with a SQL database to find a correct SQL query for the given question.
AGENT_PREFIX = """You are an agent designed to interact with a OSQuery database to find a correct SQL query for the given question.
Database stores information about the OS and its configurations.
Given an input question, generate a syntactically correct {dialect} query, execute the query to make sure it is correct, and return the SQL query between ```sql and ``` tags.
You have access to tools for interacting with the database. You can use tools using Action: <tool_name> and Action Input: <tool_input> format.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
Expand Down Expand Up @@ -114,7 +115,8 @@
Thought: I should use the GenerateSql tool to generate a SQL query for the given question.
{agent_scratchpad}"""

FINETUNING_AGENT_PREFIX = """You are an agent designed to interact with a SQL database to find a correct SQL query for the given question.
FINETUNING_AGENT_PREFIX = """You are an agent designed to interact with a OSQuery database to find a correct SQL query for the given question.
Database stores information about the OS and its configurations.
Given an input question, return a syntactically correct {dialect} query, always execute the query to make sure it is correct, and return the SQL query in ```sql and ``` format.
Using `current_date()` or `current_datetime()` in SQL queries is banned, use SystemTime tool to get the exact time of the query execution.
Expand Down

0 comments on commit 4dd3a02

Please sign in to comment.