From a172c37c29c59d813bd1c4916b8cecced611349e Mon Sep 17 00:00:00 2001 From: ishwor Giri Date: Wed, 7 Aug 2024 00:05:40 +0200 Subject: [PATCH 1/3] add test for api and llm --- README.md | 4 + documentation/00 - introduction.md | 4 +- documentation/04-testing.md | 68 ++++++ documentation/FAQ.md | 22 +- requirements.txt | 25 --- run | 4 - run.sh | 4 - src/data/VulnerabilityReport.py | 2 +- src/data/repository/finding.py | 2 - src/db/models.py | 14 +- src/repository/finding.py | 9 +- src/repository/recommendation.py | 2 +- src/routes/v1/recommendations.py | 14 +- src/test/mock_llm.py | 165 +++++++++++++++ src/test/test_db.py | 70 ------- src/test/test_db_api.py | 322 +++++++++++++++++++++++++++++ src/test/test_llm.py | 58 ++++++ 17 files changed, 656 insertions(+), 133 deletions(-) create mode 100644 documentation/04-testing.md delete mode 100644 requirements.txt delete mode 100755 run delete mode 100644 run.sh delete mode 100644 src/data/repository/finding.py create mode 100644 src/test/mock_llm.py delete mode 100644 src/test/test_db.py create mode 100644 src/test/test_db_api.py create mode 100644 src/test/test_llm.py diff --git a/README.md b/README.md index 19be203..baa1ac8 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,10 @@ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file This Project was created in the context of TUMs practical course [Digital Product Innovation and Development](https://www.fortiss.org/karriere/digital-product-innovation-and-development) by fortiss in the summer semester 2024. The task was suggested by Siemens, who also supervised the project. +## FAQ + +Please refer to [FAQ](documentation/FAQ.md) + ## Contact To contact fortiss or Siemens, please refer to their official websites. diff --git a/documentation/00 - introduction.md b/documentation/00 - introduction.md index de948ff..19d066a 100644 --- a/documentation/00 - introduction.md +++ b/documentation/00 - introduction.md @@ -67,7 +67,9 @@ For more detailed information about the API routes, refer to the [API Routes](ap - **[Prerequisites](01%20-%20prerequisites):** Environment setup and dependencies. - **[Installation](02%20-%20installation):** Step-by-step installation guide. -- **[Usage](04%20-%20usage):** Instructions on how to use the system. +- **[Usage](03%20-%20usage):** Instructions on how to use the system. + +- **[Testing](04-testing):** Provides an explanation of the testing strategy and rules. --- diff --git a/documentation/04-testing.md b/documentation/04-testing.md new file mode 100644 index 0000000..4ff1c67 --- /dev/null +++ b/documentation/04-testing.md @@ -0,0 +1,68 @@ +# Testing + +We have a basic testing setup in place that covers database insertion and API retrieval. + +The tests are located inside the `src/tests` directory and can be executed using the following command: + +```bash +pipenv run pytest +#or +pytest +``` + +We also have a GitHub test workflow set up in `test.yml` that runs on every pull request and merge on the main branch to ensure that the tests are passing. + +## Adding new Tests + +New tests can also be added to the folder `src/tests`. + +The files must be prefixed with `test_` for pytest to recognize them. + +A testing function should also be created with the prefix `test_`. + +eg: + +```python +def test_create_get_task_integration(): + assert 1+1=2 +``` + +We use dependency overriding of Repositories with different databases to ensure that they do not interfere with your actual database. The dependencies can be overridden as shown below. + +```python + +from app import app + + +SQLALCHEMY_DATABASE_URL = "sqlite://" + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +db_models.BaseModel.metadata.create_all(bind=engine) + + +def override_get_db(): + try: + db = TestingSessionLocal() + yield db + finally: + db.close() + + +def override_get_task_repository(session: Session = Depends(override_get_db)): + return TaskRepository(session) + +app.dependency_overrides[get_task_repository] = override_get_task_repository + +``` + +For a comprehensive guide on testing with FastAPI, please refer to the [FASTAPI Testing](https://fastapi.tiangolo.com/tutorial/testing/) documentation. + +Note: +It is always challenging to perform integration testing, especially when dealing with LLMS and queues. However, the API endpoints have been thoroughly tested to ensure accurate responses. diff --git a/documentation/FAQ.md b/documentation/FAQ.md index 5d270c6..8942908 100644 --- a/documentation/FAQ.md +++ b/documentation/FAQ.md @@ -1,25 +1,29 @@ # Frequently Asked Questions -A compiled list of FAQs that may come in handy. +A compiled list of frequently asked questions that may come in handy. -## Import ot found? +## Import not found? -If you are getting an "import not found" error, it is likely because the base folder is always `src`. Always run the program from the `src` folder or use the commands inside the Makefile. +If you are encountering an "import not found" error, it is likely because the base folder is always `src`. Make sure to run the program from the `src` folder or use the commands inside the Makefile. -If you have something in `src/lib` and want to use it, import it as follows: +If you have something in `src/lib` and want to use it, import it as shown below: ```python import lib # or from lib import ... ``` +# Why use POST call to retrieve recommendations and aggregated recommendations? + +For these calls, we require a JSON type body with at least `{}`. This allows us to handle nested filters such as `task_id` and `severity` inside the filters. Using query parameters and GET calls would be less suitable for this purpose. However, one can modify the pathname, like changing `get-`, to make it more convenient. + # Why are env.docker and .env different? -If you are only running the program using Docker, then you only need to worry about `.env.docker`. +If you are running the program exclusively using Docker, then you only need to concern yourself with `.env.docker`. -As the addresses tend to be different in a Docker environment compared to a local environment, you need different values to resolve the addresses. +Since the addresses can differ between a Docker environment and a local environment, you need different values to resolve the addresses. -For example, if you have your program outside Docker (locally) and want to access a database, you may use: +For example, if your program is outside Docker (locally) and you want to access a database, you may use: ``` POSTGRES_SERVER=localhost @@ -50,3 +54,7 @@ We have a predefined structure that input must adhere to called `Content`. You c Inside the db model called `Findings`, there is a method `from_data` which can be modified to adapt the changes. `VulnerablityReport` also has `create_from_flama_json` that must be adjusted accordingly to make sure the Generation side also works. + +# Does it make sense to Mock LLM? + +While we don't strive for accuracy, it would still make sense to mock LLM methods to ensure that methods for finding and interacting properly with LLM class methods work correctly. Nevertheless, it is still difficult to extract meaningful test outputs based on only prompts as input. diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 6217590..0000000 --- a/requirements.txt +++ /dev/null @@ -1,25 +0,0 @@ -#LEGACY : use pipfile instead - -pydantic -pandas -sqlmodel~=0.0.18 -alembic~=1.13.1 -psycopg -psycopg-binary -fastapi==0.109.1 -uvicorn==0.15.0 -python-dotenv -jsonschema==4.22.0 -SQLAlchemy~=2.0.30 -tqdm~=4.66.4 -httpx~=0.27.0 -tenacity~=8.3.0 -celery[redis]==5.2.2 -openai~=1.33.0 # only if you plan to use the OpenAI Service -anthropic~=0.28.0 # only if you plan to use the Anthropic Service -pydantic-settings~=2.3.2 - -sentence_transformers # only if you plan to use unsupervised clustering -kneed # only if you plan to use unsupervised clustering -scikit-learn # only if you plan to use unsupervised clustering -plotly # only if you plan to use unsupervised clustering \ No newline at end of file diff --git a/run b/run deleted file mode 100755 index 49e3add..0000000 --- a/run +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -source venv/bin/activate -uvicorn src.app:app --host 0.0.0.0 --port 8000 \ No newline at end of file diff --git a/run.sh b/run.sh deleted file mode 100644 index 49e3add..0000000 --- a/run.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -source venv/bin/activate -uvicorn src.app:app --host 0.0.0.0 --port 8000 \ No newline at end of file diff --git a/src/data/VulnerabilityReport.py b/src/data/VulnerabilityReport.py index 21a7346..b67bb6a 100644 --- a/src/data/VulnerabilityReport.py +++ b/src/data/VulnerabilityReport.py @@ -30,7 +30,7 @@ def set_llm_service(self, llm_service: "LLMServiceStrategy"): finding.llm_service = llm_service return self - def add_finding(self, finding): + def add_finding(self, finding: Finding): self.findings.append(finding) def get_findings(self): diff --git a/src/data/repository/finding.py b/src/data/repository/finding.py deleted file mode 100644 index d521f6f..0000000 --- a/src/data/repository/finding.py +++ /dev/null @@ -1,2 +0,0 @@ -class FindingRepository: - pass diff --git a/src/db/models.py b/src/db/models.py index 5057b0d..a9718ec 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -102,23 +102,25 @@ class Finding(BaseModel): def from_data(self, data: Content): self.cve_id_list = ( - [x.dict() for x in data.cve_id_list] if data.cve_id_list else [] + [x.model_dump() for x in data.cve_id_list] if data.cve_id_list else [] ) self.description_list = ( - [x.dict() for x in data.description_list] if data.description_list else [] + [x.model_dump() for x in data.description_list] + if data.description_list + else [] ) - self.title_list = [x.dict() for x in data.title_list] + self.title_list = [x.model_dump() for x in data.title_list] self.locations_list = ( - [x.dict() for x in data.location_list] if data.location_list else [] + [x.model_dump() for x in data.location_list] if data.location_list else [] ) - self.raw_data = data.dict() + self.raw_data = data.model_dump() self.severity = data.severity self.priority = data.priority self.report_amount = data.report_amount return self def __repr__(self): - return f"" + return f"" class TaskStatus(PyEnum): diff --git a/src/repository/finding.py b/src/repository/finding.py index db2cb27..e390435 100644 --- a/src/repository/finding.py +++ b/src/repository/finding.py @@ -76,9 +76,9 @@ def get_findings_by_task_id_and_filter( ) total = query.count() - findings = query.all() if pagination: query = query.offset(pagination.offset).limit(pagination.limit) + findings = query.all() return findings, total @@ -95,10 +95,9 @@ def get_findings_count_by_task_id(self, task_id: int) -> int: return count - def create_findings( - self, findings: list[db_models.Finding] - ) -> list[db_models.Finding]: - self.session.bulk_save_objects(findings) + def create_findings(self, findings: list[db_models.Finding]): + self.session.add_all(findings) + self.session.commit() diff --git a/src/repository/recommendation.py b/src/repository/recommendation.py index 5f2147e..51406d5 100644 --- a/src/repository/recommendation.py +++ b/src/repository/recommendation.py @@ -63,7 +63,7 @@ def create_recommendations( else "No long description available" ), meta=f.solution.metadata if f.solution.metadata else {}, - search_terms=f.solution.search_terms if f.solution.search_terms else [], + search_terms=f.solution.search_terms if f.solution.search_terms else "", finding_id=finding_id, recommendation_task_id=recommendation_task_id, category=(f.category.model_dump_json() if f.category else None), diff --git a/src/routes/v1/recommendations.py b/src/routes/v1/recommendations.py index cf39f92..4ef6cd6 100644 --- a/src/routes/v1/recommendations.py +++ b/src/routes/v1/recommendations.py @@ -1,19 +1,19 @@ import datetime from typing import Annotated, Optional -from fastapi import Body, Depends, HTTPException, Response +from fastapi import Body, Depends, HTTPException from fastapi.routing import APIRouter -from sqlalchemy import Date, cast from sqlalchemy.orm import Session import data.apischema as apischema import db.models as db_models -from data.AggregatedSolution import AggregatedSolution from db.my_db import get_db from dto.finding import db_finding_to_response_item from repository.finding import get_finding_repository -from repository.recommendation import (RecommendationRepository, - get_recommendation_repository) +from repository.recommendation import ( + RecommendationRepository, + get_recommendation_repository, +) from repository.task import TaskRepository, get_task_repository from repository.types import GetFindingsByFilterInput @@ -102,8 +102,8 @@ def aggregated_solutions( task = None if request.filter and request.filter.task_id: task = task_repository.get_task_by_id(request.filter.task_id) - task = task_repository.get_task_by_date(today) - + else: + task = task_repository.get_task_by_date(today) if not task: raise HTTPException( status_code=404, diff --git a/src/test/mock_llm.py b/src/test/mock_llm.py new file mode 100644 index 0000000..84519ed --- /dev/null +++ b/src/test/mock_llm.py @@ -0,0 +1,165 @@ +# This tests the interworkings of the LLMServiceStrategy, VulnerabilityReport, and Finding classes. +# This does not necessarily test the functionality of the LLM Models, but rather the interactions between the classes. + +from ai.LLM import BaseLLMService, LLMServiceMixin +from data.Finding import Finding + + +from typing import List, Dict, Union, Tuple + +from ai.LLM.BaseLLMService import BaseLLMService +from ai.LLM.LLMServiceMixin import LLMServiceMixin +from utils.text_tools import clean +from data.Finding import Finding +from ai.LLM.Strategies.ollama_prompts import ( + CLASSIFY_KIND_TEMPLATE, + SHORT_RECOMMENDATION_TEMPLATE, + LONG_RECOMMENDATION_TEMPLATE, + META_PROMPT_GENERATOR_TEMPLATE, + GENERIC_LONG_RECOMMENDATION_TEMPLATE, + SEARCH_TERMS_TEMPLATE, + AGGREGATED_SOLUTION_TEMPLATE, + SUBDIVISION_PROMPT_TEMPLATE, + answer_in_json_prompt, +) + + +class MockLLMService(BaseLLMService, LLMServiceMixin): + def __init__( + self, + ): + + self.model_url = "mock" + self.model_name = "mock_model" + self.context_size = 8000 + + LLMServiceMixin.__init__( + self, {"model_url": self.model_url, "model_name": self.model_name} + ) + + self.pull_url: str = self.model_url + "/api/pull" + self.generate_url: str = self.model_url + "/api/generate" + self.generate_payload: Dict[str, Union[str, bool]] = { + "model": self.model_name, + "stream": False, + "format": "json", + } + + def init_pull_model(self) -> None: + pass + + def get_model_name(self) -> str: + return self.model_name + + def get_context_size(self) -> int: + return self.context_size + + def get_url(self) -> str: + return self.generate_url + + def _generate(self, prompt: str, json=True) -> Dict[str, str]: + print(prompt) + print(f"{answer_in_json_prompt('recommendation')}".format()) + if f"{answer_in_json_prompt('combined_description').format()}" in prompt: + return {"combined_description": "combined_description"} + if f"{answer_in_json_prompt('selected_option').format()}" in prompt: + return {"selected_option": "JavaScript"} + if f"{answer_in_json_prompt('recommendation').format()}" in prompt: + return {"recommendation": "recommendation_response"} + + if f"{answer_in_json_prompt('search_terms').format()}" in prompt: + return {"search_terms": "search_terms_response"} + if "Convert the following dictionary to a descriptive string" in prompt: + return {"response": "dict_to_str_response"} + + return {} + + def _get_classification_prompt( + self, options: str, field_name: str, finding_str: str + ) -> str: + return CLASSIFY_KIND_TEMPLATE.format( + options=options, field_name=field_name, data=finding_str + ) + + def _get_recommendation_prompt(self, finding: Finding, short: bool) -> str: + if short: + return SHORT_RECOMMENDATION_TEMPLATE.format(data=str(finding)) + elif finding.solution and finding.solution.short_description: + finding.solution.add_to_metadata("used_meta_prompt", True) + return self._generate_prompt_with_meta_prompts(finding) + else: + return GENERIC_LONG_RECOMMENDATION_TEMPLATE.format() + + def _process_recommendation_response( + self, response: Dict[str, str], finding: Finding, short: bool + ) -> Union[str, List[str]]: + if "recommendation" not in response: + + return "" if short else [""] + return clean(response["recommendation"], llm_service=self) + + def _generate_prompt_with_meta_prompts(self, finding: Finding) -> str: + short_recommendation = finding.solution.short_description + meta_prompt_generator = META_PROMPT_GENERATOR_TEMPLATE.format( + finding=str(finding) + ) + meta_prompt_response = self.generate(meta_prompt_generator) + meta_prompts = clean( + meta_prompt_response.get("meta_prompts", ""), llm_service=self + ) + + long_prompt = LONG_RECOMMENDATION_TEMPLATE.format(meta_prompts=meta_prompts) + + finding.solution.add_to_metadata( + "prompt_long_breakdown", + { + "short_recommendation": short_recommendation, + "meta_prompts": meta_prompts, + }, + ) + + return long_prompt + + def _get_search_terms_prompt(self, finding: Finding) -> str: + return SEARCH_TERMS_TEMPLATE.format(data=str(finding)) + + def _process_search_terms_response( + self, response: Dict[str, str], finding: Finding + ) -> str: + if "search_terms" not in response: + + return "" + return clean(response["search_terms"], llm_service=self) + + def _get_subdivision_prompt(self, findings: List[Finding]) -> str: + findings_str = self._get_findings_str_for_aggregation(findings) + return SUBDIVISION_PROMPT_TEMPLATE.format(data=findings_str) + + def _process_subdivision_response( + self, response: Dict, findings: List[Finding] + ) -> List[Tuple[List[Finding], Dict]]: + pass + + def _get_aggregated_solution_prompt( + self, findings: List[Finding], meta_info: Dict + ) -> str: + findings_str = self._get_findings_str_for_aggregation(findings, details=True) + + return AGGREGATED_SOLUTION_TEMPLATE.format( + data=findings_str, meta_info=meta_info.get("reason", "") + ) + + def _process_aggregated_solution_response(self, response: Dict[str, str]) -> str: + if "aggregated_solution" not in response: + raise "Failed to generate an aggregated solution" + return clean(response["aggregated_solution"], llm_service=self) + + def convert_dict_to_str(self, data: Dict) -> str: + + return LLMServiceMixin.convert_dict_to_str(self, data) + + def combine_descriptions( + self, + descriptions: List[str], + ) -> str: + return "combined_description" diff --git a/src/test/test_db.py b/src/test/test_db.py deleted file mode 100644 index a106575..0000000 --- a/src/test/test_db.py +++ /dev/null @@ -1,70 +0,0 @@ -from fastapi import Depends -from fastapi.testclient import TestClient -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import StaticPool - -from sqlalchemy.orm.session import Session - -import db.models as db_models -from repository.finding import FindingRepository, get_finding_repository -from repository.recommendation import ( - RecommendationRepository, - get_recommendation_repository, -) -from repository.task import TaskRepository, get_task_repository -from app import app - -SQLALCHEMY_DATABASE_URL = "sqlite://" - -engine = create_engine( - SQLALCHEMY_DATABASE_URL, - connect_args={"check_same_thread": False}, - poolclass=StaticPool, -) -TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -db_models.BaseModel.metadata.create_all(bind=engine) - - -def override_get_db(): - try: - db = TestingSessionLocal() - yield db - finally: - db.close() - - -def override_get_task_repository(session: Session = Depends(override_get_db)): - return TaskRepository(session) - - -def override_get_finding_repository(session: Session = Depends(override_get_db)): - return FindingRepository(session) - - -def override_get_recommendation_repository(session: Session = Depends(override_get_db)): - return RecommendationRepository(session) - - -app.dependency_overrides[get_task_repository] = override_get_task_repository -app.dependency_overrides[get_finding_repository] = override_get_finding_repository -app.dependency_overrides[get_recommendation_repository] = ( - override_get_recommendation_repository -) -client = TestClient(app) - - -def test_create_get_task_integration(): - - with TestingSessionLocal() as session: - task_repo = TaskRepository(session=session) - task = task_repo.create_task() - task_repo.get_task_by_id - - response = client.get( - "api/v1/tasks/", - ) - assert response.status_code == 200 - assert len(response.json()) == 1 diff --git a/src/test/test_db_api.py b/src/test/test_db_api.py new file mode 100644 index 0000000..c66bc36 --- /dev/null +++ b/src/test/test_db_api.py @@ -0,0 +1,322 @@ +import json +from fastapi import Depends +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool +from data.AggregatedSolution import AggregatedSolution +from data.Solution import Solution +from data.types import Content +from sqlalchemy.orm.session import Session + +from data.Finding import Finding +import db.models as db_models +from repository.finding import FindingRepository, get_finding_repository +from repository.recommendation import ( + RecommendationRepository, + get_recommendation_repository, +) +from repository.task import TaskRepository, get_task_repository +from app import app +from repository.types import ( + AggregatedSolutionInput, + CreateAggregatedRecommendationInput, +) + +SQLALCHEMY_DATABASE_URL = "sqlite://" + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +db_models.BaseModel.metadata.create_all(bind=engine) + + +def override_get_db(): + try: + db = TestingSessionLocal() + yield db + finally: + db.close() + + +def override_get_task_repository(session: Session = Depends(override_get_db)): + return TaskRepository(session) + + +def override_get_finding_repository(session: Session = Depends(override_get_db)): + return FindingRepository(session) + + +def override_get_recommendation_repository(session: Session = Depends(override_get_db)): + return RecommendationRepository(session) + + +app.dependency_overrides[get_task_repository] = override_get_task_repository +app.dependency_overrides[get_finding_repository] = override_get_finding_repository +app.dependency_overrides[get_recommendation_repository] = ( + override_get_recommendation_repository +) + + +client = TestClient(app) + + +def generate_findings(): + return [ + { + "doc_type": "summary", + "criticality_tag": ["unrestricted", 0], + "knowledge_type": "derived", + "requirement_list": [""], + "title_list": [ + { + "element": "finding_title_" + str(x), + "source": "Trivy", + } + ], + "location_list": [], + "description_list": [ + { + "element": "finding_description_" + str(x), + "source": "Trivy", + } + ], + "internal_rating_list": [], + "internal_ratingsource_list": [], + "cvss_rating_list": [], + "rule_list": [], + "cwe_id_list": [], + "cve_id_list": [], + "activity_list": [], + "first_found": "2023-08-24T08:32:33+00:00", + "last_found": "2023-08-25T15:48:01+00:00", + "report_amount": 4, + "content_hash": "", + "severity": x * 10, + "severity_explanation": "", + "priority": x * 10, + "priority_explanation": "", + "sum_id": "", + "prio_id": "", + "element_tag": "", + } + for x in range(10) + ] + + +def test_create_get_task_integration(): + + with TestingSessionLocal() as session: + task_repo = TaskRepository(session=session) + task = task_repo.create_task() + task_repo.get_task_by_id + + response = client.get( + "api/v1/tasks/", + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + + +def test_processing_recommendation(): + + with TestingSessionLocal() as session: + create_task = TaskRepository(session=session) + task = create_task.create_task() + + response = client.post( + "api/v1/recommendations/", json={"filter": {"task_id": task.id}} + ) + assert response.status_code == 400 + assert "detail" in response.json() + assert response.json()["detail"] == "Recommendation task is still processing" + + +def test_recommendation_done(): + with TestingSessionLocal() as session: + repo = TaskRepository(session=session) + task = repo.create_task() + repo.update_task_completed(task.id) + + response = client.post( + "api/v1/recommendations/", json={"filter": {"task_id": task.id}} + ) + assert response.status_code == 200 + assert "items" in response.json() + assert len(response.json()["items"]) == 0 + + +def test_recommendation_with_findings_and_solution(): + with TestingSessionLocal() as session: + repo = TaskRepository(session=session) + task = repo.create_task() + repo.update_task_completed(task.id) + findings = generate_findings() + + findings_db = [ + db_models.Finding().from_data( + Content.model_validate_json(json.dumps(finding)) + ) + for finding in findings + ] + + findingRepo = FindingRepository(session=session) + for finding in findings_db: + finding.recommendation_task_id = task.id + + findingRepo.create_findings(findings_db) + + # refresh finding ids + for finding in findings_db: + session.refresh(finding) + + recommendationRepo = RecommendationRepository(session=session) + recommendationRepo.create_recommendations( + list( + zip( + [finding.id for finding in findings_db], + [ + Finding( + solution=Solution( + short_description="short_description", + long_description="long_description", + ) + ) + for _ in findings_db + ], + ) + ), + recommendation_task_id=task.id, + ) + + response = client.post( + "api/v1/recommendations/", json={"filter": {"task_id": task.id}} + ) + assert response.status_code == 200 + assert "items" in response.json() + assert len(response.json()["items"]) == 10 + + assert "solution" in response.json()["items"][0] + + assert "short_description" in response.json()["items"][0]["solution"] + assert "long_description" in response.json()["items"][0]["solution"] + + +def test_recommendation_with_findings_filter(): + with TestingSessionLocal() as session: + repo = TaskRepository(session=session) + task = repo.create_task() + repo.update_task_completed(task.id) + findings = generate_findings() + + findings_db = [ + db_models.Finding().from_data( + Content.model_validate_json(json.dumps(finding)) + ) + for finding in findings + ] + findingRepo = FindingRepository(session=session) + for finding in findings_db: + finding.recommendation_task_id = task.id + + findingRepo.create_findings(findings_db) + + response = client.post( + "api/v1/recommendations/", + json={ + "filter": { + "task_id": task.id, + }, + "pagination": {"limit": 5, "offset": 0}, + }, + ) + assert response.status_code == 200 + assert "items" in response.json() + assert len(response.json()["items"]) == 5 + + response = client.post( + "api/v1/recommendations/", + json={ + "filter": { + "task_id": task.id, + "severity": {"minValue": 10, "maxValue": 20}, + }, + "pagination": {"limit": 5, "offset": 0}, + }, + ) + assert response.status_code == 200 + assert "items" in response.json() + assert len(response.json()["items"]) == 2 + + +def test_aggregated_solutions_response(): + with TestingSessionLocal() as session: + repo = TaskRepository(session=session) + task = repo.create_task() + repo.update_task_completed(task.id) + findings = generate_findings() + findings_db = [ + db_models.Finding().from_data( + Content.model_validate_json(json.dumps(finding)) + ) + for finding in findings + ] + findingRepo = FindingRepository(session=session) + for finding in findings_db: + finding.recommendation_task_id = task.id + + findingRepo.create_findings(findings_db) + + for finding in findings_db: + session.refresh(finding) + + recommendationRepo = RecommendationRepository(session=session) + + recommendationRepo.create_aggregated_solutions( + input=CreateAggregatedRecommendationInput( + aggregated_solutions=[ + AggregatedSolutionInput( + solution=AggregatedSolution( + findings=[], + solution="Aggregated Solution", + metadata={"meta": "data"}, + ), + findings_db_ids=[finding.id for finding in findings_db], + ) + ], + recommendation_task_id=task.id, + ) + ) + # TODO: test based on date. Doesn't work for now with sqlite + response = client.post( + "api/v1/recommendations/aggregated/", + json={ + "filter": { + "task_id": task.id, + }, + }, + ) + print(response.json()) + assert response.status_code == 200 + assert "items" in response.json() + assert len(response.json()["items"]) == 1 + + assert "solution" in response.json()["items"][0] + assert "findings" in response.json()["items"][0] + + # check if aggregated respionse has a correct solution + assert response.json()["items"][0]["solution"] == "Aggregated Solution" + assert len(response.json()["items"][0]["findings"]) == 10 + + # check if the format is correct + assert all( + [ + "finding_title_" in finding["title"][0] + for finding in response.json()["items"][0]["findings"] + ] + ) diff --git a/src/test/test_llm.py b/src/test/test_llm.py new file mode 100644 index 0000000..8f0426e --- /dev/null +++ b/src/test/test_llm.py @@ -0,0 +1,58 @@ +# This tests the interworkings of the LLMServiceStrategy, VulnerabilityReport, and Finding classes. + +from ai.LLM.LLMServiceStrategy import LLMServiceStrategy +from data.Finding import Finding +from data.VulnerabilityReport import VulnerabilityReport + +from .mock_llm import MockLLMService + + +# Test the LLMServiceStrategy together with VulnerabilityReport and Finding +def setup(): + llm = LLMServiceStrategy(MockLLMService()) + report = VulnerabilityReport() + oneFinding = Finding( + title=["title1", "title2"], descriptions=["description1", "description2"] + ) + oneFinding._llm_service = llm + report.add_finding(oneFinding) + + return llm, report, oneFinding + + +llm, report, oneFinding = setup() + + +def test_llm_setup(): + oneFinding.combine_descriptions() + assert oneFinding.description == "combined_description" + + assert report.findings[0].solution == None + + +def test_solution_short_generation(): + report.add_solution(short=True, long=False) + assert report.findings[0].solution.short_description == "recommendation_response" + assert report.findings[0].solution.long_description == None + + +def test_solution_long_generation(): + report.add_solution(short=False, long=True) + assert report.findings[0].solution.short_description == None + assert report.findings[0].solution.long_description == "recommendation_response" + + +def test_solution_short_long_generation(): + report.add_solution(short=True, long=True) + assert report.findings[0].solution.short_description == "recommendation_response" + assert report.findings[0].solution.long_description == "recommendation_response" + + +def test_solution_search_terms_generation(): + report.add_solution(short=True, long=True) + assert report.findings[0].solution.search_terms == "search_terms_response" + + +def test_classification(): + report.add_category() + assert report.findings[0].category.technology_stack.value == "JavaScript" From ab5e9d200a1cb90a8474913de129f77e75018d85 Mon Sep 17 00:00:00 2001 From: ishwor Giri Date: Wed, 7 Aug 2024 00:25:41 +0200 Subject: [PATCH 2/3] fix action --- .github/workflows/license-compliance.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/license-compliance.yml b/.github/workflows/license-compliance.yml index d67618e..dee2ea8 100644 --- a/.github/workflows/license-compliance.yml +++ b/.github/workflows/license-compliance.yml @@ -25,6 +25,8 @@ jobs: run: | python -m venv venv . venv/bin/activate + pip install pipenv + pipenv requirements > requirements.txt pip install -r requirements.txt - name: Check licenses From 253ecbe2d425f1db2c0de44802b627585139f201 Mon Sep 17 00:00:00 2001 From: ishwor Giri Date: Wed, 7 Aug 2024 00:51:26 +0200 Subject: [PATCH 3/3] fix input --- src/routes/v1/upload.py | 14 ++++++++------ src/worker/types.py | 8 ++++++++ src/worker/worker.py | 19 +++++++++++-------- 3 files changed, 27 insertions(+), 14 deletions(-) create mode 100644 src/worker/types.py diff --git a/src/routes/v1/upload.py b/src/routes/v1/upload.py index 02aedd9..4777913 100644 --- a/src/routes/v1/upload.py +++ b/src/routes/v1/upload.py @@ -13,6 +13,7 @@ from db.my_db import get_db from repository.finding import get_finding_repository from repository.task import TaskRepository, get_task_repository +from worker.types import GenerateReportInput from worker.worker import worker router = APIRouter(prefix="/upload") @@ -57,15 +58,16 @@ async def upload( find.recommendation_task_id = recommendation_task.id findings.append(find) finding_repository.create_findings(findings) + worker_input = GenerateReportInput( + recommendation_task_id=recommendation_task.id, + generate_long_solution=data.preferences.long_description or True, + generate_search_terms=data.preferences.search_terms or True, + generate_aggregate_solutions=data.preferences.aggregated_solutions or True, + ) celery_result = worker.send_task( "worker.generate_report", - args=[ - recommendation_task.id, - data.preferences.long_description, - data.preferences.search_terms, - data.preferences.aggregated_solutions, - ], + args=[worker_input.model_dump()], ) # update the task with the celery task id diff --git a/src/worker/types.py b/src/worker/types.py new file mode 100644 index 0000000..13504cb --- /dev/null +++ b/src/worker/types.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class GenerateReportInput(BaseModel): + recommendation_task_id: int + generate_long_solution: bool = True + generate_search_terms: bool = True + generate_aggregate_solutions: bool = True diff --git a/src/worker/worker.py b/src/worker/worker.py index 9994406..e309338 100644 --- a/src/worker/worker.py +++ b/src/worker/worker.py @@ -19,6 +19,7 @@ from ai.LLM.Strategies.OLLAMAService import OLLAMAService from data.VulnerabilityReport import create_from_flama_json from ai.Grouping.FindingGrouper import FindingGrouper +from worker.types import GenerateReportInput logger = logging.getLogger(__name__) @@ -41,16 +42,18 @@ def error(self, exc, task_id, args, kwargs, einfo): @worker.task(name="worker.generate_report", on_failure=error) def generate_report( - recommendation_task_id: int, - generate_long_solution: bool = True, - generate_search_terms: bool = True, - generate_aggregate_solutions: bool = True, + input: dict, ): + try: + input = GenerateReportInput.model_validate(input) + except Exception as e: + logger.error(f"Error validating input: {e}") + return - # importing here so importing worker does not import all the dependencies - # from ai.LLM.LLMServiceStrategy import LLMServiceStrategy - # from ai.LLM.Strategies.OLLAMAService import OLLAMAService - # from data.VulnerabilityReport import create_from_flama_json + recommendation_task_id = input.recommendation_task_id + generate_long_solution = input.generate_long_solution + generate_search_terms = input.generate_search_terms + generate_aggregate_solutions = input.generate_aggregate_solutions ollama_strategy = OLLAMAService() llm_service = LLMServiceStrategy(ollama_strategy)