From 253ecbe2d425f1db2c0de44802b627585139f201 Mon Sep 17 00:00:00 2001 From: ishwor Giri Date: Wed, 7 Aug 2024 00:51:26 +0200 Subject: [PATCH] fix input --- src/routes/v1/upload.py | 14 ++++++++------ src/worker/types.py | 8 ++++++++ src/worker/worker.py | 19 +++++++++++-------- 3 files changed, 27 insertions(+), 14 deletions(-) create mode 100644 src/worker/types.py diff --git a/src/routes/v1/upload.py b/src/routes/v1/upload.py index 02aedd9..4777913 100644 --- a/src/routes/v1/upload.py +++ b/src/routes/v1/upload.py @@ -13,6 +13,7 @@ from db.my_db import get_db from repository.finding import get_finding_repository from repository.task import TaskRepository, get_task_repository +from worker.types import GenerateReportInput from worker.worker import worker router = APIRouter(prefix="/upload") @@ -57,15 +58,16 @@ async def upload( find.recommendation_task_id = recommendation_task.id findings.append(find) finding_repository.create_findings(findings) + worker_input = GenerateReportInput( + recommendation_task_id=recommendation_task.id, + generate_long_solution=data.preferences.long_description or True, + generate_search_terms=data.preferences.search_terms or True, + generate_aggregate_solutions=data.preferences.aggregated_solutions or True, + ) celery_result = worker.send_task( "worker.generate_report", - args=[ - recommendation_task.id, - data.preferences.long_description, - data.preferences.search_terms, - data.preferences.aggregated_solutions, - ], + args=[worker_input.model_dump()], ) # update the task with the celery task id diff --git a/src/worker/types.py b/src/worker/types.py new file mode 100644 index 0000000..13504cb --- /dev/null +++ b/src/worker/types.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class GenerateReportInput(BaseModel): + recommendation_task_id: int + generate_long_solution: bool = True + generate_search_terms: bool = True + generate_aggregate_solutions: bool = True diff --git a/src/worker/worker.py b/src/worker/worker.py index 9994406..e309338 100644 --- a/src/worker/worker.py +++ b/src/worker/worker.py @@ -19,6 +19,7 @@ from ai.LLM.Strategies.OLLAMAService import OLLAMAService from data.VulnerabilityReport import create_from_flama_json from ai.Grouping.FindingGrouper import FindingGrouper +from worker.types import GenerateReportInput logger = logging.getLogger(__name__) @@ -41,16 +42,18 @@ def error(self, exc, task_id, args, kwargs, einfo): @worker.task(name="worker.generate_report", on_failure=error) def generate_report( - recommendation_task_id: int, - generate_long_solution: bool = True, - generate_search_terms: bool = True, - generate_aggregate_solutions: bool = True, + input: dict, ): + try: + input = GenerateReportInput.model_validate(input) + except Exception as e: + logger.error(f"Error validating input: {e}") + return - # importing here so importing worker does not import all the dependencies - # from ai.LLM.LLMServiceStrategy import LLMServiceStrategy - # from ai.LLM.Strategies.OLLAMAService import OLLAMAService - # from data.VulnerabilityReport import create_from_flama_json + recommendation_task_id = input.recommendation_task_id + generate_long_solution = input.generate_long_solution + generate_search_terms = input.generate_search_terms + generate_aggregate_solutions = input.generate_aggregate_solutions ollama_strategy = OLLAMAService() llm_service = LLMServiceStrategy(ollama_strategy)