diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..2e12d55 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,36 @@ +name: API Docs + +on: + push: + branches: + - main + pull_request: + branches: + - main +jobs: + generate-docs: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Python 3.11 + uses: actions/setup-python@v2 + with: + python-version: 3.11 + + - name: Install dependencies + run: | + pip install pipenv + pipenv install + + - name: Generate docs + run: | + pipenv run python src/extract-docs.py + + - name: Save docs as artifact + uses: actions/upload-artifact@v2 + with: + name: docs + path: .docs + if-no-files-found: error diff --git a/README.md b/README.md index 056c9ba..daec723 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ After starting the application, you can access the API at `http://localhost:8000 ### Docker building the images + ```bash docker compose build ``` @@ -64,23 +65,23 @@ To run the code within Docker, use the following command in the root directory o docker compose up ``` - - Add the `-d` flag to run the containers in the background: `docker compose up -d`. - - - ### Available Routes Currently, these routes are generated by fastapi. ``` +HEAD, GET /openapi.json +HEAD, GET /docs +HEAD, GET /docs/oauth2-redirect +HEAD, GET /redoc GET /api/v1/tasks/ DELETE /api/v1/tasks/{task_id} DELETE /api/v1/tasks/ GET /api/v1/tasks/{task_id}/status POST /api/v1/recommendations/ +POST /api/v1/recommendations/aggregated POST /api/v1/upload/ GET / ``` diff --git a/src/ai/Grouping/FindingGrouper.py b/src/ai/Grouping/FindingGrouper.py index 8a73c33..7e7ac09 100644 --- a/src/ai/Grouping/FindingGrouper.py +++ b/src/ai/Grouping/FindingGrouper.py @@ -9,7 +9,9 @@ class FindingGrouper: - def __init__(self, vulnerability_report: VulnerabilityReport, llm_service: BaseLLMService): + def __init__( + self, vulnerability_report: VulnerabilityReport, llm_service: BaseLLMService + ): self.vulnerability_report = vulnerability_report self.llm_service = llm_service self.batcher = FindingBatcher(llm_service) @@ -20,5 +22,7 @@ def generate_aggregated_solutions(self): for batch in tqdm(self.batches, desc="Generating Aggregated Solutions"): result_list = self.llm_service.generate_aggregated_solution(batch) for result in result_list: - self.aggregated_solutions.append(AggregatedSolution(result[1], result[0], result[2])) # Solution, Findings, Metadata + self.aggregated_solutions.append( + AggregatedSolution().from_result(result[1], result[0], result[2]) + ) # Solution, Findings, Metadata self.vulnerability_report.set_aggregated_solutions(self.aggregated_solutions) diff --git a/src/app.py b/src/app.py index 79b9da3..c472fac 100644 --- a/src/app.py +++ b/src/app.py @@ -7,6 +7,7 @@ import ai.LLM.Strategies.OLLAMAService from config import config + import routes import routes.v1.recommendations import routes.v1.task diff --git a/src/config.py b/src/config.py index 80e5868..0c4129b 100644 --- a/src/config.py +++ b/src/config.py @@ -1,58 +1,33 @@ from pydantic import ( - BaseModel, Field, - RedisDsn, ValidationInfo, - model_validator, - root_validator, ) from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import Field, RedisDsn, field_validator +from pydantic import Field, field_validator from typing import Optional -class SubModel(BaseModel): - foo: str = "bar" - apple: int = 1 - - class Config(BaseSettings): model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") - ollama_url: str = Field( - json_schema_extra="OLLAMA_URL", default="http://localhost:11434" - ) - ollama_model: str = Field(json_schema_extra="OLLAMA_MODEL", default="phi3:mini") + ollama_url: str = Field(default="http://localhost:11434") + ollama_model: str = Field(default="phi3:mini") - ai_strategy: Optional[str] = Field( - json_schema_extra="AI_STRATEGY", default="OLLAMA" - ) - anthropic_api_key: Optional[str] = Field( - json_schema_extra="ANTHROPIC_API_KEY", default=None - ) - openai_api_key: Optional[str] = Field( - json_schema_extra="OPENAI_API_KEY", default=None - ) + ai_strategy: Optional[str] = Field(default="OLLAMA") + anthropic_api_key: Optional[str] = Field(default=None) + openai_api_key: Optional[str] = Field(default=None) - postgres_server: str = Field( - json_schema_extra="POSTGRES_SERVER", default="localhost" - ) - postgres_port: int = Field(json_schema_extra="POSTGRES_PORT", default=5432) - postgres_db: str = Field(json_schema_extra="POSTGRES_DB", default="app") - postgres_user: str = Field(json_schema_extra="POSTGRES_USER", default="postgres") - postgres_password: str = Field( - json_schema_extra="POSTGRES_PASSWORD", default="postgres" - ) - queue_processing_limit: int = Field( - json_schema_extra="QUEUE_PROCESSING_LIMIT", default=10 - ) - redis_endpoint: str = Field( - json_schema_extra="REDIS_ENDPOINT", default="redis://localhost:6379/0" - ) - environment: str = Field(env="ENVIRONMENT", default="development") - db_debug: bool = Field(env="DB_DEBUG", default=False) + postgres_server: str = Field(default="localhost") + postgres_port: int = Field(default=5432) + postgres_db: str = Field(default="app") + postgres_user: str = Field(default="postgres") + postgres_password: str = Field(default="postgres") + queue_processing_limit: int = Field(default=10) + redis_endpoint: str = Field(default="redis://localhost:6379/0") + environment: str = Field(default="development") + db_debug: bool = Field(default=False) @field_validator( "ai_strategy", @@ -66,7 +41,6 @@ def check_ai_strategy(cls, ai_strategy, values): @field_validator("openai_api_key") def check_api_key(cls, api_key, info: ValidationInfo): - print(info) if info.data["ai_strategy"] == "OPENAI" and not api_key: raise ValueError("OPENAI_API_KEY is required when ai_strategy is OPENAI") return api_key diff --git a/src/data/AggregatedSolution.py b/src/data/AggregatedSolution.py index 52e675c..64473f2 100644 --- a/src/data/AggregatedSolution.py +++ b/src/data/AggregatedSolution.py @@ -1,18 +1,19 @@ from typing import List from data.Finding import Finding -from db.base import BaseModel +from pydantic import BaseModel -class AggregatedSolution: +class AggregatedSolution(BaseModel): findings: List[Finding] = None solution: str = "" metadata: dict = {} - def __init__(self, findings: List[Finding], solution: str, metadata=None): + def from_result(self, findings: List[Finding], solution: str, metadata=None): self.findings = findings self.solution = solution self.metadata = metadata + return self def __str__(self): return self.solution @@ -21,8 +22,8 @@ def to_dict(self): return { "findings": [finding.to_dict() for finding in self.findings], "solution": self.solution, - "metadata": self.metadata + "metadata": self.metadata, } def to_html(self): - return f"

