Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add requirements and measures to task registry client #360

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ FROM development AS test

COPY ./example/ ./example/
COPY ./resources/ ./resources/
COPY ./example_registry/ ./example_registry/
RUN npm run build

FROM project-base AS production
Expand All @@ -82,7 +81,6 @@ COPY --chown=root:root --chmod=755 amt /app/amt
COPY --chown=root:root --chmod=755 alembic.ini /app/alembic.ini
COPY --chown=root:root --chmod=755 prod.env /app/.env
COPY --chown=root:root --chmod=755 resources /app/resources
COPY --chown=root:root --chmod=755 example_registry /app/example_registry
COPY --chown=root:root --chmod=755 LICENSE /app/LICENSE
COPY --chown=amt:amt --chmod=755 docker-entrypoint.sh /app/docker-entrypoint.sh
USER root
Expand Down
75 changes: 43 additions & 32 deletions amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,31 @@
from amt.schema.requirement import RequirementTask
from amt.schema.system_card import SystemCard
from amt.schema.task import MovedTask
from amt.services import task_registry
from amt.services.algorithms import AlgorithmsService
from amt.services.instruments_and_requirements_state import InstrumentStateService, RequirementsStateService
from amt.services.task_registry import fetch_measures, fetch_requirements
from amt.services.measures import MeasuresService, create_measures_service
from amt.services.requirements import RequirementsService, create_requirements_service
from amt.services.tasks import TasksService

router = APIRouter()
logger = logging.getLogger(__name__)


def get_instrument_state(system_card: SystemCard) -> dict[str, Any]:
async def get_instrument_state(system_card: SystemCard) -> dict[str, Any]:
instrument_state = InstrumentStateService(system_card)
instrument_states = instrument_state.get_state_per_instrument()
instrument_states = await instrument_state.get_state_per_instrument()
return {
"instrument_states": instrument_states,
"count_0": instrument_state.get_amount_completed_instruments(),
"count_1": instrument_state.get_amount_total_instruments(),
}


def get_requirements_state(system_card: SystemCard) -> dict[str, Any]:
requirements = fetch_requirements([requirement.urn for requirement in system_card.requirements])
async def get_requirements_state(system_card: SystemCard) -> dict[str, Any]:
requirements_service = create_requirements_service()
requirements = await requirements_service.fetch_requirements(
[requirement.urn for requirement in system_card.requirements]
)
requirements_state_service = RequirementsStateService(system_card)
requirements_state = requirements_state_service.get_requirements_state(requirements)

