Skip to content

Commit

Permalink
Add filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Nov 19, 2024
1 parent 6a4824b commit db65852
Show file tree
Hide file tree
Showing 12 changed files with 321 additions and 53 deletions.
4 changes: 1 addition & 3 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
"DEBUG": "True",
"AUTO_CREATE_SCHEMA": "True",
"ENVIRONMENT": "local",
"LOGGING_LEVEL": "DEBUG",
"OIDC_CLIENT_SECRET": "uIeFiKFazNEIbpJ3wzj0lZLLSJXefeld",
"OIDC_CLIENT_ID": "AMT"
"LOGGING_LEVEL": "DEBUG"
}
},
{
Expand Down
17 changes: 9 additions & 8 deletions amt/repositories/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TaskRegistryRepository:

def __init__(self, client: TaskRegistryAPIClient) -> None:
self.client = client
self._urn_cache: dict[TaskType, list[str]] = {}

async def fetch_tasks(self, task_type: TaskType, urns: str | Sequence[str] | None = None) -> list[dict[str, Any]]:
"""
Expand All @@ -41,10 +40,8 @@ async def _fetch_valid_urns(self, task_type: TaskType) -> list[str]:
"""
Fetches all valid URNs for the given task type.
"""
if task_type not in self._urn_cache:
content_list = await self.client.get_list_of_task(task_type)
self._urn_cache[task_type] = [content.urn for content in content_list.root]
return self._urn_cache[task_type]
content_list = await self.client.get_list_of_task(task_type)
return [content.urn for content in content_list.root]

async def _fetch_tasks_by_urns(self, task_type: TaskType, urns: Sequence[str]) -> list[dict[str, Any]]:
"""
Expand All @@ -55,12 +52,16 @@ async def _fetch_tasks_by_urns(self, task_type: TaskType, urns: Sequence[str]) -
results = await asyncio.gather(*get_tasks, return_exceptions=True)

tasks: list[dict[str, Any]] = []
for result in results:
failed_urns: list[str] = []
for urn, result in zip(urns, results, strict=True):
if isinstance(result, dict):
tasks.append(result)
elif isinstance(result, AMTNotFound):
logger.warning(f"Cannot find {task_type.value}")
else:
failed_urns.append(urn)
elif isinstance(result, Exception):
raise result

if failed_urns:
logger.warning(f"Cannot find tasks for URNs: {failed_urns}")

return tasks
14 changes: 7 additions & 7 deletions amt/schema/ai_act_profile.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from pydantic import field_validator
from pydantic import Field, field_validator

from amt.schema.shared import BaseModel


class AiActProfile(BaseModel):
type: str | None
open_source: str | None
publication_category: str | None
systemic_risk: str | None
transparency_obligations: str | None
role: list[str] | str | None
type: str | None = Field(default=None)
open_source: str | None = Field(default=None)
risk_category: str | None = Field(default=None)
systemic_risk: str | None = Field(default=None)
transparency_obligations: str | None = Field(default=None)
role: list[str] | str | None = Field(default=None)

@field_validator("role")
def compute_role(cls, v: list[str] | None) -> str | None:
Expand Down
5 changes: 3 additions & 2 deletions amt/schema/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ class MeasureBase(BaseModel):


class MeasureTask(MeasureBase):
state: str
state: str = Field(default="")
value: str = Field(default="")
version: str
value: str


class Measure(MeasureBase):
name: str
schema_version: str
description: str
links: list[str] = Field(default=[])
url: str
Expand Down
13 changes: 7 additions & 6 deletions amt/schema/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ class RoleEnum(Enum):
distributeur = "distributeur"


class AiActProfileItem(BaseModel):
class RequirementAiActProfile(BaseModel):
type: list[TypeEnum]
open_source: list[OpenSourceEnum] | None = None
open_source: list[OpenSourceEnum]
risk_category: list[RiskCategoryEnum]
systemic_risk: list[SystemicRiskEnum] | None = None
transparency_obligations: list[TransparencyObligation] | None = None
systemic_risk: list[SystemicRiskEnum]
transparency_obligations: list[TransparencyObligation]
role: list[RoleEnum]


Expand All @@ -54,15 +54,16 @@ class RequirementBase(BaseModel):


class RequirementTask(RequirementBase):
state: str
state: str = Field(default="")
version: str


class Requirement(RequirementBase):
name: str
description: str
schema_version: str
links: list[str] = Field(default=[])
ai_act_profile: list[AiActProfileItem]
ai_act_profile: list[RequirementAiActProfile]
always_applicable: int = Field(
...,
description="1 if requirements applies to every system, 0 if only for specific systems",
Expand Down
4 changes: 2 additions & 2 deletions amt/services/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ async def create(self, algorithm_new: AlgorithmNew) -> Algorithm:
ai_act_profile = AiActProfile(
type=algorithm_new.type,
open_source=algorithm_new.open_source,
publication_category=algorithm_new.publication_category,
risk_category=algorithm_new.publication_category,
systemic_risk=algorithm_new.systemic_risk,
transparency_obligations=algorithm_new.transparency_obligations,
role=algorithm_new.role,
)

requirements, measures = get_requirements_and_measures(ai_act_profile)
requirements, measures = await get_requirements_and_measures(ai_act_profile)

system_card = SystemCard(
name=algorithm_new.name,
Expand Down
59 changes: 45 additions & 14 deletions amt/services/task_registry.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,57 @@
import logging

from amt.schema.measure import MeasureTask
from amt.schema.requirement import RequirementTask
from amt.schema.requirement import Requirement, RequirementTask
from amt.schema.system_card import AiActProfile
from amt.services.measures import create_measures_service
from amt.services.requirements import create_requirements_service

logger = logging.getLogger(__name__)


def get_requirements(ai_act_profile: AiActProfile) -> list[RequirementTask]:
requirements: list[RequirementTask] = []
def is_requirement_applicable(requirement: Requirement, ai_act_profile: AiActProfile) -> bool:
if requirement.always_applicable == 1:
return True

return requirements
# We can assume the ai_act_profile field always is of length 1.
requirement_profile = requirement.ai_act_profile[0]
comparison_attrs = ["type", "risk_category", "type", "open_source", "systemic_risk", "transparency_obligations"]

for attr in comparison_attrs:
requirement_attr_values = getattr(requirement_profile, attr, [])

def get_requirements_and_measures(
if not requirement_attr_values:
continue

# In the system card the field role has values "gebruiksverantwoordelijke", "aanbieder" and
# "gebruiksverantwoordelijke + aanbieder", so we need to split the latter into a list of two
# strings.
raw_input_value = getattr(ai_act_profile, attr)
input_value = {raw_input_value} if attr != "role" else {s.strip() for s in raw_input_value.split("+")}

if not input_value & {attr_value.value for attr_value in requirement_attr_values}:
return False

return True


async def get_requirements_and_measures(
ai_act_profile: AiActProfile,
) -> tuple[
list[RequirementTask],
list[MeasureTask],
]:
# TODO (Robbert): the body of this method will be added later (another ticket)
measure_card: list[MeasureTask] = []
requirements_card: list[RequirementTask] = []

return requirements_card, measure_card
) -> tuple[list[RequirementTask], list[MeasureTask]]:
requirements_service = create_requirements_service()
measure_service = create_measures_service()
all_requirements = await requirements_service.fetch_requirements()

applicable_requirements: list[RequirementTask] = []
applicable_measures: dict[str, MeasureTask] = {}

for requirement in all_requirements:
if is_requirement_applicable(requirement, ai_act_profile):
applicable_requirements.append(RequirementTask(urn=requirement.urn, version=requirement.schema_version))

for measure_urn in requirement.links:
if measure_urn not in applicable_measures:
measure = await measure_service.fetch_measures(measure_urn)
applicable_measures[measure_urn] = MeasureTask(urn=measure_urn, version=measure[0].schema_version)

return applicable_requirements, [*applicable_measures.values()]
4 changes: 2 additions & 2 deletions tests/api/routes/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,13 @@ async def test_post_new_algorithms_write_system_card(
ai_act_profile = AiActProfile(
type=algorithm_new.type,
open_source=algorithm_new.open_source,
publication_category=algorithm_new.publication_category,
risk_category=algorithm_new.publication_category,
systemic_risk=algorithm_new.systemic_risk,
transparency_obligations=algorithm_new.transparency_obligations,
role=algorithm_new.role,
)

requirements, measures = get_requirements_and_measures(ai_act_profile)
requirements, measures = await get_requirements_and_measures(ai_act_profile)

system_card = SystemCard(
name=algorithm_new.name,
Expand Down
18 changes: 9 additions & 9 deletions tests/schema/test_schema_ai_act_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ def test_ai_act_profile_schema_create_new():
algorithm_new = AiActProfile(
type="AI-systeem",
open_source="open-source",
publication_category="hoog-risico AI",
risk_category="hoog-risico AI",
systemic_risk="systeemrisico",
transparency_obligations="transparantieverplichtingen",
role="aanbieder",
)
assert algorithm_new.type == "AI-systeem"
assert algorithm_new.open_source == "open-source"
assert algorithm_new.publication_category == "hoog-risico AI"
assert algorithm_new.risk_category == "hoog-risico AI"
assert algorithm_new.systemic_risk == "systeemrisico"
assert algorithm_new.transparency_obligations == "transparantieverplichtingen"
assert algorithm_new.role == "aanbieder"
Expand All @@ -23,14 +23,14 @@ def test_ai_act_profile_schema_create_new_no_role():
algorithm_new = AiActProfile(
type="AI-systeem",
open_source="open-source",
publication_category="hoog-risico AI",
risk_category="hoog-risico AI",
systemic_risk="systeemrisico",
transparency_obligations="transparantieverplichtingen",
role=None,
)
assert algorithm_new.type == "AI-systeem"
assert algorithm_new.open_source == "open-source"
assert algorithm_new.publication_category == "hoog-risico AI"
assert algorithm_new.risk_category == "hoog-risico AI"
assert algorithm_new.systemic_risk == "systeemrisico"
assert algorithm_new.transparency_obligations == "transparantieverplichtingen"
assert algorithm_new.role is None
Expand All @@ -40,14 +40,14 @@ def test_ai_act_profile_schema_create_new_empty_role_list():
algorithm_new = AiActProfile(
type="AI-systeem",
open_source="open-source",
publication_category="hoog-risico AI",
risk_category="hoog-risico AI",
systemic_risk="systeemrisico",
transparency_obligations="transparantieverplichtingen",
role=[],
)
assert algorithm_new.type == "AI-systeem"
assert algorithm_new.open_source == "open-source"
assert algorithm_new.publication_category == "hoog-risico AI"
assert algorithm_new.risk_category == "hoog-risico AI"
assert algorithm_new.systemic_risk == "systeemrisico"
assert algorithm_new.transparency_obligations == "transparantieverplichtingen"
assert algorithm_new.role is None
Expand All @@ -57,14 +57,14 @@ def test_ai_act_profile_schema_create_new_double_role():
algorithm_new = AiActProfile(
type="AI-systeem",
open_source="open-source",
publication_category="hoog-risico AI",
risk_category="hoog-risico AI",
systemic_risk="systeemrisico",
transparency_obligations="transparantieverplichtingen",
role=["aanbieder", "gebruiksverantwoordelijke"],
)
assert algorithm_new.type == "AI-systeem"
assert algorithm_new.open_source == "open-source"
assert algorithm_new.publication_category == "hoog-risico AI"
assert algorithm_new.risk_category == "hoog-risico AI"
assert algorithm_new.systemic_risk == "systeemrisico"
assert algorithm_new.transparency_obligations == "transparantieverplichtingen"
assert algorithm_new.role == "aanbieder + gebruiksverantwoordelijke"
Expand All @@ -75,7 +75,7 @@ def test_ai_act_profile_schema_create_new_too_many_roles():
AiActProfile(
type="AI-systeem",
open_source="open-source",
publication_category="hoog-risico AI",
risk_category="hoog-risico AI",
systemic_risk="systeemrisico",
transparency_obligations="transparantieverplichtingen",
role=["aanbieder", "gebruiksverantwoordelijke", "I am too much of a role"],
Expand Down
2 changes: 2 additions & 0 deletions tests/services/test_measures_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ def sample_data() -> list[dict[str, Any]]:
{
"urn": "urn:measure:1",
"name": "Measure 1",
"schema_version": "1.1.0",
"description": "description_1",
"links": [],
"url": "url_1",
},
{
"urn": "urn:measure:2",
"name": "Measure 2",
"schema_version": "1.1.0",
"description": "description_2",
"url": "url_2",
"language": "nl",
Expand Down
2 changes: 2 additions & 0 deletions tests/services/test_requirements_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def sample_data() -> list[dict[str, Any]]:
"urn": "urn:requirement:1",
"name": "Requirement 1",
"description": "description_1",
"schema_version": "1.1.0",
"links": [],
"always_applicable": 1,
"ai_act_profile": [],
Expand All @@ -33,6 +34,7 @@ def sample_data() -> list[dict[str, Any]]:
"urn": "urn:requirement:2",
"name": "Requirement 2",
"description": "description_2",
"schema_version": "1.1.0",
"url": "url_2",
"always_applicable": 1,
"ai_act_profile": [],
Expand Down
Loading

0 comments on commit db65852

Please sign in to comment.