{self.solution}

" \ No newline at end of file + return f"

{self.solution}

" diff --git a/src/data/Finding.py b/src/data/Finding.py index c41fea3..5a4a73c 100644 --- a/src/data/Finding.py +++ b/src/data/Finding.py @@ -1,9 +1,19 @@ from typing import List, Set, Optional, Any, get_args from enum import Enum, auto +import uuid + from pydantic import BaseModel, Field, PrivateAttr from data.Solution import Solution -from data.Categories import Category, TechnologyStack, SecurityAspect, SeverityLevel, RemediationType, \ - AffectedComponent, Compliance, Environment +from data.Categories import ( + Category, + TechnologyStack, + SecurityAspect, + SeverityLevel, + RemediationType, + AffectedComponent, + Compliance, + Environment, +) import json import logging @@ -12,6 +22,7 @@ class Finding(BaseModel): + id: str = Field(default_factory=lambda: f"{str(uuid.uuid4())}") title: List[str] = Field(default_factory=list) source: Set[str] = Field(default_factory=set) descriptions: List[str] = Field(default_factory=list) @@ -21,7 +32,7 @@ class Finding(BaseModel): severity: Optional[int] = None priority: Optional[int] = None location_list: List[str] = Field(default_factory=list) - category: Category = None + category: Optional[Category] = None unsupervised_cluster: Optional[int] = None solution: Optional["Solution"] = None _llm_service: Optional[Any] = PrivateAttr(default=None) @@ -35,7 +46,9 @@ def combine_descriptions(self) -> "Finding": logger.error("LLM Service not set, cannot combine descriptions.") return self - self.description = self.llm_service.combine_descriptions(self.descriptions, self.cve_ids, self.cwe_ids) + self.description = self.llm_service.combine_descriptions( + self.descriptions, self.cve_ids, self.cwe_ids + ) return self def add_category(self) -> "Finding": @@ -47,34 +60,45 @@ def add_category(self) -> "Finding": # Classify technology stack technology_stack_options = list(TechnologyStack) - self.category.technology_stack = self.llm_service.classify_kind(self, "technology_stack", - technology_stack_options) + self.category.technology_stack = self.llm_service.classify_kind( + self, "technology_stack", technology_stack_options + ) # Classify security aspect security_aspect_options = list(SecurityAspect) - self.category.security_aspect = self.llm_service.classify_kind(self, "security_aspect", security_aspect_options) + self.category.security_aspect = self.llm_service.classify_kind( + self, "security_aspect", security_aspect_options + ) # Classify severity level severity_level_options = list(SeverityLevel) - self.category.severity_level = self.llm_service.classify_kind(self, "severity_level", severity_level_options) + self.category.severity_level = self.llm_service.classify_kind( + self, "severity_level", severity_level_options + ) # Classify remediation type remediation_type_options = list(RemediationType) - self.category.remediation_type = self.llm_service.classify_kind(self, "remediation_type", - remediation_type_options) + self.category.remediation_type = self.llm_service.classify_kind( + self, "remediation_type", remediation_type_options + ) # Classify affected component affected_component_options = list(AffectedComponent) - self.category.affected_component = self.llm_service.classify_kind(self, "affected_component", - affected_component_options) + self.category.affected_component = self.llm_service.classify_kind( + self, "affected_component", affected_component_options + ) # Classify compliance compliance_options = list(Compliance) - self.category.compliance = self.llm_service.classify_kind(self, "compliance", compliance_options) + self.category.compliance = self.llm_service.classify_kind( + self, "compliance", compliance_options + ) # Classify environment environment_options = list(Environment) - self.category.environment = self.llm_service.classify_kind(self, "environment", environment_options) + self.category.environment = self.llm_service.classify_kind( + self, "environment", environment_options + ) return self @@ -213,9 +237,7 @@ def to_html(self, table=False): result += "NameValue" result += f"Title{', '.join(self.title)}" result += f"Source{', '.join(self.source)}" - result += ( - f"Description{self.description}" - ) + result += f"Description{self.description}" if len(self.location_list) > 0: result += f"Location List{' & '.join(map(str, self.location_list))}" result += f"CWE IDs{', '.join(self.cwe_ids)}" @@ -223,7 +245,11 @@ def to_html(self, table=False): result += f"Severity{self.severity}" result += f"Priority{self.priority}" if self.category is not None: - result += 'Category' + str(self.category).replace("\n", "
") + '' + result += ( + "Category" + + str(self.category).replace("\n", "
") + + "" + ) if self.unsupervised_cluster is not None: result += f"Unsupervised Cluster{self.unsupervised_cluster}" result += "" diff --git a/src/data/VulnerabilityReport.py b/src/data/VulnerabilityReport.py index 3ec9b87..21a7346 100644 --- a/src/data/VulnerabilityReport.py +++ b/src/data/VulnerabilityReport.py @@ -7,7 +7,6 @@ from data.AggregatedSolution import AggregatedSolution from data.Finding import Finding from ai.LLM.LLMServiceStrategy import LLMServiceStrategy -from ai.Clustering.AgglomerativeClusterer import AgglomerativeClusterer import logging @@ -51,9 +50,15 @@ def add_unsupervised_category(self, use_solution=True): This function adds a category to each finding in the report. :return: The VulnerabilityReport object and the AgglomerativeClustering object (in case you wanna have a look at the graph) """ + from ai.Clustering.AgglomerativeClusterer import AgglomerativeClusterer + clustering = AgglomerativeClusterer(self) - if use_solution and not all([finding.solution is not None for finding in self.findings]): - logger.warning("Not all findings have a solution, falling back to using the description.") + if use_solution and not all( + [finding.solution is not None for finding in self.findings] + ): + logger.warning( + "Not all findings have a solution, falling back to using the description." + ) use_solution = False clustering.add_unsupervised_category(use_solution=use_solution) @@ -98,21 +103,27 @@ def sort(self, by: str = "severity", reverse: bool = True): def to_dict(self): findings = [f.to_dict() for f in self.findings] if len(self.get_aggregated_solutions()) > 0: - aggregated_solutions = [f.to_dict() for f in self.get_aggregated_solutions()] + aggregated_solutions = [ + f.to_dict() for f in self.get_aggregated_solutions() + ] return {"findings": findings, "aggregated_solutions": aggregated_solutions} return {"findings": findings} def __str__(self): findings_str = "\n".join([str(f) for f in self.findings]) if len(self.get_aggregated_solutions()) > 0: - aggregated_solutions_str = "\n".join([str(f) for f in self.get_aggregated_solutions()]) + aggregated_solutions_str = "\n".join( + [str(f) for f in self.get_aggregated_solutions()] + ) return findings_str + "\n\n" + aggregated_solutions_str return findings_str def to_html(self, table=False): my_str = "
".join([f.to_html(table) for f in self.findings]) if len(self.get_aggregated_solutions()) > 0: - my_str += "

