Skip to content

Commit

Permalink
Merge pull request #55 from DigitalProductInnovationAndDevelopment/te…
Browse files Browse the repository at this point in the history
…st/llm

add test for api and llm
  • Loading branch information
ishworgiri1999 authored Aug 7, 2024
2 parents 1c69ac9 + 253ecbe commit fb3bfdd
Show file tree
Hide file tree
Showing 21 changed files with 685 additions and 147 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/license-compliance.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ jobs:
run: |
python -m venv venv
. venv/bin/activate
pip install pipenv
pipenv requirements > requirements.txt
pip install -r requirements.txt
- name: Check licenses
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file
This Project was created in the context of TUMs practical course [Digital Product Innovation and Development](https://www.fortiss.org/karriere/digital-product-innovation-and-development) by fortiss in the summer semester 2024.
The task was suggested by Siemens, who also supervised the project.

## FAQ

Please refer to [FAQ](documentation/FAQ.md)

## Contact

To contact fortiss or Siemens, please refer to their official websites.
Expand Down
4 changes: 3 additions & 1 deletion documentation/00 - introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ For more detailed information about the API routes, refer to the [API Routes](ap

- **[Prerequisites](01%20-%20prerequisites):** Environment setup and dependencies.
- **[Installation](02%20-%20installation):** Step-by-step installation guide.
- **[Usage](04%20-%20usage):** Instructions on how to use the system.
- **[Usage](03%20-%20usage):** Instructions on how to use the system.

- **[Testing](04-testing):** Provides an explanation of the testing strategy and rules.

---

Expand Down
68 changes: 68 additions & 0 deletions documentation/04-testing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Testing

We have a basic testing setup in place that covers database insertion and API retrieval.

The tests are located inside the `src/tests` directory and can be executed using the following command:

```bash
pipenv run pytest
#or
pytest
```

We also have a GitHub test workflow set up in `test.yml` that runs on every pull request and merge on the main branch to ensure that the tests are passing.

## Adding new Tests

New tests can also be added to the folder `src/tests`.

The files must be prefixed with `test_` for pytest to recognize them.

A testing function should also be created with the prefix `test_`.

eg:

```python
def test_create_get_task_integration():
assert 1+1=2
```

We use dependency overriding of Repositories with different databases to ensure that they do not interfere with your actual database. The dependencies can be overridden as shown below.

```python

from app import app


SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


db_models.BaseModel.metadata.create_all(bind=engine)


def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()


def override_get_task_repository(session: Session = Depends(override_get_db)):
return TaskRepository(session)

app.dependency_overrides[get_task_repository] = override_get_task_repository

```

For a comprehensive guide on testing with FastAPI, please refer to the [FASTAPI Testing](https://fastapi.tiangolo.com/tutorial/testing/) documentation.

Note:
It is always challenging to perform integration testing, especially when dealing with LLMS and queues. However, the API endpoints have been thoroughly tested to ensure accurate responses.
22 changes: 15 additions & 7 deletions documentation/FAQ.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
# Frequently Asked Questions

A compiled list of FAQs that may come in handy.
A compiled list of frequently asked questions that may come in handy.

## Import ot found?
## Import not found?

If you are getting an "import not found" error, it is likely because the base folder is always `src`. Always run the program from the `src` folder or use the commands inside the Makefile.
If you are encountering an "import not found" error, it is likely because the base folder is always `src`. Make sure to run the program from the `src` folder or use the commands inside the Makefile.

If you have something in `src/lib` and want to use it, import it as follows:
If you have something in `src/lib` and want to use it, import it as shown below:

```python
import lib # or
from lib import ...
```

# Why use POST call to retrieve recommendations and aggregated recommendations?

For these calls, we require a JSON type body with at least `{}`. This allows us to handle nested filters such as `task_id` and `severity` inside the filters. Using query parameters and GET calls would be less suitable for this purpose. However, one can modify the pathname, like changing `get-`, to make it more convenient.

# Why are env.docker and .env different?

If you are only running the program using Docker, then you only need to worry about `.env.docker`.
If you are running the program exclusively using Docker, then you only need to concern yourself with `.env.docker`.

As the addresses tend to be different in a Docker environment compared to a local environment, you need different values to resolve the addresses.
Since the addresses can differ between a Docker environment and a local environment, you need different values to resolve the addresses.

For example, if you have your program outside Docker (locally) and want to access a database, you may use:
For example, if your program is outside Docker (locally) and you want to access a database, you may use:

```
POSTGRES_SERVER=localhost
Expand Down Expand Up @@ -50,3 +54,7 @@ We have a predefined structure that input must adhere to called `Content`. You c

Inside the db model called `Findings`, there is a method `from_data` which can be modified to adapt the changes.
`VulnerablityReport` also has `create_from_flama_json` that must be adjusted accordingly to make sure the Generation side also works.

# Does it make sense to Mock LLM?

While we don't strive for accuracy, it would still make sense to mock LLM methods to ensure that methods for finding and interacting properly with LLM class methods work correctly. Nevertheless, it is still difficult to extract meaningful test outputs based on only prompts as input.
25 changes: 0 additions & 25 deletions requirements.txt

This file was deleted.

4 changes: 0 additions & 4 deletions run

This file was deleted.

4 changes: 0 additions & 4 deletions run.sh

This file was deleted.

2 changes: 1 addition & 1 deletion src/data/VulnerabilityReport.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def set_llm_service(self, llm_service: "LLMServiceStrategy"):
finding.llm_service = llm_service
return self

def add_finding(self, finding):
def add_finding(self, finding: Finding):
self.findings.append(finding)

def get_findings(self):
Expand Down
2 changes: 0 additions & 2 deletions src/data/repository/finding.py

This file was deleted.

14 changes: 8 additions & 6 deletions src/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,25 @@ class Finding(BaseModel):

def from_data(self, data: Content):
self.cve_id_list = (
[x.dict() for x in data.cve_id_list] if data.cve_id_list else []
[x.model_dump() for x in data.cve_id_list] if data.cve_id_list else []
)
self.description_list = (
[x.dict() for x in data.description_list] if data.description_list else []
[x.model_dump() for x in data.description_list]
if data.description_list
else []
)
self.title_list = [x.dict() for x in data.title_list]
self.title_list = [x.model_dump() for x in data.title_list]
self.locations_list = (
[x.dict() for x in data.location_list] if data.location_list else []
[x.model_dump() for x in data.location_list] if data.location_list else []
)
self.raw_data = data.dict()
self.raw_data = data.model_dump()
self.severity = data.severity
self.priority = data.priority
self.report_amount = data.report_amount
return self

def __repr__(self):
return f"<Finding {self.finding}>"
return f"<Finding {self.title_list}>"


class TaskStatus(PyEnum):
Expand Down
9 changes: 4 additions & 5 deletions src/repository/finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def get_findings_by_task_id_and_filter(
)

total = query.count()
findings = query.all()
if pagination:
query = query.offset(pagination.offset).limit(pagination.limit)
findings = query.all()

return findings, total

Expand All @@ -95,10 +95,9 @@ def get_findings_count_by_task_id(self, task_id: int) -> int:

return count

def create_findings(
self, findings: list[db_models.Finding]
) -> list[db_models.Finding]:
self.session.bulk_save_objects(findings)
def create_findings(self, findings: list[db_models.Finding]):
self.session.add_all(findings)

self.session.commit()


Expand Down
2 changes: 1 addition & 1 deletion src/repository/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def create_recommendations(
else "No long description available"
),
meta=f.solution.metadata if f.solution.metadata else {},
search_terms=f.solution.search_terms if f.solution.search_terms else [],
search_terms=f.solution.search_terms if f.solution.search_terms else "",
finding_id=finding_id,
recommendation_task_id=recommendation_task_id,
category=(f.category.model_dump_json() if f.category else None),
Expand Down
14 changes: 7 additions & 7 deletions src/routes/v1/recommendations.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import datetime
from typing import Annotated, Optional

from fastapi import Body, Depends, HTTPException, Response
from fastapi import Body, Depends, HTTPException
from fastapi.routing import APIRouter
from sqlalchemy import Date, cast
from sqlalchemy.orm import Session

import data.apischema as apischema
import db.models as db_models
from data.AggregatedSolution import AggregatedSolution
from db.my_db import get_db
from dto.finding import db_finding_to_response_item
from repository.finding import get_finding_repository
from repository.recommendation import (RecommendationRepository,
get_recommendation_repository)
from repository.recommendation import (
RecommendationRepository,
get_recommendation_repository,
)
from repository.task import TaskRepository, get_task_repository
from repository.types import GetFindingsByFilterInput

Expand Down Expand Up @@ -102,8 +102,8 @@ def aggregated_solutions(
task = None
if request.filter and request.filter.task_id:
task = task_repository.get_task_by_id(request.filter.task_id)
task = task_repository.get_task_by_date(today)

else:
task = task_repository.get_task_by_date(today)
if not task:
raise HTTPException(
status_code=404,
Expand Down
14 changes: 8 additions & 6 deletions src/routes/v1/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from db.my_db import get_db
from repository.finding import get_finding_repository
from repository.task import TaskRepository, get_task_repository
from worker.types import GenerateReportInput
from worker.worker import worker

router = APIRouter(prefix="/upload")
Expand Down Expand Up @@ -57,15 +58,16 @@ async def upload(
find.recommendation_task_id = recommendation_task.id
findings.append(find)
finding_repository.create_findings(findings)
worker_input = GenerateReportInput(
recommendation_task_id=recommendation_task.id,
generate_long_solution=data.preferences.long_description or True,
generate_search_terms=data.preferences.search_terms or True,
generate_aggregate_solutions=data.preferences.aggregated_solutions or True,
)

celery_result = worker.send_task(
"worker.generate_report",
args=[
recommendation_task.id,
data.preferences.long_description,
data.preferences.search_terms,
data.preferences.aggregated_solutions,
],
args=[worker_input.model_dump()],
)

# update the task with the celery task id
Expand Down
Loading

0 comments on commit fb3bfdd

Please sign in to comment.