Skip to content

Commit

Permalink
feat(api): refactor create_variant and add integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
salemsd committed Dec 6, 2024
1 parent 1256468 commit 2c01407
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 57 deletions.
17 changes: 3 additions & 14 deletions src/antares/model/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def create_study_api(
url = f"{base_url}/studies?name={study_name}&version={version}"
response = wrapper.post(url)
study_id = response.json()

study_settings = _returns_study_settings(base_url, study_id, wrapper, False, settings)

except APIError as e:
Expand Down Expand Up @@ -199,10 +198,8 @@ def create_variant_api(api_config: APIconf, study_id: str, variant_name: str) ->
"""
factory = ServiceFactory(api_config, study_id)
api_service = factory.create_study_service()
variant_id = api_service.create_variant(variant_name)
variant = read_study_api(api_config, variant_id)

return variant
return api_service.create_variant(variant_name)


class Study:
Expand Down Expand Up @@ -327,15 +324,7 @@ def create_variant(self, variant_name: str) -> "Study":
variant_name: the name of the new variant
Returns: The variant in the form of a Study object
"""
variant_id = self._study_service.create_variant(variant_name)
config = self._study_service.config

if isinstance(config, APIconf):
variant = read_study_api(config, variant_id)
else:
raise TypeError("Expected config to be of type APIconf")

return variant
return self._study_service.create_variant(variant_name)


def _verify_study_already_exists(study_directory: Path) -> None:
Expand Down Expand Up @@ -386,4 +375,4 @@ def _create_correlation_ini_files(local_settings: StudySettingsLocal, study_dire
season_correlation=getattr(local_settings.time_series_parameters, field).season_correlation,
),
)
ini_file.write_ini_file()
ini_file.write_ini_file()
14 changes: 10 additions & 4 deletions src/antares/service/api_services/study_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.
from typing import Optional
from typing import TYPE_CHECKING, Optional

import antares.model.study as study

from antares.api_conf.api_conf import APIconf
from antares.api_conf.request_wrapper import RequestWrapper
Expand All @@ -31,6 +33,9 @@
from antares.model.settings.time_series import TimeSeriesParameters
from antares.service.base_services import BaseStudyService

if TYPE_CHECKING:
from antares.model.study import Study


