Skip to content

Commit

Permalink
Merge branch 'main' into feat/dashboard-add-aggregatedSolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
niklas1531 committed Aug 1, 2024
2 parents 37d1e3e + 2c1e7c2 commit ab25a2d
Show file tree
Hide file tree
Showing 25 changed files with 795 additions and 274 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: API Docs

on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
generate-docs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Set up Python 3.11
uses: actions/setup-python@v2
with:
python-version: 3.11

- name: Install dependencies
run: |
pip install pipenv
pipenv install
- name: Generate docs
run: |
pipenv run python src/extract-docs.py
- name: Save docs as artifact
uses: actions/upload-artifact@v2
with:
name: docs
path: .docs
if-no-files-found: error
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ After starting the application, you can access the API at `http://localhost:8000
### Docker

building the images

```bash
docker compose build
```
Expand All @@ -64,23 +65,23 @@ To run the code within Docker, use the following command in the root directory o
docker compose up
```



Add the `-d` flag to run the containers in the background: `docker compose up -d`.




### Available Routes

Currently, these routes are generated by fastapi.

```
HEAD, GET /openapi.json
HEAD, GET /docs
HEAD, GET /docs/oauth2-redirect
HEAD, GET /redoc
GET /api/v1/tasks/
DELETE /api/v1/tasks/{task_id}
DELETE /api/v1/tasks/
GET /api/v1/tasks/{task_id}/status
POST /api/v1/recommendations/
POST /api/v1/recommendations/aggregated
POST /api/v1/upload/
GET /
```
Expand Down
8 changes: 6 additions & 2 deletions src/ai/Grouping/FindingGrouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


class FindingGrouper:
def __init__(self, vulnerability_report: VulnerabilityReport, llm_service: BaseLLMService):
def __init__(
self, vulnerability_report: VulnerabilityReport, llm_service: BaseLLMService
):
self.vulnerability_report = vulnerability_report
self.llm_service = llm_service
self.batcher = FindingBatcher(llm_service)
Expand All @@ -20,5 +22,7 @@ def generate_aggregated_solutions(self):
for batch in tqdm(self.batches, desc="Generating Aggregated Solutions"):
result_list = self.llm_service.generate_aggregated_solution(batch)
for result in result_list:
self.aggregated_solutions.append(AggregatedSolution(result[1], result[0], result[2])) # Solution, Findings, Metadata
self.aggregated_solutions.append(
AggregatedSolution().from_result(result[1], result[0], result[2])
) # Solution, Findings, Metadata
self.vulnerability_report.set_aggregated_solutions(self.aggregated_solutions)
1 change: 1 addition & 0 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ai.LLM.Strategies.OLLAMAService
from config import config


import routes
import routes.v1.recommendations
import routes.v1.task
Expand Down
56 changes: 15 additions & 41 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,33 @@
from pydantic import (
BaseModel,
Field,
RedisDsn,
ValidationInfo,
model_validator,
root_validator,
)

from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field, RedisDsn, field_validator
from pydantic import Field, field_validator

from typing import Optional


class SubModel(BaseModel):
foo: str = "bar"
apple: int = 1


