Skip to content

Commit

Permalink
fix input
Browse files Browse the repository at this point in the history
  • Loading branch information
ishworgiri1999 committed Aug 6, 2024
1 parent ab5e9d2 commit 253ecbe
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
14 changes: 8 additions & 6 deletions src/routes/v1/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from db.my_db import get_db
from repository.finding import get_finding_repository
from repository.task import TaskRepository, get_task_repository
from worker.types import GenerateReportInput
from worker.worker import worker

router = APIRouter(prefix="/upload")
Expand Down Expand Up @@ -57,15 +58,16 @@ async def upload(
find.recommendation_task_id = recommendation_task.id
findings.append(find)
finding_repository.create_findings(findings)
worker_input = GenerateReportInput(
recommendation_task_id=recommendation_task.id,
generate_long_solution=data.preferences.long_description or True,
generate_search_terms=data.preferences.search_terms or True,
generate_aggregate_solutions=data.preferences.aggregated_solutions or True,
)

celery_result = worker.send_task(
"worker.generate_report",
args=[
recommendation_task.id,
data.preferences.long_description,
data.preferences.search_terms,
data.preferences.aggregated_solutions,
],
args=[worker_input.model_dump()],
)

# update the task with the celery task id
Expand Down
8 changes: 8 additions & 0 deletions src/worker/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel


class GenerateReportInput(BaseModel):
recommendation_task_id: int
generate_long_solution: bool = True
generate_search_terms: bool = True
generate_aggregate_solutions: bool = True
19 changes: 11 additions & 8 deletions src/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ai.LLM.Strategies.OLLAMAService import OLLAMAService
from data.VulnerabilityReport import create_from_flama_json
from ai.Grouping.FindingGrouper import FindingGrouper
from worker.types import GenerateReportInput


logger = logging.getLogger(__name__)
Expand All @@ -41,16 +42,18 @@ def error(self, exc, task_id, args, kwargs, einfo):

@worker.task(name="worker.generate_report", on_failure=error)
def generate_report(
recommendation_task_id: int,
generate_long_solution: bool = True,
generate_search_terms: bool = True,
generate_aggregate_solutions: bool = True,
input: dict,
):
try:
input = GenerateReportInput.model_validate(input)
except Exception as e:
logger.error(f"Error validating input: {e}")
return

# importing here so importing worker does not import all the dependencies
# from ai.LLM.LLMServiceStrategy import LLMServiceStrategy
# from ai.LLM.Strategies.OLLAMAService import OLLAMAService
# from data.VulnerabilityReport import create_from_flama_json
recommendation_task_id = input.recommendation_task_id
generate_long_solution = input.generate_long_solution
generate_search_terms = input.generate_search_terms
generate_aggregate_solutions = input.generate_aggregate_solutions

ollama_strategy = OLLAMAService()
llm_service = LLMServiceStrategy(ollama_strategy)
Expand Down

0 comments on commit 253ecbe

Please sign in to comment.