Expand Down Expand Up @@ -106,8 +109,8 @@ async def get_tasks(
tasks_service: Annotated[TasksService, Depends(TasksService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)
tab_items = get_algorithm_details_tabs(request)
tasks_by_status = await gather_algorithm_tasks(algorithm_id, task_service=tasks_service)

Expand Down Expand Up @@ -168,8 +171,8 @@ async def get_algorithm_context(
algorithm_id: int, algorithms_service: AlgorithmsService, request: Request
) -> tuple[Algorithm, dict[str, Any]]:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)
tab_items = get_algorithm_details_tabs(request)
return algorithm, {
"last_edited": algorithm.last_edited,
Expand Down Expand Up @@ -275,8 +278,8 @@ async def get_system_card(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand Down Expand Up @@ -320,8 +323,8 @@ async def get_algorithm_inference(
request,
)

instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand All @@ -344,10 +347,12 @@ async def get_system_card_requirements(
request: Request,
algorithm_id: int,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)],
measures_service: Annotated[MeasuresService, Depends(create_measures_service)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)
tab_items = get_algorithm_details_tabs(request)

breadcrumbs = resolve_base_navigation_items(
Expand All @@ -359,13 +364,15 @@ async def get_system_card_requirements(
request,
)

requirements = fetch_requirements([requirement.urn for requirement in algorithm.system_card.requirements])
requirements = await requirements_service.fetch_requirements(
[requirement.urn for requirement in algorithm.system_card.requirements]
)

# Get measures that correspond to the requirements and merge them with the measuretasks
requirements_and_measures = []
for requirement in requirements:
completed_measures_count = 0
linked_measures = fetch_measures(requirement.links)
linked_measures = await measures_service.fetch_measures(requirement.links)
extended_linked_measures: list[ExtendedMeasureTask] = []
for measure in linked_measures:
measure_task = find_measure_task(algorithm.system_card, measure.urn)
Expand Down Expand Up @@ -410,16 +417,18 @@ def find_requirement_task(system_card: SystemCard, requirement_urn: str) -> Requ
return None


def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure_urn: str) -> list[RequirementTask]:
async def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure_urn: str) -> list[RequirementTask]:
requirement_mapper: dict[str, RequirementTask] = {}
for requirement_task in system_card.requirements:
requirement_mapper[requirement_task.urn] = requirement_task

requirement_tasks: list[RequirementTask] = []
measure = fetch_measures([measure_urn])
measures_service = create_measures_service()
requirements_service = create_requirements_service()
measure = await measures_service.fetch_measures(measure_urn)
for requirement_urn in measure[0].links:
# TODO: This is because measure are linked to too many requirement not applicable in our use case
if len(fetch_requirements([requirement_urn])) > 0:
if len(await requirements_service.fetch_requirements(requirement_urn)) > 0:
requirement_tasks.append(requirement_mapper[requirement_urn])

return requirement_tasks
Expand All @@ -443,7 +452,8 @@ async def get_measure(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
measure = task_registry.fetch_measures([measure_urn])
measures_service = create_measures_service()
measure = await measures_service.fetch_measures([measure_urn])
measure_task = find_measure_task(algorithm.system_card, measure_urn)

context = {
Expand All @@ -468,6 +478,7 @@ async def update_measure_value(
measure_urn: str,
measure_update: MeasureUpdate,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)

Expand All @@ -476,9 +487,9 @@ async def update_measure_value(
measure_task.value = measure_update.measure_value # pyright: ignore [reportOptionalMemberAccess]

# update for the linked requirements the state based on all it's measures
requirement_tasks = find_requirement_tasks_by_measure_urn(algorithm.system_card, measure_urn)
requirement_tasks = await find_requirement_tasks_by_measure_urn(algorithm.system_card, measure_urn)
requirement_urns = [requirement_task.urn for requirement_task in requirement_tasks]
requirements = fetch_requirements(requirement_urns)
requirements = await requirements_service.fetch_requirements(requirement_urns)

for requirement in requirements:
count_completed = 0
Expand Down Expand Up @@ -512,8 +523,8 @@ async def get_system_card_data_page(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand Down Expand Up @@ -550,8 +561,8 @@ async def get_system_card_instruments(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand Down Expand Up @@ -584,8 +595,8 @@ async def get_assessment_card(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

request.state.path_variables.update({"assessment_card": assessment_card})

Expand Down Expand Up @@ -635,8 +646,8 @@ async def get_model_card(
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
request.state.path_variables.update({"model_card": model_card})
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand Down
8 changes: 5 additions & 3 deletions amt/api/routes/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from amt.schema.algorithm import AlgorithmNew
from amt.schema.localized_value_item import LocalizedValueItem
from amt.services.algorithms import AlgorithmsService, get_template_files
from amt.services.instruments import InstrumentsService
from amt.services.instruments import InstrumentsService, create_instrument_service

router = APIRouter()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -125,7 +125,7 @@ async def get_root(
@router.get("/new")
async def get_new(
request: Request,
instrument_service: Annotated[InstrumentsService, Depends(InstrumentsService)],
instrument_service: Annotated[InstrumentsService, Depends(create_instrument_service)],
) -> HTMLResponse:
sub_menu_items = resolve_navigation_items([Navigation.ALGORITHMS_OVERVIEW], request) # pyright: ignore [reportUnusedVariable] # noqa
breadcrumbs = resolve_base_navigation_items([Navigation.ALGORITHMS_ROOT, Navigation.ALGORITHM_NEW], request)
Expand All @@ -134,8 +134,10 @@ async def get_new(

template_files = get_template_files()

instruments = await instrument_service.fetch_instruments()

context: dict[str, Any] = {
"instruments": instrument_service.fetch_instruments(),
"instruments": instruments,
"ai_act_profile": ai_act_profile,
"breadcrumbs": breadcrumbs,
"sub_menu_items": {}, # sub_menu_items disabled for now,
Expand Down
7 changes: 4 additions & 3 deletions amt/cli/check_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from pathlib import Path
from typing import Any
Expand All @@ -7,7 +8,7 @@

from amt.schema.instrument import Instrument
from amt.schema.system_card import SystemCard
from amt.services.instruments import InstrumentsService
from amt.services.instruments import create_instrument_service
from amt.services.instruments_and_requirements_state import all_lifecycles, get_all_next_tasks
from amt.services.storage import StorageFactory

Expand All @@ -29,8 +30,8 @@ def get_requested_instruments(all_instruments: list[Instrument], urns: list[str]
def get_tasks_by_priority(urns: list[str], system_card_path: Path) -> None:
try:
system_card = get_system_card(system_card_path)
instruments_service = InstrumentsService()
all_instruments = instruments_service.fetch_instruments()
instruments_service = create_instrument_service()
all_instruments = asyncio.run(instruments_service.fetch_instruments())
instruments = get_requested_instruments(all_instruments, urns)
next_tasks = get_all_next_tasks(instruments, system_card)

Expand Down
59 changes: 34 additions & 25 deletions amt/clients/clients.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,54 @@
import logging
from enum import StrEnum
from typing import Any

import httpx
from amt.core.exceptions import AMTInstrumentError, AMTNotFound
from amt.schema.github import RepositoryContent
from amt.schema.instrument import Instrument

logger = logging.getLogger(__name__)


class TaskRegistryAPIClient:
"""
This class interacts with the Task Registry API.
class TaskType(StrEnum):
INSTRUMENTS = "instruments"
REQUIREMENTS = "requirements"
MEASURES = "measures"


Currently it supports:
- Retrieving the list of instruments.
- Getting an instrument by URN.
class APIClient:
"""
Base API client with common HTTP functionality.
"""

base_url = "https://task-registry.apps.digilab.network"
def __init__(self, base_url: str, max_retries: int = 3, timeout: int = 5) -> None:
self.base_url = base_url
transport = httpx.AsyncHTTPTransport(retries=max_retries)
self.client = httpx.AsyncClient(timeout=timeout, transport=transport)

def __init__(self, max_retries: int = 3, timeout: int = 5) -> None:
transport = httpx.HTTPTransport(retries=max_retries)
self.client = httpx.Client(timeout=timeout, transport=transport)

def get_instrument_list(self) -> RepositoryContent:
response = self.client.get(f"{TaskRegistryAPIClient.base_url}/instruments/")
async def _make_request(self, endpoint: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
response = await self.client.get(f"{self.base_url}/{endpoint}", params=params)
Dismissed Show dismissed Hide dismissed
if response.status_code != 200:
raise AMTNotFound()
return RepositoryContent.model_validate(response.json()["entries"])
return response.json()


def get_instrument(self, urn: str, version: str = "latest") -> Instrument:
response = self.client.get(
f"{TaskRegistryAPIClient.base_url}/instruments/urn/{urn}", params={"version": version}
class TaskRegistryAPIClient(APIClient):
"""
Client for interacting with the Task Registry API.
"""

def __init__(self, max_retries: int = 3, timeout: int = 5) -> None:
super().__init__(
base_url="https://task-registry.apps.digilab.network", max_retries=max_retries, timeout=timeout
)

if response.status_code != 200:
raise AMTNotFound()
async def get_list_of_task(self, task: TaskType = TaskType.INSTRUMENTS) -> RepositoryContent:
response_data = await self._make_request(f"{task.value}/")
return RepositoryContent.model_validate(response_data["entries"])

data = response.json()
if "urn" not in data:
logger.exception("Invalid instrument fetched: key 'urn' must occur in instrument.")
async def get_task_by_urn(self, task_type: TaskType, urn: str, version: str = "latest") -> dict[str, Any]:
response_data = await self._make_request(f"{task_type.value}/urn/{urn}", params={"version": version})
if "urn" not in response_data:
logger.exception(f"Invalid task {task_type.value} fetched: key 'urn' must occur in task {task_type.value}.")
raise AMTInstrumentError()

return Instrument(**data)
return response_data
Loading
Loading