Skip to content

Commit

Permalink
adding the cve source to metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Mar 12, 2024
1 parent 9d84cdf commit c322705
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 13 deletions.
22 changes: 15 additions & 7 deletions dataherald/services/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def generate_response_with_timeout(self, sql_generator, user_prompt, db_connecti
user_prompt=user_prompt, database_connection=db_connection
)

def update_sql_generation(
self, initial_sql_generation: SQLGeneration, sql_generation: SQLGeneration
) -> SQLGeneration:
initial_sql_generation.sql = sql_generation.sql
initial_sql_generation.tokens_used = sql_generation.tokens_used
initial_sql_generation.completed_at = datetime.now()
initial_sql_generation.status = sql_generation.status
initial_sql_generation.error = sql_generation.error
initial_sql_generation.metadata.update(sql_generation.metadata)
return self.sql_generation_repository.update(initial_sql_generation)

def create(
self, prompt_id: str, sql_generation_request: SQLGenerationRequest
) -> SQLGeneration:
Expand All @@ -56,7 +67,9 @@ def create(
llm_config=sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
metadata=sql_generation_request.metadata,
metadata=sql_generation_request.metadata
if sql_generation_request.metadata
else {},
)
self.sql_generation_repository.insert(initial_sql_generation)
prompt_repository = PromptRepository(self.storage)
Expand Down Expand Up @@ -152,12 +165,7 @@ def create(
)
initial_sql_generation.evaluate = sql_generation_request.evaluate
initial_sql_generation.confidence_score = confidence_score
initial_sql_generation.sql = sql_generation.sql
initial_sql_generation.tokens_used = sql_generation.tokens_used
initial_sql_generation.completed_at = datetime.now()
initial_sql_generation.status = sql_generation.status
initial_sql_generation.error = sql_generation.error
return self.sql_generation_repository.update(initial_sql_generation)
return self.update_sql_generation(initial_sql_generation, sql_generation)

def get(self, query) -> list[SQLGeneration]:
return self.sql_generation_repository.find_by(query)
Expand Down
1 change: 1 addition & 0 deletions dataherald/sql_database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def parser_to_filter_commands(cls, command: str) -> str:
"TRUNCATE",
"MERGE",
"EXECUTE",
"CREATE",
]
parsed_command = sqlparse.parse(command)

Expand Down
12 changes: 9 additions & 3 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None:
vulnerabilities = VulnerabilityRepository(storage)
cves = self.extract_cve_ids(user_prompt.text)
extra_info = ""
source = ""
if len(cves) > 0:
for cve in cves:
vulnerability = vulnerabilities.find_by({"cve_id": cve})[0]
Expand All @@ -533,7 +534,9 @@ def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None:
extra_info += (
f"{cve} was published on {vulnerability.published_date}"
)
return extra_info
if vulnerability.source:
source = vulnerability.source
return extra_info, source

@override
def generate_response( # noqa: C901
Expand Down Expand Up @@ -612,8 +615,9 @@ def generate_response( # noqa: C901
)
agent_executor.return_intermediate_steps = True
agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE
if self.augment_prompt(user_prompt, storage):
user_prompt.text += " \n" + self.augment_prompt(user_prompt, storage)
cve_augmented, cve_source = self.augment_prompt(user_prompt, storage)
if cve_augmented:
user_prompt.text += " \n" + cve_augmented
with get_openai_callback() as cb:
try:
logger.info(f"Prompt: {user_prompt.text}")
Expand Down Expand Up @@ -647,6 +651,8 @@ def generate_response( # noqa: C901
response.sql = replace_unprocessable_characters(sql_query)
response.tokens_used = cb.total_tokens
response.completed_at = datetime.datetime.now()
if cve_source:
response.metadata.update({"cve_source": cve_source})
return self.create_sql_query_status(
self.database,
response.sql,
Expand Down
13 changes: 10 additions & 3 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None:
vulnerabilities = VulnerabilityRepository(storage)
cves = self.extract_cve_ids(user_prompt.text)
extra_info = ""
source = ""
if len(cves) > 0:
for cve in cves:
vulnerability = vulnerabilities.find_by({"cve_id": cve})[0]
Expand All @@ -682,7 +683,9 @@ def augment_prompt(self, user_prompt: Prompt, storage: DB) -> None:
extra_info += (
f"{cve} was published on {vulnerability.published_date}"
)
return extra_info
if vulnerability.source:
source = vulnerability.source
return extra_info, source

@override
def generate_response( # noqa: C901
Expand All @@ -697,6 +700,7 @@ def generate_response( # noqa: C901
prompt_id=user_prompt.id,
llm_config=self.llm_config,
created_at=datetime.datetime.now(),
metadata={},
)
self.llm = self.model.get_model(
database_connection=database_connection,
Expand Down Expand Up @@ -744,8 +748,9 @@ def generate_response( # noqa: C901
)
agent_executor.return_intermediate_steps = True
agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE
if self.augment_prompt(user_prompt, storage):
user_prompt.text += " \n" + self.augment_prompt(user_prompt, storage)
cve_augmented, cve_source = self.augment_prompt(user_prompt, storage)
if cve_augmented:
user_prompt.text += " \n" + cve_augmented
with get_openai_callback() as cb:
try:
logger.info(f"Prompt: {user_prompt.text}")
Expand Down Expand Up @@ -778,6 +783,8 @@ def generate_response( # noqa: C901
response.sql = replace_unprocessable_characters(sql_query)
response.tokens_used = cb.total_tokens
response.completed_at = datetime.datetime.now()
if cve_source:
response.metadata.update({"cve_source": cve_source})
return self.create_sql_query_status(
self.database,
response.sql,
Expand Down
1 change: 1 addition & 0 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,4 @@ class Vulnerability(BaseModel):
date_updated: str
description: str
affected_versions: str
source: str

0 comments on commit c322705

Please sign in to comment.