def _returns_study_settings(
base_url: str, study_id: str, wrapper: RequestWrapper, update: bool, settings: Optional[StudySettings]
Expand Down Expand Up @@ -106,10 +111,11 @@ def delete(self, children: bool) -> None:
except APIError as e:
raise StudyDeletionError(self.study_id, e.message) from e

def create_variant(self, variant_name: str) -> str:
def create_variant(self, variant_name: str) -> "Study":
url = f"{self._base_url}/studies/{self.study_id}/variants?name={variant_name}"
try:
response = self._wrapper.post(url)
return response.text
variant_id = response.json()
return study.read_study_api(self.config, variant_id)
except APIError as e:
raise StudyVariantCreationError(self.study_id, e.message) from e
raise StudyVariantCreationError(self.study_id, e.message) from e
11 changes: 7 additions & 4 deletions src/antares/service/base_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from abc import ABC, abstractmethod
from types import MappingProxyType
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional

import pandas as pd

Expand All @@ -31,6 +31,9 @@
from antares.model.st_storage import STStorage, STStorageProperties
from antares.model.thermal import ThermalCluster, ThermalClusterMatrixName, ThermalClusterProperties

if TYPE_CHECKING:
from antares.model.study import Study


class BaseAreaService(ABC):
@abstractmethod
Expand Down Expand Up @@ -508,13 +511,13 @@ def delete(self, children: bool) -> None:
pass

@abstractmethod
def create_variant(self, variant_name: str) -> str:
def create_variant(self, variant_name: str) -> "Study":
"""
Creates a new variant for the study
Args:
variant_name: the name of the new variant
Returns: id of the variant
Returns: the variant
"""
pass

Expand Down Expand Up @@ -561,4 +564,4 @@ def update_st_storage_properties(

@abstractmethod
def read_st_storages(self, area_id: str) -> List[STStorage]:
pass
pass
9 changes: 6 additions & 3 deletions src/antares/service/local_services/study_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
#
# This file is part of the Antares project.

from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

from antares.config.local_configuration import LocalConfiguration
from antares.model.binding_constraint import BindingConstraint
from antares.model.settings.study_settings import StudySettings
from antares.service.base_services import BaseStudyService

if TYPE_CHECKING:
from antares.model.study import Study


class StudyLocalService(BaseStudyService):
def __init__(self, config: LocalConfiguration, study_name: str, **kwargs: Any) -> None:
Expand All @@ -41,5 +44,5 @@ def delete_binding_constraint(self, constraint: BindingConstraint) -> None:
def delete(self, children: bool) -> None:
raise NotImplementedError

def create_variant(self, variant_name: str) -> str:
raise NotImplementedError
def create_variant(self, variant_name: str) -> "Study":
raise NotImplementedError
38 changes: 6 additions & 32 deletions tests/antares/services/api_services/test_study_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_create_variant_success(self):
base_url = "https://antares.com/api/v1"
url = f"{base_url}/studies/{self.study_id}/variants?name={variant_name}"
variant_id = "variant_id"
mocker.post(url, text=variant_id, status_code=201)
mocker.post(url, json=variant_id, status_code=201)

variant_url = f"{base_url}/studies/{variant_id}"
mocker.get(variant_url, json={"id": variant_id, "name": variant_name, "version": "880"}, status_code=200)
Expand All @@ -270,33 +270,14 @@ def test_create_variant_success(self):
mocker.get(areas_url, json={}, status_code=200)

variant = self.study.create_variant(variant_name)
variant_from_api = create_variant_api(self.api, self.study_id, variant_name)

assert isinstance(variant, Study)
assert isinstance(variant_from_api, Study)
assert variant.name == variant_name
assert variant_from_api.name == variant_name
assert variant.service.study_id == variant_id

def test_create_variant_api_success(self):
variant_name = "variant_test"
with requests_mock.Mocker() as mocker:
base_url = "https://antares.com/api/v1"
url = f"{base_url}/studies/{self.study_id}/variants?name={variant_name}"
variant_id = "variant_id"
mocker.post(url, text=variant_id, status_code=201)

variant_url = f"{base_url}/studies/{variant_id}"
mocker.get(variant_url, json={"id": variant_id, "name": variant_name, "version": "880"}, status_code=200)

config_urls = re.compile(f"{base_url}/studies/{variant_id}/config/.*")
mocker.get(config_urls, json={}, status_code=200)

areas_url = f"{base_url}/studies/{variant_id}/areas?ui=true"
mocker.get(areas_url, json={}, status_code=200)

variant = create_variant_api(self.api, self.study_id, variant_name)

assert isinstance(variant, Study)
assert variant.name == variant_name
assert variant.service.study_id == variant_id
assert variant_from_api.service.study_id == variant_id

def test_create_variant_fails(self):
variant_name = "variant_test"
Expand All @@ -309,12 +290,5 @@ def test_create_variant_fails(self):
with pytest.raises(StudyVariantCreationError, match=error_message):
self.study.create_variant(variant_name)

def test_create_variant_api_fails(self):
variant_name = "variant_test"
with requests_mock.Mocker() as mocker:
base_url = "https://antares.com/api/v1"
url = f"{base_url}/studies/{self.study_id}/variants?name={variant_name}"
error_message = "Variant creation failed"
mocker.post(url, json={"description": error_message}, status_code=404)
with pytest.raises(StudyVariantCreationError, match=error_message):
create_variant_api(self.api, self.study_id, variant_name)
create_variant_api(self.api, self.study_id, variant_name)
14 changes: 14 additions & 0 deletions tests/integration/test_web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,17 @@ def test_creation_lifecycle(self, antares_web: AntaresWebDesktop):
empty_settings = StudySettings()
new_study.update_settings(empty_settings)
assert old_settings == new_study.get_settings()

def test_create_variant(self, antares_web: AntaresWebDesktop):
api_config = APIconf(api_host=antares_web.url, token="", verify=False)
study = create_study_api("antares-craft-test", "880", api_config)

variant_name = "variant_test"
variant = study.create_variant(variant_name)

assert variant.name == variant_name
assert variant.service.study_id != study.service.study_id
assert variant.get_settings() == study.get_settings()
assert list(variant.get_areas().keys()) == list(study.get_areas().keys())
assert list(variant.get_links().keys()) == list(study.get_links().keys())
assert list(variant.get_binding_constraints().keys()) == list(study.get_binding_constraints().keys())

0 comments on commit 2c01407

Please sign in to comment.