" + "
".join([f.to_html() for f in self.get_aggregated_solutions()]) + my_str += "

" + "
".join( + [f.to_html() for f in self.get_aggregated_solutions()] + ) return my_str def export_to_json(self, filename="VulnerabilityReport.json"): @@ -142,7 +153,11 @@ def import_from_json(filename="VulnerabilityReport.json"): def create_from_flama_json( - json_data, n=-1, llm_service: "LLMServiceStrategy" = None, shuffle_data=False, combine_descriptions=True + json_data, + n=-1, + llm_service: "LLMServiceStrategy" = None, + shuffle_data=False, + combine_descriptions=True, ) -> VulnerabilityReport: """ This function creates a VulnerabilityReport object from a JSON object. @@ -167,7 +182,9 @@ def create_from_flama_json( # shuffle json_data to get a random sample shuffle(json_data) - for d in conditional_tqdm(json_data[:n], combine_descriptions, desc="Combining descriptions"): + for d in conditional_tqdm( + json_data[:n], combine_descriptions, desc="Combining descriptions" + ): title = [x["element"] for x in d["title_list"]] source = set([x["source"] for x in d["title_list"]]) description = [x["element"] for x in d.get("description_list", [])] diff --git a/src/data/apischema.py b/src/data/apischema.py index 90acc85..bfaaf25 100644 --- a/src/data/apischema.py +++ b/src/data/apischema.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, validator +from pydantic import BaseModel, Field +from data.AggregatedSolution import AggregatedSolution from data.Categories import Category from data.Finding import Finding from data.pagination import Pagination, PaginationInput @@ -13,16 +14,26 @@ class SeverityFilter(BaseModel): minValue: int maxValue: int + + class FindingInputFilter(BaseModel): - severity: Optional[SeverityFilter] = None # ['low', 'high'] - priority: Optional[SeverityFilter] = None # ['low', 'high'] + severity: Optional[SeverityFilter] = None # ['low', 'high'] + priority: Optional[SeverityFilter] = None # ['low', 'high'] cve_ids: Optional[List[str]] = None cwe_ids: Optional[List[str]] = None source: Optional[List[str]] = None - + + +class UploadPreference(BaseModel): + long_description: Optional[bool] = Field(default=True) + search_terms: Optional[bool] = Field(default=True) + aggregated_solutions: Optional[bool] = Field(default=True) + + class StartRecommendationTaskRequest(BaseModel): - user_id: Optional[int] = None - strategy: Optional[Literal["OLLAMA", "ANTHROPIC", "OPENAI"]] = "OLLAMA" + preferences: Optional[UploadPreference] = UploadPreference( + long_description=True, search_terms=True, aggregated_solutions=True + ) data: InputData force_update: Optional[bool] = False filter: Optional[FindingInputFilter] = None @@ -36,7 +47,7 @@ class GetRecommendationFilter(BaseModel): task_id: Optional[int] = None date: Optional[str] = None location: Optional[str] = None - severity: Optional[SeverityFilter] = None # ['low', 'high'] + severity: Optional[SeverityFilter] = None # ['low', 'high'] cve_id: Optional[str] = None source: Optional[str] = None @@ -47,14 +58,7 @@ class GetRecommendationRequest(BaseModel): pagination: Optional[PaginationInput] = PaginationInput(offset=0, limit=10) -class SolutionItem(BaseModel): - short_description: str - long_description: str - metadata: dict - search_terms: str - - -class GetRecommendationResponseItem(Finding): # TODO adapt needed fields +class GetRecommendationResponseItem(Finding): pass class GetRecommendationResponseItems(BaseModel): @@ -70,11 +74,17 @@ class GetRecommendationTaskStatusResponse(BaseModel): status: TaskStatus -class GetSummarizedRecommendationRequest(BaseModel): - user_id: Optional[int] - pagination: PaginationInput = PaginationInput(offset=0, limit=10) +class GetAggregatedRecommendationFilter(BaseModel): + task_id: Optional[int] = None + + +class GetAggregatedRecommendationRequest(BaseModel): + filter: Optional[GetAggregatedRecommendationFilter] = None + + +class GetAggregatedRecommendationResponseItem(AggregatedSolution): + pass -# class GetSummarizedRecommendationResponse(BaseModel): -# recommendation: list[Recommendation] -# pagination: Pagination +class GetAggregatedRecommendationResponse(BaseModel): + items: List[GetAggregatedRecommendationResponseItem] diff --git a/src/data/helper.py b/src/data/helper.py index b7e4c35..bc2cba0 100644 --- a/src/data/helper.py +++ b/src/data/helper.py @@ -3,41 +3,41 @@ from data.apischema import FindingInputFilter from data.types import Content, InputData -# ##TODO:maybe using pydantic should be enough -# def validate_json(data: any) -> bool: -# try: -# json_data = data -# try: -# validate(instance=json_data, schema=schema) -# print("JSON data adheres to the schema.") -# except jsonschema.exceptions.ValidationError as e: -# print("JSON data does not adhere to the schema.") -# print(e) -# except ValueError as e: -# return False - -# return True - def get_content_list(json_data: InputData) -> list[Content]: return json_data.message.content -def filter_findings(findings: List[Content], filter: FindingInputFilter) -> List[Content]: + +def filter_findings( + findings: List[Content], filter: FindingInputFilter +) -> List[Content]: def matches(content: Content) -> bool: - if filter.source and not any(title.source in filter.source for title in content.title_list): + if filter.source and not any( + title.source in filter.source for title in content.title_list + ): return False - if filter.severity and not (filter.severity.minValue <= content.severity <= filter.severity.maxValue): + if filter.severity and not ( + filter.severity.minValue <= content.severity <= filter.severity.maxValue + ): return False - if filter.priority and not (filter.priority.minValue <= content.priority <= filter.priority.maxValue): + if filter.priority and not ( + filter.priority.minValue <= content.priority <= filter.priority.maxValue + ): return False - if filter.cve_ids and not any(cve.element in filter.cve_ids for cve in content.cve_id_list): + if filter.cve_ids and not any( + cve.element in filter.cve_ids for cve in content.cve_id_list + ): return False - if filter.cwe_ids and not any(element in filter.cwe_ids for cwe in content.cwe_id_list or [] for element in cwe.element): + if filter.cwe_ids and not any( + element in filter.cwe_ids + for cwe in content.cwe_id_list or [] + for element in cwe.element + ): return False return True - return [content for content in findings if matches(content)] \ No newline at end of file + return [content for content in findings if matches(content)] diff --git a/src/data/types.py b/src/data/types.py index f0e3f75..62582c3 100644 --- a/src/data/types.py +++ b/src/data/types.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Union,Optional +from typing import List, Dict, Any, Union, Optional from typing import Any, Optional, Tuple from typing_extensions import Annotated @@ -6,8 +6,7 @@ from pydantic import BaseModel - - +# deprecated schema = { "type": "object", "properties": { @@ -25,17 +24,47 @@ "doc_type": {"type": "string"}, "criticality_tag": {"type": ["array", "object"]}, "knowledge_type": {"type": "string"}, - "requirement_list": {"type": "array", "items": {"type": "string"}}, - "title_list": {"type": "array", "items": {"type": "object"}}, - "location_list": {"type": "array", "items": {"type": "object"}}, - "description_list": {"type": "array", "items": {"type": "object"}}, - "internal_rating_list": {"type": "array", "items": {"type": "object"}}, - "internal_ratingsource_list": {"type": "array", "items": {"type": "object"}}, - "cvss_rating_list": {"type": "array", "items": {"type": "object"}}, + "requirement_list": { + "type": "array", + "items": {"type": "string"}, + }, + "title_list": { + "type": "array", + "items": {"type": "object"}, + }, + "location_list": { + "type": "array", + "items": {"type": "object"}, + }, + "description_list": { + "type": "array", + "items": {"type": "object"}, + }, + "internal_rating_list": { + "type": "array", + "items": {"type": "object"}, + }, + "internal_ratingsource_list": { + "type": "array", + "items": {"type": "object"}, + }, + "cvss_rating_list": { + "type": "array", + "items": {"type": "object"}, + }, "rule_list": {"type": "array", "items": {"type": "object"}}, - "cwe_id_list": {"type": "array", "items": {"type": "object"}}, - "cve_id_list": {"type": "array", "items": {"type": "object"}}, - "activity_list": {"type": "array", "items": {"type": "object"}}, + "cwe_id_list": { + "type": "array", + "items": {"type": "object"}, + }, + "cve_id_list": { + "type": "array", + "items": {"type": "object"}, + }, + "activity_list": { + "type": "array", + "items": {"type": "object"}, + }, "first_found": {"type": "string"}, "last_found": {"type": "string"}, "report_amount": {"type": "integer"}, @@ -46,25 +75,45 @@ "priority_explanation": {"type": "string"}, "sum_id": {"type": "string"}, "prio_id": {"type": "string"}, - "element_tag": {"type": "string"} + "element_tag": {"type": "string"}, }, - "required": ["doc_type", "criticality_tag", "knowledge_type", "requirement_list", "title_list", - "location_list", "description_list", "internal_rating_list", - "internal_ratingsource_list", "cvss_rating_list", "rule_list", - "cwe_id_list", "cve_id_list", "activity_list", "first_found", "last_found", - "report_amount", "content_hash", "severity", "severity_explanation", - "priority", "priority_explanation", "sum_id", "prio_id", "element_tag"] - } - } + "required": [ + "doc_type", + "criticality_tag", + "knowledge_type", + "requirement_list", + "title_list", + "location_list", + "description_list", + "internal_rating_list", + "internal_ratingsource_list", + "cvss_rating_list", + "rule_list", + "cwe_id_list", + "cve_id_list", + "activity_list", + "first_found", + "last_found", + "report_amount", + "content_hash", + "severity", + "severity_explanation", + "priority", + "priority_explanation", + "sum_id", + "prio_id", + "element_tag", + ], + }, + }, }, - "required": ["version", "utc_age", "content"] - } + "required": ["version", "utc_age", "content"], + }, }, - "required": ["status", "message"] + "required": ["status", "message"], } - class Tag(BaseModel): action: str user_mail: str @@ -73,8 +122,6 @@ class Tag(BaseModel): valid_until: str - - class Location(BaseModel): location: str amount: int @@ -84,53 +131,46 @@ class Location(BaseModel): tags: List[Tag] - class Title(BaseModel): element: str source: str - class Description(BaseModel): element: str source: str - class Rating(BaseModel): element: str source: str - class CvssRating(BaseModel): element: str source: str - class Rule(BaseModel): element: str source: str - class CveId(BaseModel): element: str source: str + class CweId(BaseModel): element: List[str] source: str - class Activity(BaseModel): element: str source: str - class Content(BaseModel): doc_type: str criticality_tag: Union[List[Union[str, int]], Dict[str, Any]] @@ -138,13 +178,13 @@ class Content(BaseModel): requirement_list: List[str] title_list: List[Title] location_list: List[Location] - description_list: Optional[List[Description]] = [] - internal_rating_list: Optional[ List[Rating]] = [] + description_list: Optional[List[Description]] = [] + internal_rating_list: Optional[List[Rating]] = [] internal_ratingsource_list: List[Rating] - cvss_rating_list: Optional[ List[CvssRating]] = [] + cvss_rating_list: Optional[List[CvssRating]] = [] rule_list: List[Rule] - cve_id_list: Optional[ List[CveId]] = [] - cwe_id_list: Optional[ List[CweId]] = [] + cve_id_list: Optional[List[CveId]] = [] + cwe_id_list: Optional[List[CweId]] = [] activity_list: List[Activity] first_found: str last_found: str @@ -159,24 +199,21 @@ class Content(BaseModel): element_tag: str - class Message(BaseModel): version: str utc_age: str content: List[Content] - class InputData(BaseModel): status: str message: Message - class Finding(BaseModel): content: Content solutions: List[Dict[str, Any]] = [] - + class Recommendation(BaseModel): recommendation: str diff --git a/src/db/models.py b/src/db/models.py index 62aec28..f604a2e 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -1,7 +1,7 @@ from enum import Enum as PyEnum from typing import List, Optional -from sqlalchemy import JSON, Enum, ForeignKey, Integer, String +from sqlalchemy import JSON, Enum, ForeignKey, Integer, String, Table from sqlalchemy.orm import Mapped, relationship from data.types import Content @@ -9,13 +9,29 @@ from .base import BaseModel, Column +findings_aggregated_association_table = Table( + "findings_aggregated_association", + BaseModel.metadata, + Column( + "finding_id", + ForeignKey("findings.id", ondelete="CASCADE"), + primary_key=True, + ), + Column( + "aggregated_recommendation_id", + ForeignKey("aggregated_recommendations.id", ondelete="CASCADE"), + primary_key=True, + ), +) + + class Recommendation(BaseModel): __tablename__ = "recommendations" description_short = Column(String, nullable=True) description_long = Column(String, nullable=True) search_terms = Column(String, nullable=True) meta = Column(JSON, default={}, nullable=True) - category = Column(String, nullable=True) + category = Column(JSON, nullable=True) finding_id: int = Column(Integer, ForeignKey("findings.id"), nullable=True) recommendation_task_id = Column( Integer, @@ -27,7 +43,25 @@ class Recommendation(BaseModel): ) def __repr__(self): - return f"" + return f"" + + +class AggregatedRecommendation(BaseModel): + __tablename__ = "aggregated_recommendations" + solution = Column(String, nullable=True) + meta = Column(JSON, default={}, nullable=True) + recommendation_task_id = Column( + Integer, + ForeignKey("recommendation_task.id", ondelete="CASCADE"), + nullable=False, + ) + findings: Mapped[List["Finding"]] = relationship( + secondary=findings_aggregated_association_table, + back_populates="aggregated_recommendations", + ) + + def __repr__(self): + return f"" class Finding(BaseModel): @@ -37,7 +71,7 @@ class Finding(BaseModel): title_list = Column(JSON, default=None, nullable=True) description_list = Column(JSON, default=[], nullable=True) locations_list = Column(JSON, default=[], nullable=True) - category = Column(String, nullable=True) + category = Column(JSON, nullable=True) cwe_id_list = Column(JSON, default=[], nullable=True) cve_id_list = Column(JSON, default=[], nullable=True) priority = Column(Integer, default=None, nullable=True) @@ -62,6 +96,10 @@ class Finding(BaseModel): "Recommendation", back_populates="finding" ) + aggregated_recommendations: Mapped[List["AggregatedRecommendation"]] = relationship( + secondary=findings_aggregated_association_table, back_populates="findings" + ) + def from_data(self, data: Content): self.cve_id_list = ( [x.dict() for x in data.cve_id_list] if data.cve_id_list else [] diff --git a/src/db/my_db.py b/src/db/my_db.py index 6b1e554..fb524da 100644 --- a/src/db/my_db.py +++ b/src/db/my_db.py @@ -1,4 +1,4 @@ -from sqlalchemy import create_engine +from sqlalchemy import create_engine, NullPool import os from sqlalchemy.orm import sessionmaker @@ -7,7 +7,9 @@ from config import config engine = create_engine( - config.get_db_url(), echo=os.getenv("DB_DEBUG", "false") == "true" + config.get_db_url(), + echo=os.getenv("DB_DEBUG", "false") == "true", + poolclass=NullPool, ) diff --git a/src/dto/finding.py b/src/dto/finding.py index 6e9486b..d03f701 100644 --- a/src/dto/finding.py +++ b/src/dto/finding.py @@ -7,12 +7,15 @@ def db_finding_to_response_item( find: DBFinding, ) -> GetRecommendationResponseItem: + category = None + try: + recommendation = find.recommendations[0] + category = Category.model_validate_json(recommendation.category) + except Exception as e: + category = Category() + return GetRecommendationResponseItem( - category=Category( - affected_component=( - AffectedComponent(find.category) if find.category else None - ) - ), + category=category, solution=Solution( short_description=( find.recommendations[0].description_short diff --git a/src/extract-docs.py b/src/extract-docs.py index 04df878..6fb6dfb 100644 --- a/src/extract-docs.py +++ b/src/extract-docs.py @@ -1,10 +1,20 @@ import app as app from fastapi.openapi.docs import get_swagger_ui_html import os +import argparse if __name__ == "__main__": - output_path = os.path.join(os.getcwd(), ".docs") - + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output", + help="Output path for documentation", + required=False, + default=os.path.join(os.getcwd(), ".docs"), + ) + args = parser.parse_args() + output_path = args.output + print(f"Generating documentation in {output_path}") if not os.path.exists(output_path): os.makedirs(output_path) diff --git a/src/migrations/versions/06253ed3ac28_change_category.py b/src/migrations/versions/06253ed3ac28_change_category.py new file mode 100644 index 0000000..1389402 --- /dev/null +++ b/src/migrations/versions/06253ed3ac28_change_category.py @@ -0,0 +1,50 @@ +"""change_category + +Revision ID: 06253ed3ac28 +Revises: 9e31625c7978 +Create Date: 2024-07-26 11:07:19.507484 + +""" + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +import db + + +# revision identifiers, used by Alembic. +revision = "06253ed3ac28" +down_revision = "9e31625c7978" +branch_labels = None +depends_on = None + + +def upgrade(): + + op.execute( + "ALTER TABLE recommendations ALTER COLUMN category TYPE json using category::json;" + ) + + op.execute( + "alter table findings alter column category type json using category::json;" + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "recommendations", + "category", + existing_type=sa.JSON(), + type_=sa.VARCHAR(), + existing_nullable=True, + ) + op.alter_column( + "findings", + "category", + existing_type=sa.JSON(), + type_=sa.VARCHAR(), + existing_nullable=True, + ) + # ### end Alembic commands ### diff --git a/src/migrations/versions/9e31625c7978_aggregated.py b/src/migrations/versions/9e31625c7978_aggregated.py new file mode 100644 index 0000000..13fe48e --- /dev/null +++ b/src/migrations/versions/9e31625c7978_aggregated.py @@ -0,0 +1,47 @@ +"""aggregated + +Revision ID: 9e31625c7978 +Revises: 8e2452c67d2d +Create Date: 2024-07-26 08:56:17.779435 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +import db + + +# revision identifiers, used by Alembic. +revision = '9e31625c7978' +down_revision = '8e2452c67d2d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('aggregated_recommendations', + sa.Column('solution', sa.String(), nullable=True), + sa.Column('meta', sa.JSON(), nullable=True), + sa.Column('recommendation_task_id', sa.Integer(), nullable=False), + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True), + sa.ForeignKeyConstraint(['recommendation_task_id'], ['recommendation_task.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('findings_aggregated_association', + sa.Column('finding_id', sa.Integer(), nullable=False), + sa.Column('aggregated_recommendation_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['aggregated_recommendation_id'], ['aggregated_recommendations.id'], ), + sa.ForeignKeyConstraint(['finding_id'], ['findings.id'], ), + sa.PrimaryKeyConstraint('finding_id', 'aggregated_recommendation_id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('findings_aggregated_association') + op.drop_table('aggregated_recommendations') + # ### end Alembic commands ### \ No newline at end of file diff --git a/src/migrations/versions/cae76af4e118_on_delete_aggregated.py b/src/migrations/versions/cae76af4e118_on_delete_aggregated.py new file mode 100644 index 0000000..6479987 --- /dev/null +++ b/src/migrations/versions/cae76af4e118_on_delete_aggregated.py @@ -0,0 +1,36 @@ +"""on_delete_aggregated + +Revision ID: cae76af4e118 +Revises: 06253ed3ac28 +Create Date: 2024-07-26 19:17:05.829930 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +import db + + +# revision identifiers, used by Alembic. +revision = 'cae76af4e118' +down_revision = '06253ed3ac28' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('findings_aggregated_association_finding_id_fkey', 'findings_aggregated_association', type_='foreignkey') + op.drop_constraint('findings_aggregated_associati_aggregated_recommendation_id_fkey', 'findings_aggregated_association', type_='foreignkey') + op.create_foreign_key(None, 'findings_aggregated_association', 'findings', ['finding_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key(None, 'findings_aggregated_association', 'aggregated_recommendations', ['aggregated_recommendation_id'], ['id'], ondelete='CASCADE') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'findings_aggregated_association', type_='foreignkey') + op.drop_constraint(None, 'findings_aggregated_association', type_='foreignkey') + op.create_foreign_key('findings_aggregated_associati_aggregated_recommendation_id_fkey', 'findings_aggregated_association', 'aggregated_recommendations', ['aggregated_recommendation_id'], ['id']) + op.create_foreign_key('findings_aggregated_association_finding_id_fkey', 'findings_aggregated_association', 'findings', ['finding_id'], ['id']) + # ### end Alembic commands ### \ No newline at end of file diff --git a/src/repository/finding.py b/src/repository/finding.py index 88ac7f1..db2cb27 100644 --- a/src/repository/finding.py +++ b/src/repository/finding.py @@ -1,4 +1,5 @@ from fastapi import Depends + from sqlalchemy import Date, Integer, cast, func from sqlalchemy.orm import Session @@ -6,6 +7,7 @@ from data.apischema import SeverityFilter from data.pagination import PaginationInput from db.my_db import get_db +from repository.types import GetFindingsByFilterInput class FindingRepository: @@ -18,6 +20,19 @@ def get_findings(self, num_findings: int) -> list[db_models.Finding]: findings = self.session.query(db_models.Finding).limit(num_findings).all() return findings + def get_all_findings_by_task_id_for_processing(self, task_id: int, limit: int = -1): + query = ( + self.session.query(db_models.Finding) + .join(db_models.RecommendationTask) + .filter(db_models.RecommendationTask.id == task_id) + ) + + # Remove Limit For Now + if limit > 0: + query = query.limit(limit) + + return query.all() + def get_findings_by_task_id( self, task_id: int, pagination: PaginationInput ) -> list[db_models.Finding]: @@ -34,25 +49,38 @@ def get_findings_by_task_id( ) return findings - - def get_findings_by_task_id_and_severity( - self, task_id: int, severity: SeverityFilter, pagination: PaginationInput + + def get_findings_by_task_id_and_filter( + self, input: GetFindingsByFilterInput ) -> list[db_models.Finding]: - findings = ( + task_id = input.task_id + pagination = input.pagination + severity = input.severityFilter + query = ( self.session.query(db_models.Finding) .join(db_models.RecommendationTask) .where( db_models.RecommendationTask.status == db_models.TaskStatus.COMPLETED, (db_models.RecommendationTask.id == task_id), - db_models.Finding.severity >= severity.minValue, - db_models.Finding.severity <= severity.maxValue, + ( + db_models.Finding.severity >= severity.minValue + if severity and severity.minValue + else True + ), + ( + db_models.Finding.severity <= severity.maxValue + if severity and severity.maxValue + else True + ), ) - .offset(pagination.offset) - .limit(pagination.limit) - .all() ) - return findings + total = query.count() + findings = query.all() + if pagination: + query = query.offset(pagination.offset).limit(pagination.limit) + + return findings, total def get_findings_count_by_task_id(self, task_id: int) -> int: count = ( @@ -74,5 +102,5 @@ def create_findings( self.session.commit() -def get_finding_repository(session: Depends = Depends(get_db)): +def get_finding_repository(session: Session = Depends(get_db)): return FindingRepository(session) diff --git a/src/repository/recommendation.py b/src/repository/recommendation.py index 3c3e36a..5f2147e 100644 --- a/src/repository/recommendation.py +++ b/src/repository/recommendation.py @@ -1,3 +1,4 @@ +from data.Finding import Finding import db.models as db_models from sqlalchemy.orm import Session, sessionmaker @@ -6,6 +7,8 @@ from db.my_db import get_db from fastapi import Depends +from repository.types import CreateAggregatedRecommendationInput + class RecommendationRepository: session: Session @@ -33,6 +36,83 @@ def get_recommendation_by_id( ) return recommendation + def create_recommendations( + self, + finding_with_solution: list[tuple[str, Finding]], + recommendation_task_id: int, + ): + session = self.session + for finding_id, f in finding_with_solution: + finding = ( + session.query(db_models.Finding) + .filter(db_models.Finding.id == finding_id) + .first() + ) + if finding is None: + print(f"Finding with id {finding_id} not found") + continue + recommendation = db_models.Recommendation( + description_short=( + f.solution.short_description + if f.solution.short_description + else "No short description available" + ), + description_long=( + f.solution.long_description + if f.solution.long_description + else "No long description available" + ), + meta=f.solution.metadata if f.solution.metadata else {}, + search_terms=f.solution.search_terms if f.solution.search_terms else [], + finding_id=finding_id, + recommendation_task_id=recommendation_task_id, + category=(f.category.model_dump_json() if f.category else None), + ) + session.add(recommendation) + ## update recommendation task status + + session.commit() + + def create_aggregated_solutions( + self, + input: CreateAggregatedRecommendationInput, + ): + + for solution in input.aggregated_solutions: + aggregated_rec = db_models.AggregatedRecommendation( + solution=solution.solution.solution, + meta=solution.solution.metadata, + recommendation_task_id=input.recommendation_task_id, + ) + self.session.add(aggregated_rec) + self.session.commit() + self.session.refresh(aggregated_rec) + for findings_id in solution.findings_db_ids: + + res = self.session.execute( + db_models.findings_aggregated_association_table.insert().values( + finding_id=findings_id, + aggregated_recommendation_id=aggregated_rec.id, + ) + ) + res.close() + + self.session.commit() + + def get_aggregated_solutions( + self, recommendation_task_id: int + ) -> list[db_models.AggregatedRecommendation]: + + aggregated_solutions = ( + self.session.query(db_models.AggregatedRecommendation) + .filter( + db_models.AggregatedRecommendation.recommendation_task_id + == recommendation_task_id + ) + .all() + ) + return aggregated_solutions + def get_recommendation_repository(session: Session = Depends(get_db)): return RecommendationRepository(session) diff --git a/src/repository/task.py b/src/repository/task.py index b9ef75f..85603a0 100644 --- a/src/repository/task.py +++ b/src/repository/task.py @@ -29,15 +29,23 @@ def update_task(self, task: db_models.RecommendationTask, celery_task_id: str): self.session.refresh(task) return task + def update_task_completed(self, task_id: int): + status = db_models.TaskStatus.COMPLETED + self.session.query(db_models.RecommendationTask).filter( + db_models.RecommendationTask.id == task_id + ).update({db_models.RecommendationTask.status: status}) + + self.session.commit() + self.session.flush() + def get_tasks( self, ) -> list[db_models.RecommendationTask]: - + print("get_tasks") tasks = self.session.query(db_models.RecommendationTask).all() return tasks def get_task_by_id(self, task_id: int) -> db_models.RecommendationTask | None: - task = ( self.session.query(db_models.RecommendationTask) .where(db_models.RecommendationTask.id == task_id) diff --git a/src/repository/types.py b/src/repository/types.py new file mode 100644 index 0000000..b7ef472 --- /dev/null +++ b/src/repository/types.py @@ -0,0 +1,22 @@ +from typing import Optional +from pydantic import BaseModel + +from data.AggregatedSolution import AggregatedSolution +from data.apischema import SeverityFilter +from data.pagination import PaginationInput + + +class AggregatedSolutionInput(BaseModel): + solution: AggregatedSolution + findings_db_ids: list[int] + + +class CreateAggregatedRecommendationInput(BaseModel): + aggregated_solutions: list[AggregatedSolutionInput] + recommendation_task_id: int + + +class GetFindingsByFilterInput(BaseModel): + task_id: int + severityFilter: Optional[SeverityFilter] = None + pagination: Optional[PaginationInput] = None diff --git a/src/routes/v1/recommendations.py b/src/routes/v1/recommendations.py index e05db88..cf39f92 100644 --- a/src/routes/v1/recommendations.py +++ b/src/routes/v1/recommendations.py @@ -1,5 +1,5 @@ import datetime -from typing import Annotated +from typing import Annotated, Optional from fastapi import Body, Depends, HTTPException, Response from fastapi.routing import APIRouter @@ -8,10 +8,14 @@ import data.apischema as apischema import db.models as db_models +from data.AggregatedSolution import AggregatedSolution from db.my_db import get_db from dto.finding import db_finding_to_response_item from repository.finding import get_finding_repository +from repository.recommendation import (RecommendationRepository, + get_recommendation_repository) from repository.task import TaskRepository, get_task_repository +from repository.types import GetFindingsByFilterInput router = APIRouter( prefix="/recommendations", @@ -63,26 +67,65 @@ def recommendations( status_code=400, detail="Recommendation task failed", ) - if severityFilter: - findings = finding_repository.get_findings_by_task_id_and_severity(task.id, severityFilter, request.pagination) - else: - findings = finding_repository.get_findings_by_task_id(task.id, request.pagination) + findings, total = finding_repository.get_findings_by_task_id_and_filter( + GetFindingsByFilterInput( + task_id=task.id, + severityFilter=severityFilter if severityFilter else None, + pagination=request.pagination, + ) + ) - total_count = finding_repository.get_findings_count_by_task_id(task.id) - response = apischema.GetRecommendationResponse( - items=apischema.GetRecommendationResponseItems( - findings= [db_finding_to_response_item(find) for find in findings], - aggregated_solutions= []), + items=[db_finding_to_response_item(find) for find in findings], pagination=apischema.Pagination( offset=request.pagination.offset, limit=request.pagination.limit, - total=total_count, + total=total, count=len(findings), ), ) + return response - if not response or len(response.items.findings) == 0: - return Response(status_code=204, headers={"Retry-After": "120"}) - return response +@router.post("/aggregated") +def aggregated_solutions( + request: Annotated[ + Optional[apischema.GetAggregatedRecommendationRequest], Body(...) + ], + task_repository: TaskRepository = Depends(get_task_repository), + recommendation_repository: RecommendationRepository = ( + Depends(get_recommendation_repository) + ), +) -> apischema.GetAggregatedRecommendationResponse: + + today = datetime.datetime.now().date() + task = None + if request.filter and request.filter.task_id: + task = task_repository.get_task_by_id(request.filter.task_id) + task = task_repository.get_task_by_date(today) + + if not task: + raise HTTPException( + status_code=404, + detail=( + f"Task with id {request.filter.task_id} not found" + if request.filter and request.filter.task_id + else "Task for today not found" + ), + ) + if task.status != db_models.TaskStatus.COMPLETED: + raise HTTPException( + status_code=400, + detail="Recommendation status:" + task.status.value, + ) + agg_recs = recommendation_repository.get_aggregated_solutions(task.id) + return apischema.GetAggregatedRecommendationResponse( + items=[ + apischema.GetAggregatedRecommendationResponseItem( + solution=rec.solution, + findings=[db_finding_to_response_item(x) for x in rec.findings], + metadata=rec.meta, + ) + for rec in agg_recs + ] + ) diff --git a/src/routes/v1/upload.py b/src/routes/v1/upload.py index 0aa02dd..02aedd9 100644 --- a/src/routes/v1/upload.py +++ b/src/routes/v1/upload.py @@ -33,7 +33,7 @@ async def upload( content_list = get_content_list(data.data) if data.filter: content_list = filter_findings(content_list, data.filter) - + today = datetime.datetime.now().date() existing_task = task_repository.get_task_by_date(today) if existing_task and not data.force_update: @@ -59,7 +59,13 @@ async def upload( finding_repository.create_findings(findings) celery_result = worker.send_task( - "worker.generate_report", args=[recommendation_task.id] + "worker.generate_report", + args=[ + recommendation_task.id, + data.preferences.long_description, + data.preferences.search_terms, + data.preferences.aggregated_solutions, + ], ) # update the task with the celery task id diff --git a/src/worker/worker.py b/src/worker/worker.py index f2f4178..9994406 100644 --- a/src/worker/worker.py +++ b/src/worker/worker.py @@ -1,16 +1,31 @@ import logging from celery import Celery +from celery.signals import worker_init import db.models as db_models -from db.my_db import Session +from db.my_db import engine, sessionmaker +from repository.finding import FindingRepository +from repository.recommendation import RecommendationRepository +from repository.task import TaskRepository +from repository.types import ( + AggregatedSolutionInput, + CreateAggregatedRecommendationInput, +) + + +from ai.LLM.LLMServiceStrategy import LLMServiceStrategy +from ai.LLM.Strategies.OLLAMAService import OLLAMAService +from data.VulnerabilityReport import create_from_flama_json +from ai.Grouping.FindingGrouper import FindingGrouper + logger = logging.getLogger(__name__) from config import config - +Session = sessionmaker(engine) redis_url = config.redis_endpoint print(f"Redis URL: {redis_url}") @@ -25,11 +40,17 @@ def error(self, exc, task_id, args, kwargs, einfo): @worker.task(name="worker.generate_report", on_failure=error) -def generate_report(recommendation_task_id: int, stragegy: str = "OLLAMA"): +def generate_report( + recommendation_task_id: int, + generate_long_solution: bool = True, + generate_search_terms: bool = True, + generate_aggregate_solutions: bool = True, +): + # importing here so importing worker does not import all the dependencies - from ai.LLM.LLMServiceStrategy import LLMServiceStrategy - from ai.LLM.Strategies.OLLAMAService import OLLAMAService - from data.VulnerabilityReport import create_from_flama_json + # from ai.LLM.LLMServiceStrategy import LLMServiceStrategy + # from ai.LLM.Strategies.OLLAMAService import OLLAMAService + # from data.VulnerabilityReport import create_from_flama_json ollama_strategy = OLLAMAService() llm_service = LLMServiceStrategy(ollama_strategy) @@ -37,71 +58,91 @@ def generate_report(recommendation_task_id: int, stragegy: str = "OLLAMA"): logger.info(f"Processing recommendation task with id {recommendation_task_id}") logger.info(f"Processing recommendation task with limit {limit}") logger.info( - f"Processing recommendation task with model_name {ollama_strategy.model_name}" + f"long: {generate_long_solution}, search: {generate_search_terms}, aggregate: {generate_aggregate_solutions}" ) - with Session() as session: - query = ( - session.query(db_models.Finding) - .join(db_models.RecommendationTask) - .filter(db_models.RecommendationTask.id == recommendation_task_id) - ) - if limit > 0: - query = query.limit(limit) - - findings_from_db = query.all() - logger.info(f"Found {len(findings_from_db)} findings for recommendation task") - if not findings_from_db: - logger.warn( - f"No findings found for recommendation task {recommendation_task_id}" + + recommendationTask = None + try: + with Session() as session: + + taskRepo = TaskRepository(session) + recommendationTask = taskRepo.get_task_by_id(recommendation_task_id) + + if not recommendationTask: + logger.error( + f"Recommendation task with id {recommendation_task_id} not found" + ) + return + find_repo = FindingRepository(session) + findings_from_db = find_repo.get_all_findings_by_task_id_for_processing( + recommendation_task_id, limit + ) + + logger.info( + f"Found {len(findings_from_db)} findings for recommendation task" ) - return + if not findings_from_db: + logger.warn( + f"No findings found for recommendation task {recommendation_task_id}" + ) + return + + findings = [f.raw_data for f in findings_from_db] + finding_ids = [f.id for f in findings_from_db] + + except Exception as e: + logger.error(f"Error processing recommendation task: {e}") + return - findings = [f.raw_data for f in findings_from_db] - finding_ids = [f.id for f in findings_from_db] vulnerability_report = create_from_flama_json( - findings, n=limit, llm_service=llm_service + findings, + n=limit, + llm_service=llm_service, + shuffle_data=False, # set it true may cause zip to work incorrectly ) + + # map findings to db ids, used for aggregated solutions + finding_id_map = { + finding.id: db_id + for db_id, finding in zip(finding_ids, vulnerability_report.findings) + } + vulnerability_report.add_category() - vulnerability_report.add_solution() + vulnerability_report.add_solution( + search_term=generate_search_terms, long=generate_long_solution + ) - with Session() as session: - for finding_id, f in zip(finding_ids, vulnerability_report.findings): - finding = ( - session.query(db_models.Finding) - .filter(db_models.Finding.id == finding_id) - .first() - ) - if finding is None: - print(f"Finding with id {finding_id} not found") - continue - recommendation = db_models.Recommendation( - description_short=( - f.solution.short_description - if f.solution.short_description - else "No short description" - ), - description_long=( - f.solution.long_description - if f.solution.long_description - else "No long description" - ), - meta=f.solution.metadata if f.solution.metadata else {}, - search_terms=f.solution.search_terms if f.solution.search_terms else [], - finding_id=finding_id, - recommendation_task_id=recommendation_task_id, - # TODO: fix category changes - category=( - f.category.affected_component.value - if f.category and f.category.affected_component - else None - ), + if generate_aggregate_solutions: + # after all the findings are processed, we need to group them + + findingGrouper = FindingGrouper(vulnerability_report, ollama_strategy) + findingGrouper.generate_aggregated_solutions() + # map findings to db ids for aggregated solutions + aggregated_solutions_with_db_ids = [ + AggregatedSolutionInput( + findings_db_ids=[ + finding_id_map[finding.id] for finding in solution.findings + ], + solution=solution, ) - session.add(recommendation) - ## updat recommendation task status - recommendation_task = ( - session.query(db_models.RecommendationTask) - .filter(db_models.RecommendationTask.id == recommendation_task_id) - .first() + for solution in vulnerability_report.aggregated_solutions + ] + with Session() as session: + recommendationRepo = RecommendationRepository(session) + recommendationRepo.create_recommendations( + list(zip(finding_ids, vulnerability_report.findings)), + recommendation_task_id, + ) + + if generate_aggregate_solutions: + # save aggregated solutions + recommendationRepo = RecommendationRepository(session) + recommendationRepo.create_aggregated_solutions( + input=CreateAggregatedRecommendationInput( + aggregated_solutions=aggregated_solutions_with_db_ids, + recommendation_task_id=recommendation_task_id, + ) ) - recommendation_task.status = db_models.TaskStatus.COMPLETED - session.commit() + # finally update the task status + recommendationRepo = TaskRepository(session) + recommendationRepo.update_task_completed(recommendation_task_id)