class Config(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")

ollama_url: str = Field(
json_schema_extra="OLLAMA_URL", default="http://localhost:11434"
)
ollama_model: str = Field(json_schema_extra="OLLAMA_MODEL", default="phi3:mini")
ollama_url: str = Field(default="http://localhost:11434")
ollama_model: str = Field(default="phi3:mini")

ai_strategy: Optional[str] = Field(
json_schema_extra="AI_STRATEGY", default="OLLAMA"
)
anthropic_api_key: Optional[str] = Field(
json_schema_extra="ANTHROPIC_API_KEY", default=None
)
openai_api_key: Optional[str] = Field(
json_schema_extra="OPENAI_API_KEY", default=None
)
ai_strategy: Optional[str] = Field(default="OLLAMA")
anthropic_api_key: Optional[str] = Field(default=None)
openai_api_key: Optional[str] = Field(default=None)

postgres_server: str = Field(
json_schema_extra="POSTGRES_SERVER", default="localhost"
)
postgres_port: int = Field(json_schema_extra="POSTGRES_PORT", default=5432)
postgres_db: str = Field(json_schema_extra="POSTGRES_DB", default="app")
postgres_user: str = Field(json_schema_extra="POSTGRES_USER", default="postgres")
postgres_password: str = Field(
json_schema_extra="POSTGRES_PASSWORD", default="postgres"
)
queue_processing_limit: int = Field(
json_schema_extra="QUEUE_PROCESSING_LIMIT", default=10
)
redis_endpoint: str = Field(
json_schema_extra="REDIS_ENDPOINT", default="redis://localhost:6379/0"
)
environment: str = Field(env="ENVIRONMENT", default="development")
db_debug: bool = Field(env="DB_DEBUG", default=False)
postgres_server: str = Field(default="localhost")
postgres_port: int = Field(default=5432)
postgres_db: str = Field(default="app")
postgres_user: str = Field(default="postgres")
postgres_password: str = Field(default="postgres")
queue_processing_limit: int = Field(default=10)
redis_endpoint: str = Field(default="redis://localhost:6379/0")
environment: str = Field(default="development")
db_debug: bool = Field(default=False)

@field_validator(
"ai_strategy",
Expand All @@ -66,7 +41,6 @@ def check_ai_strategy(cls, ai_strategy, values):

@field_validator("openai_api_key")
def check_api_key(cls, api_key, info: ValidationInfo):
print(info)
if info.data["ai_strategy"] == "OPENAI" and not api_key:
raise ValueError("OPENAI_API_KEY is required when ai_strategy is OPENAI")
return api_key
Expand Down
11 changes: 6 additions & 5 deletions src/data/AggregatedSolution.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from typing import List

from data.Finding import Finding
from db.base import BaseModel
from pydantic import BaseModel


class AggregatedSolution:
class AggregatedSolution(BaseModel):
findings: List[Finding] = None
solution: str = ""
metadata: dict = {}

def __init__(self, findings: List[Finding], solution: str, metadata=None):
def from_result(self, findings: List[Finding], solution: str, metadata=None):
self.findings = findings
self.solution = solution
self.metadata = metadata
return self

def __str__(self):
return self.solution
Expand All @@ -21,8 +22,8 @@ def to_dict(self):
return {
"findings": [finding.to_dict() for finding in self.findings],
"solution": self.solution,
"metadata": self.metadata
"metadata": self.metadata,
}

def to_html(self):
return f"<p>{self.solution}</p>"
return f"<p>{self.solution}</p>"
62 changes: 44 additions & 18 deletions src/data/Finding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from typing import List, Set, Optional, Any, get_args
from enum import Enum, auto
import uuid

from pydantic import BaseModel, Field, PrivateAttr
from data.Solution import Solution
from data.Categories import Category, TechnologyStack, SecurityAspect, SeverityLevel, RemediationType, \
AffectedComponent, Compliance, Environment
from data.Categories import (
Category,
TechnologyStack,
SecurityAspect,
SeverityLevel,
RemediationType,
AffectedComponent,
Compliance,
Environment,
)
import json

import logging
Expand All @@ -12,6 +22,7 @@


class Finding(BaseModel):
id: str = Field(default_factory=lambda: f"{str(uuid.uuid4())}")
title: List[str] = Field(default_factory=list)
source: Set[str] = Field(default_factory=set)
descriptions: List[str] = Field(default_factory=list)
Expand All @@ -21,7 +32,7 @@ class Finding(BaseModel):
severity: Optional[int] = None
priority: Optional[int] = None
location_list: List[str] = Field(default_factory=list)
category: Category = None
category: Optional[Category] = None
unsupervised_cluster: Optional[int] = None
solution: Optional["Solution"] = None
_llm_service: Optional[Any] = PrivateAttr(default=None)
Expand All @@ -35,7 +46,9 @@ def combine_descriptions(self) -> "Finding":
logger.error("LLM Service not set, cannot combine descriptions.")
return self

self.description = self.llm_service.combine_descriptions(self.descriptions, self.cve_ids, self.cwe_ids)
self.description = self.llm_service.combine_descriptions(
self.descriptions, self.cve_ids, self.cwe_ids
)
return self

def add_category(self) -> "Finding":
Expand All @@ -47,34 +60,45 @@ def add_category(self) -> "Finding":

# Classify technology stack
technology_stack_options = list(TechnologyStack)
self.category.technology_stack = self.llm_service.classify_kind(self, "technology_stack",
technology_stack_options)
self.category.technology_stack = self.llm_service.classify_kind(
self, "technology_stack", technology_stack_options
)

# Classify security aspect
security_aspect_options = list(SecurityAspect)
self.category.security_aspect = self.llm_service.classify_kind(self, "security_aspect", security_aspect_options)
self.category.security_aspect = self.llm_service.classify_kind(
self, "security_aspect", security_aspect_options
)

# Classify severity level
severity_level_options = list(SeverityLevel)
self.category.severity_level = self.llm_service.classify_kind(self, "severity_level", severity_level_options)
self.category.severity_level = self.llm_service.classify_kind(
self, "severity_level", severity_level_options
)

# Classify remediation type
remediation_type_options = list(RemediationType)
self.category.remediation_type = self.llm_service.classify_kind(self, "remediation_type",
remediation_type_options)
self.category.remediation_type = self.llm_service.classify_kind(
self, "remediation_type", remediation_type_options
)

# Classify affected component
affected_component_options = list(AffectedComponent)
self.category.affected_component = self.llm_service.classify_kind(self, "affected_component",
affected_component_options)
self.category.affected_component = self.llm_service.classify_kind(
self, "affected_component", affected_component_options
)

# Classify compliance
compliance_options = list(Compliance)
self.category.compliance = self.llm_service.classify_kind(self, "compliance", compliance_options)
self.category.compliance = self.llm_service.classify_kind(
self, "compliance", compliance_options
)

# Classify environment
environment_options = list(Environment)
self.category.environment = self.llm_service.classify_kind(self, "environment", environment_options)
self.category.environment = self.llm_service.classify_kind(
self, "environment", environment_options
)

return self

Expand Down Expand Up @@ -213,17 +237,19 @@ def to_html(self, table=False):
result += "<tr><th>Name</th><th>Value</th></tr>"
result += f"<tr><td>Title</td><td>{', '.join(self.title)}</td></tr>"
result += f"<tr><td>Source</td><td>{', '.join(self.source)}</td></tr>"
result += (
f"<tr><td>Description</td><td>{self.description}</td></tr>"
)
result += f"<tr><td>Description</td><td>{self.description}</td></tr>"
if len(self.location_list) > 0:
result += f"<tr><td>Location List</td><td>{' & '.join(map(str, self.location_list))}</td></tr>"
result += f"<tr><td>CWE IDs</td><td>{', '.join(self.cwe_ids)}</td></tr>"
result += f"<tr><td>CVE IDs</td><td>{', '.join(self.cve_ids)}</td></tr>"
result += f"<tr><td>Severity</td><td>{self.severity}</td></tr>"
result += f"<tr><td>Priority</td><td>{self.priority}</td></tr>"
if self.category is not None:
result += '<tr><td>Category</td><td>' + str(self.category).replace("\n", "<br />") + '</td></tr>'
result += (
"<tr><td>Category</td><td>"
+ str(self.category).replace("\n", "<br />")
+ "</td></tr>"
)
if self.unsupervised_cluster is not None:
result += f"<tr><td>Unsupervised Cluster</td><td>{self.unsupervised_cluster}</td></tr>"
result += "</table>"
Expand Down
Loading

0 comments on commit ab25a2d

Please sign in to comment.