diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
new file mode 100644
index 0000000..2e12d55
--- /dev/null
+++ b/.github/workflows/docs.yml
@@ -0,0 +1,36 @@
+name: API Docs
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+jobs:
+ generate-docs:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v3
+
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.11
+
+ - name: Install dependencies
+ run: |
+ pip install pipenv
+ pipenv install
+
+ - name: Generate docs
+ run: |
+ pipenv run python src/extract-docs.py
+
+ - name: Save docs as artifact
+ uses: actions/upload-artifact@v2
+ with:
+ name: docs
+ path: .docs
+ if-no-files-found: error
diff --git a/README.md b/README.md
index 056c9ba..daec723 100644
--- a/README.md
+++ b/README.md
@@ -54,6 +54,7 @@ After starting the application, you can access the API at `http://localhost:8000
### Docker
building the images
+
```bash
docker compose build
```
@@ -64,23 +65,23 @@ To run the code within Docker, use the following command in the root directory o
docker compose up
```
-
-
Add the `-d` flag to run the containers in the background: `docker compose up -d`.
-
-
-
### Available Routes
Currently, these routes are generated by fastapi.
```
+HEAD, GET /openapi.json
+HEAD, GET /docs
+HEAD, GET /docs/oauth2-redirect
+HEAD, GET /redoc
GET /api/v1/tasks/
DELETE /api/v1/tasks/{task_id}
DELETE /api/v1/tasks/
GET /api/v1/tasks/{task_id}/status
POST /api/v1/recommendations/
+POST /api/v1/recommendations/aggregated
POST /api/v1/upload/
GET /
```
diff --git a/src/ai/Grouping/FindingGrouper.py b/src/ai/Grouping/FindingGrouper.py
index 8a73c33..7e7ac09 100644
--- a/src/ai/Grouping/FindingGrouper.py
+++ b/src/ai/Grouping/FindingGrouper.py
@@ -9,7 +9,9 @@
class FindingGrouper:
- def __init__(self, vulnerability_report: VulnerabilityReport, llm_service: BaseLLMService):
+ def __init__(
+ self, vulnerability_report: VulnerabilityReport, llm_service: BaseLLMService
+ ):
self.vulnerability_report = vulnerability_report
self.llm_service = llm_service
self.batcher = FindingBatcher(llm_service)
@@ -20,5 +22,7 @@ def generate_aggregated_solutions(self):
for batch in tqdm(self.batches, desc="Generating Aggregated Solutions"):
result_list = self.llm_service.generate_aggregated_solution(batch)
for result in result_list:
- self.aggregated_solutions.append(AggregatedSolution(result[1], result[0], result[2])) # Solution, Findings, Metadata
+ self.aggregated_solutions.append(
+ AggregatedSolution().from_result(result[1], result[0], result[2])
+ ) # Solution, Findings, Metadata
self.vulnerability_report.set_aggregated_solutions(self.aggregated_solutions)
diff --git a/src/app.py b/src/app.py
index 79b9da3..c472fac 100644
--- a/src/app.py
+++ b/src/app.py
@@ -7,6 +7,7 @@
import ai.LLM.Strategies.OLLAMAService
from config import config
+
import routes
import routes.v1.recommendations
import routes.v1.task
diff --git a/src/config.py b/src/config.py
index 80e5868..0c4129b 100644
--- a/src/config.py
+++ b/src/config.py
@@ -1,58 +1,33 @@
from pydantic import (
- BaseModel,
Field,
- RedisDsn,
ValidationInfo,
- model_validator,
- root_validator,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
-from pydantic import Field, RedisDsn, field_validator
+from pydantic import Field, field_validator
from typing import Optional
-class SubModel(BaseModel):
- foo: str = "bar"
- apple: int = 1
-
-
class Config(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
- ollama_url: str = Field(
- json_schema_extra="OLLAMA_URL", default="http://localhost:11434"
- )
- ollama_model: str = Field(json_schema_extra="OLLAMA_MODEL", default="phi3:mini")
+ ollama_url: str = Field(default="http://localhost:11434")
+ ollama_model: str = Field(default="phi3:mini")
- ai_strategy: Optional[str] = Field(
- json_schema_extra="AI_STRATEGY", default="OLLAMA"
- )
- anthropic_api_key: Optional[str] = Field(
- json_schema_extra="ANTHROPIC_API_KEY", default=None
- )
- openai_api_key: Optional[str] = Field(
- json_schema_extra="OPENAI_API_KEY", default=None
- )
+ ai_strategy: Optional[str] = Field(default="OLLAMA")
+ anthropic_api_key: Optional[str] = Field(default=None)
+ openai_api_key: Optional[str] = Field(default=None)
- postgres_server: str = Field(
- json_schema_extra="POSTGRES_SERVER", default="localhost"
- )
- postgres_port: int = Field(json_schema_extra="POSTGRES_PORT", default=5432)
- postgres_db: str = Field(json_schema_extra="POSTGRES_DB", default="app")
- postgres_user: str = Field(json_schema_extra="POSTGRES_USER", default="postgres")
- postgres_password: str = Field(
- json_schema_extra="POSTGRES_PASSWORD", default="postgres"
- )
- queue_processing_limit: int = Field(
- json_schema_extra="QUEUE_PROCESSING_LIMIT", default=10
- )
- redis_endpoint: str = Field(
- json_schema_extra="REDIS_ENDPOINT", default="redis://localhost:6379/0"
- )
- environment: str = Field(env="ENVIRONMENT", default="development")
- db_debug: bool = Field(env="DB_DEBUG", default=False)
+ postgres_server: str = Field(default="localhost")
+ postgres_port: int = Field(default=5432)
+ postgres_db: str = Field(default="app")
+ postgres_user: str = Field(default="postgres")
+ postgres_password: str = Field(default="postgres")
+ queue_processing_limit: int = Field(default=10)
+ redis_endpoint: str = Field(default="redis://localhost:6379/0")
+ environment: str = Field(default="development")
+ db_debug: bool = Field(default=False)
@field_validator(
"ai_strategy",
@@ -66,7 +41,6 @@ def check_ai_strategy(cls, ai_strategy, values):
@field_validator("openai_api_key")
def check_api_key(cls, api_key, info: ValidationInfo):
- print(info)
if info.data["ai_strategy"] == "OPENAI" and not api_key:
raise ValueError("OPENAI_API_KEY is required when ai_strategy is OPENAI")
return api_key
diff --git a/src/data/AggregatedSolution.py b/src/data/AggregatedSolution.py
index 52e675c..64473f2 100644
--- a/src/data/AggregatedSolution.py
+++ b/src/data/AggregatedSolution.py
@@ -1,18 +1,19 @@
from typing import List
from data.Finding import Finding
-from db.base import BaseModel
+from pydantic import BaseModel
-class AggregatedSolution:
+class AggregatedSolution(BaseModel):
findings: List[Finding] = None
solution: str = ""
metadata: dict = {}
- def __init__(self, findings: List[Finding], solution: str, metadata=None):
+ def from_result(self, findings: List[Finding], solution: str, metadata=None):
self.findings = findings
self.solution = solution
self.metadata = metadata
+ return self
def __str__(self):
return self.solution
@@ -21,8 +22,8 @@ def to_dict(self):
return {
"findings": [finding.to_dict() for finding in self.findings],
"solution": self.solution,
- "metadata": self.metadata
+ "metadata": self.metadata,
}
def to_html(self):
- return f"
{self.solution}
"
\ No newline at end of file
+ return f"{self.solution}
"
diff --git a/src/data/Finding.py b/src/data/Finding.py
index c41fea3..5a4a73c 100644
--- a/src/data/Finding.py
+++ b/src/data/Finding.py
@@ -1,9 +1,19 @@
from typing import List, Set, Optional, Any, get_args
from enum import Enum, auto
+import uuid
+
from pydantic import BaseModel, Field, PrivateAttr
from data.Solution import Solution
-from data.Categories import Category, TechnologyStack, SecurityAspect, SeverityLevel, RemediationType, \
- AffectedComponent, Compliance, Environment
+from data.Categories import (
+ Category,
+ TechnologyStack,
+ SecurityAspect,
+ SeverityLevel,
+ RemediationType,
+ AffectedComponent,
+ Compliance,
+ Environment,
+)
import json
import logging
@@ -12,6 +22,7 @@
class Finding(BaseModel):
+ id: str = Field(default_factory=lambda: f"{str(uuid.uuid4())}")
title: List[str] = Field(default_factory=list)
source: Set[str] = Field(default_factory=set)
descriptions: List[str] = Field(default_factory=list)
@@ -21,7 +32,7 @@ class Finding(BaseModel):
severity: Optional[int] = None
priority: Optional[int] = None
location_list: List[str] = Field(default_factory=list)
- category: Category = None
+ category: Optional[Category] = None
unsupervised_cluster: Optional[int] = None
solution: Optional["Solution"] = None
_llm_service: Optional[Any] = PrivateAttr(default=None)
@@ -35,7 +46,9 @@ def combine_descriptions(self) -> "Finding":
logger.error("LLM Service not set, cannot combine descriptions.")
return self
- self.description = self.llm_service.combine_descriptions(self.descriptions, self.cve_ids, self.cwe_ids)
+ self.description = self.llm_service.combine_descriptions(
+ self.descriptions, self.cve_ids, self.cwe_ids
+ )
return self
def add_category(self) -> "Finding":
@@ -47,34 +60,45 @@ def add_category(self) -> "Finding":
# Classify technology stack
technology_stack_options = list(TechnologyStack)
- self.category.technology_stack = self.llm_service.classify_kind(self, "technology_stack",
- technology_stack_options)
+ self.category.technology_stack = self.llm_service.classify_kind(
+ self, "technology_stack", technology_stack_options
+ )
# Classify security aspect
security_aspect_options = list(SecurityAspect)
- self.category.security_aspect = self.llm_service.classify_kind(self, "security_aspect", security_aspect_options)
+ self.category.security_aspect = self.llm_service.classify_kind(
+ self, "security_aspect", security_aspect_options
+ )
# Classify severity level
severity_level_options = list(SeverityLevel)
- self.category.severity_level = self.llm_service.classify_kind(self, "severity_level", severity_level_options)
+ self.category.severity_level = self.llm_service.classify_kind(
+ self, "severity_level", severity_level_options
+ )
# Classify remediation type
remediation_type_options = list(RemediationType)
- self.category.remediation_type = self.llm_service.classify_kind(self, "remediation_type",
- remediation_type_options)
+ self.category.remediation_type = self.llm_service.classify_kind(
+ self, "remediation_type", remediation_type_options
+ )
# Classify affected component
affected_component_options = list(AffectedComponent)
- self.category.affected_component = self.llm_service.classify_kind(self, "affected_component",
- affected_component_options)
+ self.category.affected_component = self.llm_service.classify_kind(
+ self, "affected_component", affected_component_options
+ )
# Classify compliance
compliance_options = list(Compliance)
- self.category.compliance = self.llm_service.classify_kind(self, "compliance", compliance_options)
+ self.category.compliance = self.llm_service.classify_kind(
+ self, "compliance", compliance_options
+ )
# Classify environment
environment_options = list(Environment)
- self.category.environment = self.llm_service.classify_kind(self, "environment", environment_options)
+ self.category.environment = self.llm_service.classify_kind(
+ self, "environment", environment_options
+ )
return self
@@ -213,9 +237,7 @@ def to_html(self, table=False):
result += "Name | Value |
"
result += f"Title | {', '.join(self.title)} |
"
result += f"Source | {', '.join(self.source)} |
"
- result += (
- f"Description | {self.description} |
"
- )
+ result += f"Description | {self.description} |
"
if len(self.location_list) > 0:
result += f"Location List | {' & '.join(map(str, self.location_list))} |
"
result += f"CWE IDs | {', '.join(self.cwe_ids)} |
"
@@ -223,7 +245,11 @@ def to_html(self, table=False):
result += f"Severity | {self.severity} |
"
result += f"Priority | {self.priority} |
"
if self.category is not None:
- result += 'Category | ' + str(self.category).replace("\n", " ") + ' |
'
+ result += (
+ "Category | "
+ + str(self.category).replace("\n", " ")
+ + " |
"
+ )
if self.unsupervised_cluster is not None:
result += f"Unsupervised Cluster | {self.unsupervised_cluster} |
"
result += ""
diff --git a/src/data/VulnerabilityReport.py b/src/data/VulnerabilityReport.py
index 3ec9b87..21a7346 100644
--- a/src/data/VulnerabilityReport.py
+++ b/src/data/VulnerabilityReport.py
@@ -7,7 +7,6 @@
from data.AggregatedSolution import AggregatedSolution
from data.Finding import Finding
from ai.LLM.LLMServiceStrategy import LLMServiceStrategy
-from ai.Clustering.AgglomerativeClusterer import AgglomerativeClusterer
import logging
@@ -51,9 +50,15 @@ def add_unsupervised_category(self, use_solution=True):
This function adds a category to each finding in the report.
:return: The VulnerabilityReport object and the AgglomerativeClustering object (in case you wanna have a look at the graph)
"""
+ from ai.Clustering.AgglomerativeClusterer import AgglomerativeClusterer
+
clustering = AgglomerativeClusterer(self)
- if use_solution and not all([finding.solution is not None for finding in self.findings]):
- logger.warning("Not all findings have a solution, falling back to using the description.")
+ if use_solution and not all(
+ [finding.solution is not None for finding in self.findings]
+ ):
+ logger.warning(
+ "Not all findings have a solution, falling back to using the description."
+ )
use_solution = False
clustering.add_unsupervised_category(use_solution=use_solution)
@@ -98,21 +103,27 @@ def sort(self, by: str = "severity", reverse: bool = True):
def to_dict(self):
findings = [f.to_dict() for f in self.findings]
if len(self.get_aggregated_solutions()) > 0:
- aggregated_solutions = [f.to_dict() for f in self.get_aggregated_solutions()]
+ aggregated_solutions = [
+ f.to_dict() for f in self.get_aggregated_solutions()
+ ]
return {"findings": findings, "aggregated_solutions": aggregated_solutions}
return {"findings": findings}
def __str__(self):
findings_str = "\n".join([str(f) for f in self.findings])
if len(self.get_aggregated_solutions()) > 0:
- aggregated_solutions_str = "\n".join([str(f) for f in self.get_aggregated_solutions()])
+ aggregated_solutions_str = "\n".join(
+ [str(f) for f in self.get_aggregated_solutions()]
+ )
return findings_str + "\n\n" + aggregated_solutions_str
return findings_str
def to_html(self, table=False):
my_str = "
".join([f.to_html(table) for f in self.findings])
if len(self.get_aggregated_solutions()) > 0:
- my_str += "
" + "
".join([f.to_html() for f in self.get_aggregated_solutions()])
+ my_str += "
" + "
".join(
+ [f.to_html() for f in self.get_aggregated_solutions()]
+ )
return my_str
def export_to_json(self, filename="VulnerabilityReport.json"):
@@ -142,7 +153,11 @@ def import_from_json(filename="VulnerabilityReport.json"):
def create_from_flama_json(
- json_data, n=-1, llm_service: "LLMServiceStrategy" = None, shuffle_data=False, combine_descriptions=True
+ json_data,
+ n=-1,
+ llm_service: "LLMServiceStrategy" = None,
+ shuffle_data=False,
+ combine_descriptions=True,
) -> VulnerabilityReport:
"""
This function creates a VulnerabilityReport object from a JSON object.
@@ -167,7 +182,9 @@ def create_from_flama_json(
# shuffle json_data to get a random sample
shuffle(json_data)
- for d in conditional_tqdm(json_data[:n], combine_descriptions, desc="Combining descriptions"):
+ for d in conditional_tqdm(
+ json_data[:n], combine_descriptions, desc="Combining descriptions"
+ ):
title = [x["element"] for x in d["title_list"]]
source = set([x["source"] for x in d["title_list"]])
description = [x["element"] for x in d.get("description_list", [])]
diff --git a/src/data/apischema.py b/src/data/apischema.py
index 90acc85..bfaaf25 100644
--- a/src/data/apischema.py
+++ b/src/data/apischema.py
@@ -1,7 +1,8 @@
from typing import Any, Dict, List, Literal, Optional
-from pydantic import BaseModel, validator
+from pydantic import BaseModel, Field
+from data.AggregatedSolution import AggregatedSolution
from data.Categories import Category
from data.Finding import Finding
from data.pagination import Pagination, PaginationInput
@@ -13,16 +14,26 @@
class SeverityFilter(BaseModel):
minValue: int
maxValue: int
+
+
class FindingInputFilter(BaseModel):
- severity: Optional[SeverityFilter] = None # ['low', 'high']
- priority: Optional[SeverityFilter] = None # ['low', 'high']
+ severity: Optional[SeverityFilter] = None # ['low', 'high']
+ priority: Optional[SeverityFilter] = None # ['low', 'high']
cve_ids: Optional[List[str]] = None
cwe_ids: Optional[List[str]] = None
source: Optional[List[str]] = None
-
+
+
+class UploadPreference(BaseModel):
+ long_description: Optional[bool] = Field(default=True)
+ search_terms: Optional[bool] = Field(default=True)
+ aggregated_solutions: Optional[bool] = Field(default=True)
+
+
class StartRecommendationTaskRequest(BaseModel):
- user_id: Optional[int] = None
- strategy: Optional[Literal["OLLAMA", "ANTHROPIC", "OPENAI"]] = "OLLAMA"
+ preferences: Optional[UploadPreference] = UploadPreference(
+ long_description=True, search_terms=True, aggregated_solutions=True
+ )
data: InputData
force_update: Optional[bool] = False
filter: Optional[FindingInputFilter] = None
@@ -36,7 +47,7 @@ class GetRecommendationFilter(BaseModel):
task_id: Optional[int] = None
date: Optional[str] = None
location: Optional[str] = None
- severity: Optional[SeverityFilter] = None # ['low', 'high']
+ severity: Optional[SeverityFilter] = None # ['low', 'high']
cve_id: Optional[str] = None
source: Optional[str] = None
@@ -47,14 +58,7 @@ class GetRecommendationRequest(BaseModel):
pagination: Optional[PaginationInput] = PaginationInput(offset=0, limit=10)
-class SolutionItem(BaseModel):
- short_description: str
- long_description: str
- metadata: dict
- search_terms: str
-
-
-class GetRecommendationResponseItem(Finding): # TODO adapt needed fields
+class GetRecommendationResponseItem(Finding):
pass
class GetRecommendationResponseItems(BaseModel):
@@ -70,11 +74,17 @@ class GetRecommendationTaskStatusResponse(BaseModel):
status: TaskStatus
-class GetSummarizedRecommendationRequest(BaseModel):
- user_id: Optional[int]
- pagination: PaginationInput = PaginationInput(offset=0, limit=10)
+class GetAggregatedRecommendationFilter(BaseModel):
+ task_id: Optional[int] = None
+
+
+class GetAggregatedRecommendationRequest(BaseModel):
+ filter: Optional[GetAggregatedRecommendationFilter] = None
+
+
+class GetAggregatedRecommendationResponseItem(AggregatedSolution):
+ pass
-# class GetSummarizedRecommendationResponse(BaseModel):
-# recommendation: list[Recommendation]
-# pagination: Pagination
+class GetAggregatedRecommendationResponse(BaseModel):
+ items: List[GetAggregatedRecommendationResponseItem]
diff --git a/src/data/helper.py b/src/data/helper.py
index b7e4c35..bc2cba0 100644
--- a/src/data/helper.py
+++ b/src/data/helper.py
@@ -3,41 +3,41 @@
from data.apischema import FindingInputFilter
from data.types import Content, InputData
-# ##TODO:maybe using pydantic should be enough
-# def validate_json(data: any) -> bool:
-# try:
-# json_data = data
-# try:
-# validate(instance=json_data, schema=schema)
-# print("JSON data adheres to the schema.")
-# except jsonschema.exceptions.ValidationError as e:
-# print("JSON data does not adhere to the schema.")
-# print(e)
-# except ValueError as e:
-# return False
-
-# return True
-
def get_content_list(json_data: InputData) -> list[Content]:
return json_data.message.content
-def filter_findings(findings: List[Content], filter: FindingInputFilter) -> List[Content]:
+
+def filter_findings(
+ findings: List[Content], filter: FindingInputFilter
+) -> List[Content]:
def matches(content: Content) -> bool:
- if filter.source and not any(title.source in filter.source for title in content.title_list):
+ if filter.source and not any(
+ title.source in filter.source for title in content.title_list
+ ):
return False
- if filter.severity and not (filter.severity.minValue <= content.severity <= filter.severity.maxValue):
+ if filter.severity and not (
+ filter.severity.minValue <= content.severity <= filter.severity.maxValue
+ ):
return False
- if filter.priority and not (filter.priority.minValue <= content.priority <= filter.priority.maxValue):
+ if filter.priority and not (
+ filter.priority.minValue <= content.priority <= filter.priority.maxValue
+ ):
return False
- if filter.cve_ids and not any(cve.element in filter.cve_ids for cve in content.cve_id_list):
+ if filter.cve_ids and not any(
+ cve.element in filter.cve_ids for cve in content.cve_id_list
+ ):
return False
- if filter.cwe_ids and not any(element in filter.cwe_ids for cwe in content.cwe_id_list or [] for element in cwe.element):
+ if filter.cwe_ids and not any(
+ element in filter.cwe_ids
+ for cwe in content.cwe_id_list or []
+ for element in cwe.element
+ ):
return False
return True
- return [content for content in findings if matches(content)]
\ No newline at end of file
+ return [content for content in findings if matches(content)]
diff --git a/src/data/types.py b/src/data/types.py
index f0e3f75..62582c3 100644
--- a/src/data/types.py
+++ b/src/data/types.py
@@ -1,4 +1,4 @@
-from typing import List, Dict, Any, Union,Optional
+from typing import List, Dict, Any, Union, Optional
from typing import Any, Optional, Tuple
from typing_extensions import Annotated
@@ -6,8 +6,7 @@
from pydantic import BaseModel
-
-
+# deprecated
schema = {
"type": "object",
"properties": {
@@ -25,17 +24,47 @@
"doc_type": {"type": "string"},
"criticality_tag": {"type": ["array", "object"]},
"knowledge_type": {"type": "string"},
- "requirement_list": {"type": "array", "items": {"type": "string"}},
- "title_list": {"type": "array", "items": {"type": "object"}},
- "location_list": {"type": "array", "items": {"type": "object"}},
- "description_list": {"type": "array", "items": {"type": "object"}},
- "internal_rating_list": {"type": "array", "items": {"type": "object"}},
- "internal_ratingsource_list": {"type": "array", "items": {"type": "object"}},
- "cvss_rating_list": {"type": "array", "items": {"type": "object"}},
+ "requirement_list": {
+ "type": "array",
+ "items": {"type": "string"},
+ },
+ "title_list": {
+ "type": "array",
+ "items": {"type": "object"},
+ },
+ "location_list": {
+ "type": "array",
+ "items": {"type": "object"},
+ },
+ "description_list": {
+ "type": "array",
+ "items": {"type": "object"},
+ },
+ "internal_rating_list": {
+ "type": "array",
+ "items": {"type": "object"},
+ },
+ "internal_ratingsource_list": {
+ "type": "array",
+ "items": {"type": "object"},
+ },
+ "cvss_rating_list": {
+ "type": "array",
+ "items": {"type": "object"},
+ },
"rule_list": {"type": "array", "items": {"type": "object"}},
- "cwe_id_list": {"type": "array", "items": {"type": "object"}},
- "cve_id_list": {"type": "array", "items": {"type": "object"}},
- "activity_list": {"type": "array", "items": {"type": "object"}},
+ "cwe_id_list": {
+ "type": "array",
+ "items": {"type": "object"},
+ },
+ "cve_id_list": {
+ "type": "array",
+ "items": {"type": "object"},
+ },
+ "activity_list": {
+ "type": "array",
+ "items": {"type": "object"},
+ },
"first_found": {"type": "string"},
"last_found": {"type": "string"},
"report_amount": {"type": "integer"},
@@ -46,25 +75,45 @@
"priority_explanation": {"type": "string"},
"sum_id": {"type": "string"},
"prio_id": {"type": "string"},
- "element_tag": {"type": "string"}
+ "element_tag": {"type": "string"},
},
- "required": ["doc_type", "criticality_tag", "knowledge_type", "requirement_list", "title_list",
- "location_list", "description_list", "internal_rating_list",
- "internal_ratingsource_list", "cvss_rating_list", "rule_list",
- "cwe_id_list", "cve_id_list", "activity_list", "first_found", "last_found",
- "report_amount", "content_hash", "severity", "severity_explanation",
- "priority", "priority_explanation", "sum_id", "prio_id", "element_tag"]
- }
- }
+ "required": [
+ "doc_type",
+ "criticality_tag",
+ "knowledge_type",
+ "requirement_list",
+ "title_list",
+ "location_list",
+ "description_list",
+ "internal_rating_list",
+ "internal_ratingsource_list",
+ "cvss_rating_list",
+ "rule_list",
+ "cwe_id_list",
+ "cve_id_list",
+ "activity_list",
+ "first_found",
+ "last_found",
+ "report_amount",
+ "content_hash",
+ "severity",
+ "severity_explanation",
+ "priority",
+ "priority_explanation",
+ "sum_id",
+ "prio_id",
+ "element_tag",
+ ],
+ },
+ },
},
- "required": ["version", "utc_age", "content"]
- }
+ "required": ["version", "utc_age", "content"],
+ },
},
- "required": ["status", "message"]
+ "required": ["status", "message"],
}
-
class Tag(BaseModel):
action: str
user_mail: str
@@ -73,8 +122,6 @@ class Tag(BaseModel):
valid_until: str
-
-
class Location(BaseModel):
location: str
amount: int
@@ -84,53 +131,46 @@ class Location(BaseModel):
tags: List[Tag]
-
class Title(BaseModel):
element: str
source: str
-
class Description(BaseModel):
element: str
source: str
-
class Rating(BaseModel):
element: str
source: str
-
class CvssRating(BaseModel):
element: str
source: str
-
class Rule(BaseModel):
element: str
source: str
-
class CveId(BaseModel):
element: str
source: str
+
class CweId(BaseModel):
element: List[str]
source: str
-
class Activity(BaseModel):
element: str
source: str
-
class Content(BaseModel):
doc_type: str
criticality_tag: Union[List[Union[str, int]], Dict[str, Any]]
@@ -138,13 +178,13 @@ class Content(BaseModel):
requirement_list: List[str]
title_list: List[Title]
location_list: List[Location]
- description_list: Optional[List[Description]] = []
- internal_rating_list: Optional[ List[Rating]] = []
+ description_list: Optional[List[Description]] = []
+ internal_rating_list: Optional[List[Rating]] = []
internal_ratingsource_list: List[Rating]
- cvss_rating_list: Optional[ List[CvssRating]] = []
+ cvss_rating_list: Optional[List[CvssRating]] = []
rule_list: List[Rule]
- cve_id_list: Optional[ List[CveId]] = []
- cwe_id_list: Optional[ List[CweId]] = []
+ cve_id_list: Optional[List[CveId]] = []
+ cwe_id_list: Optional[List[CweId]] = []
activity_list: List[Activity]
first_found: str
last_found: str
@@ -159,24 +199,21 @@ class Content(BaseModel):
element_tag: str
-
class Message(BaseModel):
version: str
utc_age: str
content: List[Content]
-
class InputData(BaseModel):
status: str
message: Message
-
class Finding(BaseModel):
content: Content
solutions: List[Dict[str, Any]] = []
-
+
class Recommendation(BaseModel):
recommendation: str
diff --git a/src/db/models.py b/src/db/models.py
index 62aec28..f604a2e 100644
--- a/src/db/models.py
+++ b/src/db/models.py
@@ -1,7 +1,7 @@
from enum import Enum as PyEnum
from typing import List, Optional
-from sqlalchemy import JSON, Enum, ForeignKey, Integer, String
+from sqlalchemy import JSON, Enum, ForeignKey, Integer, String, Table
from sqlalchemy.orm import Mapped, relationship
from data.types import Content
@@ -9,13 +9,29 @@
from .base import BaseModel, Column
+findings_aggregated_association_table = Table(
+ "findings_aggregated_association",
+ BaseModel.metadata,
+ Column(
+ "finding_id",
+ ForeignKey("findings.id", ondelete="CASCADE"),
+ primary_key=True,
+ ),
+ Column(
+ "aggregated_recommendation_id",
+ ForeignKey("aggregated_recommendations.id", ondelete="CASCADE"),
+ primary_key=True,
+ ),
+)
+
+
class Recommendation(BaseModel):
__tablename__ = "recommendations"
description_short = Column(String, nullable=True)
description_long = Column(String, nullable=True)
search_terms = Column(String, nullable=True)
meta = Column(JSON, default={}, nullable=True)
- category = Column(String, nullable=True)
+ category = Column(JSON, nullable=True)
finding_id: int = Column(Integer, ForeignKey("findings.id"), nullable=True)
recommendation_task_id = Column(
Integer,
@@ -27,7 +43,25 @@ class Recommendation(BaseModel):
)
def __repr__(self):
- return f""
+ return f""
+
+
+class AggregatedRecommendation(BaseModel):
+ __tablename__ = "aggregated_recommendations"
+ solution = Column(String, nullable=True)
+ meta = Column(JSON, default={}, nullable=True)
+ recommendation_task_id = Column(
+ Integer,
+ ForeignKey("recommendation_task.id", ondelete="CASCADE"),
+ nullable=False,
+ )
+ findings: Mapped[List["Finding"]] = relationship(
+ secondary=findings_aggregated_association_table,
+ back_populates="aggregated_recommendations",
+ )
+
+ def __repr__(self):
+ return f""
class Finding(BaseModel):
@@ -37,7 +71,7 @@ class Finding(BaseModel):
title_list = Column(JSON, default=None, nullable=True)
description_list = Column(JSON, default=[], nullable=True)
locations_list = Column(JSON, default=[], nullable=True)
- category = Column(String, nullable=True)
+ category = Column(JSON, nullable=True)
cwe_id_list = Column(JSON, default=[], nullable=True)
cve_id_list = Column(JSON, default=[], nullable=True)
priority = Column(Integer, default=None, nullable=True)
@@ -62,6 +96,10 @@ class Finding(BaseModel):
"Recommendation", back_populates="finding"
)
+ aggregated_recommendations: Mapped[List["AggregatedRecommendation"]] = relationship(
+ secondary=findings_aggregated_association_table, back_populates="findings"
+ )
+
def from_data(self, data: Content):
self.cve_id_list = (
[x.dict() for x in data.cve_id_list] if data.cve_id_list else []
diff --git a/src/db/my_db.py b/src/db/my_db.py
index 6b1e554..fb524da 100644
--- a/src/db/my_db.py
+++ b/src/db/my_db.py
@@ -1,4 +1,4 @@
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine, NullPool
import os
from sqlalchemy.orm import sessionmaker
@@ -7,7 +7,9 @@
from config import config
engine = create_engine(
- config.get_db_url(), echo=os.getenv("DB_DEBUG", "false") == "true"
+ config.get_db_url(),
+ echo=os.getenv("DB_DEBUG", "false") == "true",
+ poolclass=NullPool,
)
diff --git a/src/dto/finding.py b/src/dto/finding.py
index 6e9486b..d03f701 100644
--- a/src/dto/finding.py
+++ b/src/dto/finding.py
@@ -7,12 +7,15 @@
def db_finding_to_response_item(
find: DBFinding,
) -> GetRecommendationResponseItem:
+ category = None
+ try:
+ recommendation = find.recommendations[0]
+ category = Category.model_validate_json(recommendation.category)
+ except Exception as e:
+ category = Category()
+
return GetRecommendationResponseItem(
- category=Category(
- affected_component=(
- AffectedComponent(find.category) if find.category else None
- )
- ),
+ category=category,
solution=Solution(
short_description=(
find.recommendations[0].description_short
diff --git a/src/extract-docs.py b/src/extract-docs.py
index 04df878..6fb6dfb 100644
--- a/src/extract-docs.py
+++ b/src/extract-docs.py
@@ -1,10 +1,20 @@
import app as app
from fastapi.openapi.docs import get_swagger_ui_html
import os
+import argparse
if __name__ == "__main__":
- output_path = os.path.join(os.getcwd(), ".docs")
-
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-o",
+ "--output",
+ help="Output path for documentation",
+ required=False,
+ default=os.path.join(os.getcwd(), ".docs"),
+ )
+ args = parser.parse_args()
+ output_path = args.output
+ print(f"Generating documentation in {output_path}")
if not os.path.exists(output_path):
os.makedirs(output_path)
diff --git a/src/migrations/versions/06253ed3ac28_change_category.py b/src/migrations/versions/06253ed3ac28_change_category.py
new file mode 100644
index 0000000..1389402
--- /dev/null
+++ b/src/migrations/versions/06253ed3ac28_change_category.py
@@ -0,0 +1,50 @@
+"""change_category
+
+Revision ID: 06253ed3ac28
+Revises: 9e31625c7978
+Create Date: 2024-07-26 11:07:19.507484
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+import sqlmodel.sql.sqltypes
+import db
+
+
+# revision identifiers, used by Alembic.
+revision = "06253ed3ac28"
+down_revision = "9e31625c7978"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+
+ op.execute(
+ "ALTER TABLE recommendations ALTER COLUMN category TYPE json using category::json;"
+ )
+
+ op.execute(
+ "alter table findings alter column category type json using category::json;"
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.alter_column(
+ "recommendations",
+ "category",
+ existing_type=sa.JSON(),
+ type_=sa.VARCHAR(),
+ existing_nullable=True,
+ )
+ op.alter_column(
+ "findings",
+ "category",
+ existing_type=sa.JSON(),
+ type_=sa.VARCHAR(),
+ existing_nullable=True,
+ )
+ # ### end Alembic commands ###
diff --git a/src/migrations/versions/9e31625c7978_aggregated.py b/src/migrations/versions/9e31625c7978_aggregated.py
new file mode 100644
index 0000000..13fe48e
--- /dev/null
+++ b/src/migrations/versions/9e31625c7978_aggregated.py
@@ -0,0 +1,47 @@
+"""aggregated
+
+Revision ID: 9e31625c7978
+Revises: 8e2452c67d2d
+Create Date: 2024-07-26 08:56:17.779435
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import sqlmodel.sql.sqltypes
+import db
+
+
+# revision identifiers, used by Alembic.
+revision = '9e31625c7978'
+down_revision = '8e2452c67d2d'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('aggregated_recommendations',
+ sa.Column('solution', sa.String(), nullable=True),
+ sa.Column('meta', sa.JSON(), nullable=True),
+ sa.Column('recommendation_task_id', sa.Integer(), nullable=False),
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
+ sa.ForeignKeyConstraint(['recommendation_task_id'], ['recommendation_task.id'], ondelete='CASCADE'),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_table('findings_aggregated_association',
+ sa.Column('finding_id', sa.Integer(), nullable=False),
+ sa.Column('aggregated_recommendation_id', sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(['aggregated_recommendation_id'], ['aggregated_recommendations.id'], ),
+ sa.ForeignKeyConstraint(['finding_id'], ['findings.id'], ),
+ sa.PrimaryKeyConstraint('finding_id', 'aggregated_recommendation_id')
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('findings_aggregated_association')
+ op.drop_table('aggregated_recommendations')
+ # ### end Alembic commands ###
\ No newline at end of file
diff --git a/src/migrations/versions/cae76af4e118_on_delete_aggregated.py b/src/migrations/versions/cae76af4e118_on_delete_aggregated.py
new file mode 100644
index 0000000..6479987
--- /dev/null
+++ b/src/migrations/versions/cae76af4e118_on_delete_aggregated.py
@@ -0,0 +1,36 @@
+"""on_delete_aggregated
+
+Revision ID: cae76af4e118
+Revises: 06253ed3ac28
+Create Date: 2024-07-26 19:17:05.829930
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import sqlmodel.sql.sqltypes
+import db
+
+
+# revision identifiers, used by Alembic.
+revision = 'cae76af4e118'
+down_revision = '06253ed3ac28'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_constraint('findings_aggregated_association_finding_id_fkey', 'findings_aggregated_association', type_='foreignkey')
+ op.drop_constraint('findings_aggregated_associati_aggregated_recommendation_id_fkey', 'findings_aggregated_association', type_='foreignkey')
+ op.create_foreign_key(None, 'findings_aggregated_association', 'findings', ['finding_id'], ['id'], ondelete='CASCADE')
+ op.create_foreign_key(None, 'findings_aggregated_association', 'aggregated_recommendations', ['aggregated_recommendation_id'], ['id'], ondelete='CASCADE')
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_constraint(None, 'findings_aggregated_association', type_='foreignkey')
+ op.drop_constraint(None, 'findings_aggregated_association', type_='foreignkey')
+ op.create_foreign_key('findings_aggregated_associati_aggregated_recommendation_id_fkey', 'findings_aggregated_association', 'aggregated_recommendations', ['aggregated_recommendation_id'], ['id'])
+ op.create_foreign_key('findings_aggregated_association_finding_id_fkey', 'findings_aggregated_association', 'findings', ['finding_id'], ['id'])
+ # ### end Alembic commands ###
\ No newline at end of file
diff --git a/src/repository/finding.py b/src/repository/finding.py
index 88ac7f1..db2cb27 100644
--- a/src/repository/finding.py
+++ b/src/repository/finding.py
@@ -1,4 +1,5 @@
from fastapi import Depends
+
from sqlalchemy import Date, Integer, cast, func
from sqlalchemy.orm import Session
@@ -6,6 +7,7 @@
from data.apischema import SeverityFilter
from data.pagination import PaginationInput
from db.my_db import get_db
+from repository.types import GetFindingsByFilterInput
class FindingRepository:
@@ -18,6 +20,19 @@ def get_findings(self, num_findings: int) -> list[db_models.Finding]:
findings = self.session.query(db_models.Finding).limit(num_findings).all()
return findings
+ def get_all_findings_by_task_id_for_processing(self, task_id: int, limit: int = -1):
+ query = (
+ self.session.query(db_models.Finding)
+ .join(db_models.RecommendationTask)
+ .filter(db_models.RecommendationTask.id == task_id)
+ )
+
+ # Remove Limit For Now
+ if limit > 0:
+ query = query.limit(limit)
+
+ return query.all()
+
def get_findings_by_task_id(
self, task_id: int, pagination: PaginationInput
) -> list[db_models.Finding]:
@@ -34,25 +49,38 @@ def get_findings_by_task_id(
)
return findings
-
- def get_findings_by_task_id_and_severity(
- self, task_id: int, severity: SeverityFilter, pagination: PaginationInput
+
+ def get_findings_by_task_id_and_filter(
+ self, input: GetFindingsByFilterInput
) -> list[db_models.Finding]:
- findings = (
+ task_id = input.task_id
+ pagination = input.pagination
+ severity = input.severityFilter
+ query = (
self.session.query(db_models.Finding)
.join(db_models.RecommendationTask)
.where(
db_models.RecommendationTask.status == db_models.TaskStatus.COMPLETED,
(db_models.RecommendationTask.id == task_id),
- db_models.Finding.severity >= severity.minValue,
- db_models.Finding.severity <= severity.maxValue,
+ (
+ db_models.Finding.severity >= severity.minValue
+ if severity and severity.minValue
+ else True
+ ),
+ (
+ db_models.Finding.severity <= severity.maxValue
+ if severity and severity.maxValue
+ else True
+ ),
)
- .offset(pagination.offset)
- .limit(pagination.limit)
- .all()
)
- return findings
+ total = query.count()
+ findings = query.all()
+ if pagination:
+ query = query.offset(pagination.offset).limit(pagination.limit)
+
+ return findings, total
def get_findings_count_by_task_id(self, task_id: int) -> int:
count = (
@@ -74,5 +102,5 @@ def create_findings(
self.session.commit()
-def get_finding_repository(session: Depends = Depends(get_db)):
+def get_finding_repository(session: Session = Depends(get_db)):
return FindingRepository(session)
diff --git a/src/repository/recommendation.py b/src/repository/recommendation.py
index 3c3e36a..5f2147e 100644
--- a/src/repository/recommendation.py
+++ b/src/repository/recommendation.py
@@ -1,3 +1,4 @@
+from data.Finding import Finding
import db.models as db_models
from sqlalchemy.orm import Session, sessionmaker
@@ -6,6 +7,8 @@
from db.my_db import get_db
from fastapi import Depends
+from repository.types import CreateAggregatedRecommendationInput
+
class RecommendationRepository:
session: Session
@@ -33,6 +36,83 @@ def get_recommendation_by_id(
)
return recommendation
+ def create_recommendations(
+ self,
+ finding_with_solution: list[tuple[str, Finding]],
+ recommendation_task_id: int,
+ ):
+ session = self.session
+ for finding_id, f in finding_with_solution:
+ finding = (
+ session.query(db_models.Finding)
+ .filter(db_models.Finding.id == finding_id)
+ .first()
+ )
+ if finding is None:
+ print(f"Finding with id {finding_id} not found")
+ continue
+ recommendation = db_models.Recommendation(
+ description_short=(
+ f.solution.short_description
+ if f.solution.short_description
+ else "No short description available"
+ ),
+ description_long=(
+ f.solution.long_description
+ if f.solution.long_description
+ else "No long description available"
+ ),
+ meta=f.solution.metadata if f.solution.metadata else {},
+ search_terms=f.solution.search_terms if f.solution.search_terms else [],
+ finding_id=finding_id,
+ recommendation_task_id=recommendation_task_id,
+ category=(f.category.model_dump_json() if f.category else None),
+ )
+ session.add(recommendation)
+ ## update recommendation task status
+
+ session.commit()
+
+ def create_aggregated_solutions(
+ self,
+ input: CreateAggregatedRecommendationInput,
+ ):
+
+ for solution in input.aggregated_solutions:
+ aggregated_rec = db_models.AggregatedRecommendation(
+ solution=solution.solution.solution,
+ meta=solution.solution.metadata,
+ recommendation_task_id=input.recommendation_task_id,
+ )
+ self.session.add(aggregated_rec)
+ self.session.commit()
+ self.session.refresh(aggregated_rec)
+ for findings_id in solution.findings_db_ids:
+
+ res = self.session.execute(
+ db_models.findings_aggregated_association_table.insert().values(
+ finding_id=findings_id,
+ aggregated_recommendation_id=aggregated_rec.id,
+ )
+ )
+ res.close()
+
+ self.session.commit()
+
+ def get_aggregated_solutions(
+ self, recommendation_task_id: int
+ ) -> list[db_models.AggregatedRecommendation]:
+
+ aggregated_solutions = (
+ self.session.query(db_models.AggregatedRecommendation)
+ .filter(
+ db_models.AggregatedRecommendation.recommendation_task_id
+ == recommendation_task_id
+ )
+ .all()
+ )
+ return aggregated_solutions
+
def get_recommendation_repository(session: Session = Depends(get_db)):
return RecommendationRepository(session)
diff --git a/src/repository/task.py b/src/repository/task.py
index b9ef75f..85603a0 100644
--- a/src/repository/task.py
+++ b/src/repository/task.py
@@ -29,15 +29,23 @@ def update_task(self, task: db_models.RecommendationTask, celery_task_id: str):
self.session.refresh(task)
return task
+ def update_task_completed(self, task_id: int):
+ status = db_models.TaskStatus.COMPLETED
+ self.session.query(db_models.RecommendationTask).filter(
+ db_models.RecommendationTask.id == task_id
+ ).update({db_models.RecommendationTask.status: status})
+
+ self.session.commit()
+ self.session.flush()
+
def get_tasks(
self,
) -> list[db_models.RecommendationTask]:
-
+ print("get_tasks")
tasks = self.session.query(db_models.RecommendationTask).all()
return tasks
def get_task_by_id(self, task_id: int) -> db_models.RecommendationTask | None:
-
task = (
self.session.query(db_models.RecommendationTask)
.where(db_models.RecommendationTask.id == task_id)
diff --git a/src/repository/types.py b/src/repository/types.py
new file mode 100644
index 0000000..b7ef472
--- /dev/null
+++ b/src/repository/types.py
@@ -0,0 +1,22 @@
+from typing import Optional
+from pydantic import BaseModel
+
+from data.AggregatedSolution import AggregatedSolution
+from data.apischema import SeverityFilter
+from data.pagination import PaginationInput
+
+
+class AggregatedSolutionInput(BaseModel):
+ solution: AggregatedSolution
+ findings_db_ids: list[int]
+
+
+class CreateAggregatedRecommendationInput(BaseModel):
+ aggregated_solutions: list[AggregatedSolutionInput]
+ recommendation_task_id: int
+
+
+class GetFindingsByFilterInput(BaseModel):
+ task_id: int
+ severityFilter: Optional[SeverityFilter] = None
+ pagination: Optional[PaginationInput] = None
diff --git a/src/routes/v1/recommendations.py b/src/routes/v1/recommendations.py
index e05db88..cf39f92 100644
--- a/src/routes/v1/recommendations.py
+++ b/src/routes/v1/recommendations.py
@@ -1,5 +1,5 @@
import datetime
-from typing import Annotated
+from typing import Annotated, Optional
from fastapi import Body, Depends, HTTPException, Response
from fastapi.routing import APIRouter
@@ -8,10 +8,14 @@
import data.apischema as apischema
import db.models as db_models
+from data.AggregatedSolution import AggregatedSolution
from db.my_db import get_db
from dto.finding import db_finding_to_response_item
from repository.finding import get_finding_repository
+from repository.recommendation import (RecommendationRepository,
+ get_recommendation_repository)
from repository.task import TaskRepository, get_task_repository
+from repository.types import GetFindingsByFilterInput
router = APIRouter(
prefix="/recommendations",
@@ -63,26 +67,65 @@ def recommendations(
status_code=400,
detail="Recommendation task failed",
)
- if severityFilter:
- findings = finding_repository.get_findings_by_task_id_and_severity(task.id, severityFilter, request.pagination)
- else:
- findings = finding_repository.get_findings_by_task_id(task.id, request.pagination)
+ findings, total = finding_repository.get_findings_by_task_id_and_filter(
+ GetFindingsByFilterInput(
+ task_id=task.id,
+ severityFilter=severityFilter if severityFilter else None,
+ pagination=request.pagination,
+ )
+ )
- total_count = finding_repository.get_findings_count_by_task_id(task.id)
-
response = apischema.GetRecommendationResponse(
- items=apischema.GetRecommendationResponseItems(
- findings= [db_finding_to_response_item(find) for find in findings],
- aggregated_solutions= []),
+ items=[db_finding_to_response_item(find) for find in findings],
pagination=apischema.Pagination(
offset=request.pagination.offset,
limit=request.pagination.limit,
- total=total_count,
+ total=total,
count=len(findings),
),
)
+ return response
- if not response or len(response.items.findings) == 0:
- return Response(status_code=204, headers={"Retry-After": "120"})
- return response
+@router.post("/aggregated")
+def aggregated_solutions(
+ request: Annotated[
+ Optional[apischema.GetAggregatedRecommendationRequest], Body(...)
+ ],
+ task_repository: TaskRepository = Depends(get_task_repository),
+ recommendation_repository: RecommendationRepository = (
+ Depends(get_recommendation_repository)
+ ),
+) -> apischema.GetAggregatedRecommendationResponse:
+
+ today = datetime.datetime.now().date()
+ task = None
+ if request.filter and request.filter.task_id:
+ task = task_repository.get_task_by_id(request.filter.task_id)
+ task = task_repository.get_task_by_date(today)
+
+ if not task:
+ raise HTTPException(
+ status_code=404,
+ detail=(
+ f"Task with id {request.filter.task_id} not found"
+ if request.filter and request.filter.task_id
+ else "Task for today not found"
+ ),
+ )
+ if task.status != db_models.TaskStatus.COMPLETED:
+ raise HTTPException(
+ status_code=400,
+ detail="Recommendation status:" + task.status.value,
+ )
+ agg_recs = recommendation_repository.get_aggregated_solutions(task.id)
+ return apischema.GetAggregatedRecommendationResponse(
+ items=[
+ apischema.GetAggregatedRecommendationResponseItem(
+ solution=rec.solution,
+ findings=[db_finding_to_response_item(x) for x in rec.findings],
+ metadata=rec.meta,
+ )
+ for rec in agg_recs
+ ]
+ )
diff --git a/src/routes/v1/upload.py b/src/routes/v1/upload.py
index 0aa02dd..02aedd9 100644
--- a/src/routes/v1/upload.py
+++ b/src/routes/v1/upload.py
@@ -33,7 +33,7 @@ async def upload(
content_list = get_content_list(data.data)
if data.filter:
content_list = filter_findings(content_list, data.filter)
-
+
today = datetime.datetime.now().date()
existing_task = task_repository.get_task_by_date(today)
if existing_task and not data.force_update:
@@ -59,7 +59,13 @@ async def upload(
finding_repository.create_findings(findings)
celery_result = worker.send_task(
- "worker.generate_report", args=[recommendation_task.id]
+ "worker.generate_report",
+ args=[
+ recommendation_task.id,
+ data.preferences.long_description,
+ data.preferences.search_terms,
+ data.preferences.aggregated_solutions,
+ ],
)
# update the task with the celery task id
diff --git a/src/worker/worker.py b/src/worker/worker.py
index f2f4178..9994406 100644
--- a/src/worker/worker.py
+++ b/src/worker/worker.py
@@ -1,16 +1,31 @@
import logging
from celery import Celery
+from celery.signals import worker_init
import db.models as db_models
-from db.my_db import Session
+from db.my_db import engine, sessionmaker
+from repository.finding import FindingRepository
+from repository.recommendation import RecommendationRepository
+from repository.task import TaskRepository
+from repository.types import (
+ AggregatedSolutionInput,
+ CreateAggregatedRecommendationInput,
+)
+
+
+from ai.LLM.LLMServiceStrategy import LLMServiceStrategy
+from ai.LLM.Strategies.OLLAMAService import OLLAMAService
+from data.VulnerabilityReport import create_from_flama_json
+from ai.Grouping.FindingGrouper import FindingGrouper
+
logger = logging.getLogger(__name__)
from config import config
-
+Session = sessionmaker(engine)
redis_url = config.redis_endpoint
print(f"Redis URL: {redis_url}")
@@ -25,11 +40,17 @@ def error(self, exc, task_id, args, kwargs, einfo):
@worker.task(name="worker.generate_report", on_failure=error)
-def generate_report(recommendation_task_id: int, stragegy: str = "OLLAMA"):
+def generate_report(
+ recommendation_task_id: int,
+ generate_long_solution: bool = True,
+ generate_search_terms: bool = True,
+ generate_aggregate_solutions: bool = True,
+):
+
# importing here so importing worker does not import all the dependencies
- from ai.LLM.LLMServiceStrategy import LLMServiceStrategy
- from ai.LLM.Strategies.OLLAMAService import OLLAMAService
- from data.VulnerabilityReport import create_from_flama_json
+ # from ai.LLM.LLMServiceStrategy import LLMServiceStrategy
+ # from ai.LLM.Strategies.OLLAMAService import OLLAMAService
+ # from data.VulnerabilityReport import create_from_flama_json
ollama_strategy = OLLAMAService()
llm_service = LLMServiceStrategy(ollama_strategy)
@@ -37,71 +58,91 @@ def generate_report(recommendation_task_id: int, stragegy: str = "OLLAMA"):
logger.info(f"Processing recommendation task with id {recommendation_task_id}")
logger.info(f"Processing recommendation task with limit {limit}")
logger.info(
- f"Processing recommendation task with model_name {ollama_strategy.model_name}"
+ f"long: {generate_long_solution}, search: {generate_search_terms}, aggregate: {generate_aggregate_solutions}"
)
- with Session() as session:
- query = (
- session.query(db_models.Finding)
- .join(db_models.RecommendationTask)
- .filter(db_models.RecommendationTask.id == recommendation_task_id)
- )
- if limit > 0:
- query = query.limit(limit)
-
- findings_from_db = query.all()
- logger.info(f"Found {len(findings_from_db)} findings for recommendation task")
- if not findings_from_db:
- logger.warn(
- f"No findings found for recommendation task {recommendation_task_id}"
+
+ recommendationTask = None
+ try:
+ with Session() as session:
+
+ taskRepo = TaskRepository(session)
+ recommendationTask = taskRepo.get_task_by_id(recommendation_task_id)
+
+ if not recommendationTask:
+ logger.error(
+ f"Recommendation task with id {recommendation_task_id} not found"
+ )
+ return
+ find_repo = FindingRepository(session)
+ findings_from_db = find_repo.get_all_findings_by_task_id_for_processing(
+ recommendation_task_id, limit
+ )
+
+ logger.info(
+ f"Found {len(findings_from_db)} findings for recommendation task"
)
- return
+ if not findings_from_db:
+ logger.warn(
+ f"No findings found for recommendation task {recommendation_task_id}"
+ )
+ return
+
+ findings = [f.raw_data for f in findings_from_db]
+ finding_ids = [f.id for f in findings_from_db]
+
+ except Exception as e:
+ logger.error(f"Error processing recommendation task: {e}")
+ return
- findings = [f.raw_data for f in findings_from_db]
- finding_ids = [f.id for f in findings_from_db]
vulnerability_report = create_from_flama_json(
- findings, n=limit, llm_service=llm_service
+ findings,
+ n=limit,
+ llm_service=llm_service,
+ shuffle_data=False, # set it true may cause zip to work incorrectly
)
+
+ # map findings to db ids, used for aggregated solutions
+ finding_id_map = {
+ finding.id: db_id
+ for db_id, finding in zip(finding_ids, vulnerability_report.findings)
+ }
+
vulnerability_report.add_category()
- vulnerability_report.add_solution()
+ vulnerability_report.add_solution(
+ search_term=generate_search_terms, long=generate_long_solution
+ )
- with Session() as session:
- for finding_id, f in zip(finding_ids, vulnerability_report.findings):
- finding = (
- session.query(db_models.Finding)
- .filter(db_models.Finding.id == finding_id)
- .first()
- )
- if finding is None:
- print(f"Finding with id {finding_id} not found")
- continue
- recommendation = db_models.Recommendation(
- description_short=(
- f.solution.short_description
- if f.solution.short_description
- else "No short description"
- ),
- description_long=(
- f.solution.long_description
- if f.solution.long_description
- else "No long description"
- ),
- meta=f.solution.metadata if f.solution.metadata else {},
- search_terms=f.solution.search_terms if f.solution.search_terms else [],
- finding_id=finding_id,
- recommendation_task_id=recommendation_task_id,
- # TODO: fix category changes
- category=(
- f.category.affected_component.value
- if f.category and f.category.affected_component
- else None
- ),
+ if generate_aggregate_solutions:
+ # after all the findings are processed, we need to group them
+
+ findingGrouper = FindingGrouper(vulnerability_report, ollama_strategy)
+ findingGrouper.generate_aggregated_solutions()
+ # map findings to db ids for aggregated solutions
+ aggregated_solutions_with_db_ids = [
+ AggregatedSolutionInput(
+ findings_db_ids=[
+ finding_id_map[finding.id] for finding in solution.findings
+ ],
+ solution=solution,
)
- session.add(recommendation)
- ## updat recommendation task status
- recommendation_task = (
- session.query(db_models.RecommendationTask)
- .filter(db_models.RecommendationTask.id == recommendation_task_id)
- .first()
+ for solution in vulnerability_report.aggregated_solutions
+ ]
+ with Session() as session:
+ recommendationRepo = RecommendationRepository(session)
+ recommendationRepo.create_recommendations(
+ list(zip(finding_ids, vulnerability_report.findings)),
+ recommendation_task_id,
+ )
+
+ if generate_aggregate_solutions:
+ # save aggregated solutions
+ recommendationRepo = RecommendationRepository(session)
+ recommendationRepo.create_aggregated_solutions(
+ input=CreateAggregatedRecommendationInput(
+ aggregated_solutions=aggregated_solutions_with_db_ids,
+ recommendation_task_id=recommendation_task_id,
+ )
)
- recommendation_task.status = db_models.TaskStatus.COMPLETED
- session.commit()
+ # finally update the task status
+ recommendationRepo = TaskRepository(session)
+ recommendationRepo.update_task_completed(recommendation_task_id)