Skip to content

Commit

Permalink
Add filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Nov 20, 2024
1 parent d6175ac commit 8de085a
Show file tree
Hide file tree
Showing 19 changed files with 22,611 additions and 2,146 deletions.
4 changes: 1 addition & 3 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
"DEBUG": "True",
"AUTO_CREATE_SCHEMA": "True",
"ENVIRONMENT": "local",
"LOGGING_LEVEL": "DEBUG",
"OIDC_CLIENT_SECRET": "uIeFiKFazNEIbpJ3wzj0lZLLSJXefeld",
"OIDC_CLIENT_ID": "AMT"
"LOGGING_LEVEL": "DEBUG"
}
},
{
Expand Down
19 changes: 11 additions & 8 deletions amt/repositories/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TaskRegistryRepository:

def __init__(self, client: TaskRegistryAPIClient) -> None:
self.client = client
self._urn_cache: dict[TaskType, list[str]] = {}

async def fetch_tasks(self, task_type: TaskType, urns: str | Sequence[str] | None = None) -> list[dict[str, Any]]:
"""
Expand All @@ -41,10 +40,8 @@ async def _fetch_valid_urns(self, task_type: TaskType) -> list[str]:
"""
Fetches all valid URNs for the given task type.
"""
if task_type not in self._urn_cache:
content_list = await self.client.get_list_of_task(task_type)
self._urn_cache[task_type] = [content.urn for content in content_list.root]
return self._urn_cache[task_type]
content_list = await self.client.get_list_of_task(task_type)
return [content.urn for content in content_list.root]

async def _fetch_tasks_by_urns(self, task_type: TaskType, urns: Sequence[str]) -> list[dict[str, Any]]:
"""
Expand All @@ -55,12 +52,18 @@ async def _fetch_tasks_by_urns(self, task_type: TaskType, urns: Sequence[str]) -
results = await asyncio.gather(*get_tasks, return_exceptions=True)

tasks: list[dict[str, Any]] = []
for result in results:
failed_urns: list[str] = []
for urn, result in zip(urns, results, strict=True):
if isinstance(result, dict):
tasks.append(result)
elif isinstance(result, AMTNotFound):
logger.warning(f"Cannot find {task_type.value}")
else:
failed_urns.append(urn)
elif isinstance(result, Exception):
raise result

if failed_urns:
# Sonar cloud does not like displaying the failed urns, so the warning is now
# generic without specification of the urns.
logger.warning("Cannot find all tasks")

return tasks
14 changes: 7 additions & 7 deletions amt/schema/ai_act_profile.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from pydantic import field_validator
from pydantic import Field, field_validator

from amt.schema.shared import BaseModel


class AiActProfile(BaseModel):
type: str | None
open_source: str | None
publication_category: str | None
systemic_risk: str | None
transparency_obligations: str | None
role: list[str] | str | None
type: str | None = Field(default=None)
open_source: str | None = Field(default=None)
publication_category: str | None = Field(default=None)
systemic_risk: str | None = Field(default=None)
transparency_obligations: str | None = Field(default=None)
role: list[str] | str | None = Field(default=None)

@field_validator("role")
def compute_role(cls, v: list[str] | None) -> str | None:
Expand Down
5 changes: 3 additions & 2 deletions amt/schema/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ class MeasureBase(BaseModel):


class MeasureTask(MeasureBase):
state: str
state: str = Field(default="")
value: str = Field(default="")
version: str
value: str


class Measure(MeasureBase):
name: str
schema_version: str
description: str
links: list[str] = Field(default=[])
url: str
Expand Down
13 changes: 7 additions & 6 deletions amt/schema/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ class RoleEnum(Enum):
distributeur = "distributeur"


class AiActProfileItem(BaseModel):
class RequirementAiActProfile(BaseModel):
type: list[TypeEnum]
open_source: list[OpenSourceEnum] | None = None
open_source: list[OpenSourceEnum]
risk_category: list[RiskCategoryEnum]
systemic_risk: list[SystemicRiskEnum] | None = None
transparency_obligations: list[TransparencyObligation] | None = None
systemic_risk: list[SystemicRiskEnum]
transparency_obligations: list[TransparencyObligation]
role: list[RoleEnum]


Expand All @@ -54,15 +54,16 @@ class RequirementBase(BaseModel):


class RequirementTask(RequirementBase):
state: str
state: str = Field(default="")
version: str


class Requirement(RequirementBase):
name: str
description: str
schema_version: str
links: list[str] = Field(default=[])
ai_act_profile: list[AiActProfileItem]
ai_act_profile: list[RequirementAiActProfile]
always_applicable: int = Field(
...,
description="1 if requirements applies to every system, 0 if only for specific systems",
Expand Down
2 changes: 1 addition & 1 deletion amt/services/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def create(self, algorithm_new: AlgorithmNew) -> Algorithm:
role=algorithm_new.role,
)

requirements, measures = get_requirements_and_measures(ai_act_profile)
requirements, measures = await get_requirements_and_measures(ai_act_profile)

system_card = SystemCard(
name=algorithm_new.name,
Expand Down
88 changes: 74 additions & 14 deletions amt/services/task_registry.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,86 @@
import logging

from amt.schema.measure import MeasureTask
from amt.schema.requirement import RequirementTask
from amt.schema.requirement import Requirement, RequirementTask
from amt.schema.system_card import AiActProfile
from amt.services.measures import create_measures_service
from amt.services.requirements import create_requirements_service

logger = logging.getLogger(__name__)


def get_requirements(ai_act_profile: AiActProfile) -> list[RequirementTask]:
requirements: list[RequirementTask] = []
def is_requirement_applicable(requirement: Requirement, ai_act_profile: AiActProfile) -> bool:
"""
Determine if a specific requirement is applicable to a given AI Act profile.
return requirements
Evaluation Criteria:
- Always applicable requirements automatically return True.
- For the 'role' attribute, handles compound values like
"gebruiksverantwoordelijke + aanbieder".
- For the 'systemic_risk' attribute, handles the old name 'publication_category'.
- A requirement is applicable if all specified attributes match or have no
specific restrictions.
"""
if requirement.always_applicable == 1:
return True

# We can assume the ai_act_profile field always contains exactly 1 element.
requirement_profile = requirement.ai_act_profile[0]
comparison_attrs = (
"type",
"risk_category",
"type",
"open_source",
"systemic_risk",
"transparency_obligations",
)

def get_requirements_and_measures(
for attr in comparison_attrs:
requirement_attr_values = getattr(requirement_profile, attr, [])

if not requirement_attr_values:
continue

input_value = _parse_attribute_values(attr, ai_act_profile)

if not input_value & {attr_value.value for attr_value in requirement_attr_values}:
return False

return True


async def get_requirements_and_measures(
ai_act_profile: AiActProfile,
) -> tuple[
list[RequirementTask],
list[MeasureTask],
]:
# TODO (Robbert): the body of this method will be added later (another ticket)
measure_card: list[MeasureTask] = []
requirements_card: list[RequirementTask] = []

return requirements_card, measure_card
) -> tuple[list[RequirementTask], list[MeasureTask]]:
requirements_service = create_requirements_service()
measure_service = create_measures_service()
all_requirements = await requirements_service.fetch_requirements()

applicable_requirements: list[RequirementTask] = []
applicable_measures: list[MeasureTask] = []
measure_urns: set[str] = set()

for requirement in all_requirements:
if is_requirement_applicable(requirement, ai_act_profile):
applicable_requirements.append(RequirementTask(urn=requirement.urn, version=requirement.schema_version))

for measure_urn in requirement.links:
if measure_urn not in measure_urns:
measure = await measure_service.fetch_measures(measure_urn)
applicable_measures.append(MeasureTask(urn=measure_urn, version=measure[0].schema_version))
measure_urns.add(measure_urn)

return applicable_requirements, applicable_measures


def _parse_attribute_values(attr: str, ai_act_profile: AiActProfile) -> set[str]:
"""
Helper function needed in `is_requirement_applicable`, handling special case for 'role'
and 'publication_category'.
"""
if attr == "role":
return {s.strip() for s in getattr(ai_act_profile, attr, "").split("+")}
if attr == "risk_category":
return {getattr(ai_act_profile, "publication_category", "")}

return {getattr(ai_act_profile, attr, "")}
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ liccheck = "^0.9.2"
authlib = "^1.3.2"
aiosqlite = "^0.20.0"
asyncpg = "^0.30.0"
vcrpy = "^6.0.2"


[tool.poetry.group.test.dependencies]
Expand All @@ -55,6 +54,7 @@ playwright = "^1.47.0"
pytest-playwright = "^0.5.2"
pytest-httpx = "^0.33.0"
freezegun = "^1.5.1"
vcrpy = "^6.0.2"


[tool.poetry.group.dev.dependencies]
Expand Down
2 changes: 1 addition & 1 deletion tests/api/routes/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def test_post_new_algorithms_write_system_card(
role=algorithm_new.role,
)

requirements, measures = get_requirements_and_measures(ai_act_profile)
requirements, measures = await get_requirements_and_measures(ai_act_profile)

system_card = SystemCard(
name=algorithm_new.name,
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/test_create_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_e2e_create_algorithm(page: Page):
button = page.locator("#button-new-algorithm-create")
button.click()

page.wait_for_timeout(10000)
expect(page.get_by_text("My new algorithm").first).to_be_visible()


Expand Down Expand Up @@ -78,6 +79,7 @@ def test_e2e_create_algorithm_with_tasks(page: Page):
button = page.locator("#button-new-algorithm-create")
button.click()

page.wait_for_timeout(10000)
expect(page.get_by_text("My new filled algorithm").first).to_be_visible()
card_1 = page.get_by_text("Geef een korte beschrijving van het beoogde AI-systeem")
expect(card_1).to_be_visible()
Expand Down
Loading

0 comments on commit 8de085a

Please sign in to comment.