diff --git a/ebl/fragmentarium/application/fragment_repository.py b/ebl/fragmentarium/application/fragment_repository.py index 2aa66840a..8ad8fcf8e 100644 --- a/ebl/fragmentarium/application/fragment_repository.py +++ b/ebl/fragmentarium/application/fragment_repository.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Sequence, Optional +from typing import List, Sequence, Optional, Literal from ebl.common.domain.scopes import Scope from ebl.common.query.query_result import QueryResult, AfORegisterToFragmentQueryResult @@ -10,6 +10,18 @@ from ebl.transliteration.domain.museum_number import MuseumNumber from ebl.fragmentarium.domain.date import Date +UpdatableField = Literal[ + "introduction", + "text", + "genres", + "references", + "script", + "notes", + "archaeology", + "date", + "dates_in_text", +] + class FragmentRepository(ABC): @abstractmethod @@ -90,7 +102,11 @@ def query_next_and_previous_fragment( ... @abstractmethod - def update_field(self, field: str, fragment: Fragment) -> None: + def update_transliteration(self, fragment: Fragment) -> None: + ... + + @abstractmethod + def update_field(self, field: UpdatableField, fragment: Fragment) -> None: ... @abstractmethod diff --git a/ebl/fragmentarium/application/fragment_updater.py b/ebl/fragmentarium/application/fragment_updater.py index 888d179d9..6b0b464f7 100644 --- a/ebl/fragmentarium/application/fragment_updater.py +++ b/ebl/fragmentarium/application/fragment_updater.py @@ -48,7 +48,7 @@ def update_transliteration( else fragment.update_lowest_join_transliteration(transliteration, user) ) self._create_changelog(user, fragment, updated_fragment) - self._repository.update_field("transliteration", updated_fragment) + self._repository.update_transliteration(updated_fragment) return self._create_result(updated_fragment) @@ -123,7 +123,7 @@ def update_lemmatization( updated_fragment = fragment.update_lemmatization(lemmatization) self._create_changelog(user, fragment, updated_fragment) - self._repository.update_field("lemmatization", updated_fragment) + self._repository.update_field("text", updated_fragment) return self._create_result(updated_fragment) diff --git a/ebl/fragmentarium/infrastructure/mongo_fragment_repository.py b/ebl/fragmentarium/infrastructure/mongo_fragment_repository.py index 8e111269b..b2176db90 100644 --- a/ebl/fragmentarium/infrastructure/mongo_fragment_repository.py +++ b/ebl/fragmentarium/infrastructure/mongo_fragment_repository.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Sequence, Iterator +from typing import List, Optional, Sequence, Iterator, get_args import pymongo from marshmallow import EXCLUDE @@ -13,7 +13,10 @@ ) from ebl.errors import NotFoundError from ebl.fragmentarium.application.fragment_info_schema import FragmentInfoSchema -from ebl.fragmentarium.application.fragment_repository import FragmentRepository +from ebl.fragmentarium.application.fragment_repository import ( + FragmentRepository, + UpdatableField, +) from ebl.fragmentarium.application.fragment_schema import FragmentSchema, ScriptSchema from ebl.fragmentarium.application.joins_schema import JoinSchema from ebl.fragmentarium.application.line_to_vec import LineToVecEntry @@ -274,33 +277,30 @@ def query_by_transliterated_not_revised_by_other( ) return FragmentInfoSchema(many=True).load(cursor) - def update_field(self, field, fragment): - fields_to_update = { - "introduction": ("introduction",), - "lemmatization": ("text",), - "genres": ("genres",), - "references": ("references",), - "script": ("script",), - "notes": ("notes",), - "archaeology": ("archaeology",), - "transliteration": ( - "text", - "signs", - "record", - "line_to_vec", - ), - "date": ("date",), - "dates_in_text": ("dates_in_text",), - } + def update_transliteration(self, fragment: Fragment): + self._fragments.update_one( + fragment_is(fragment), + { + "$set": FragmentSchema( + only=( + "text", + "signs", + "record", + "line_to_vec", + ) + ).dump(fragment) + }, + ) - if field not in fields_to_update: + def update_field(self, field: UpdatableField, fragment: Fragment): + if field not in (valid_fields := get_args(UpdatableField)): raise ValueError( - f"Unexpected update field {field}, must be one of {','.join(fields_to_update)}" + f"Unexpected update field {field}, must be one of {valid_fields}" ) - query = FragmentSchema(only=fields_to_update[field]).dump(fragment) + query = FragmentSchema(only=(field,)).dump(fragment) self._fragments.update_one( fragment_is(fragment), - {"$set": query if query else {field: None}}, + {"$set": query or {field: None}}, ) def query_next_and_previous_folio(self, folio_name, folio_number, number): diff --git a/ebl/tests/fragmentarium/test_fragment_repository.py b/ebl/tests/fragmentarium/test_fragment_repository.py index 2286cfb2f..463c02f10 100644 --- a/ebl/tests/fragmentarium/test_fragment_repository.py +++ b/ebl/tests/fragmentarium/test_fragment_repository.py @@ -358,7 +358,7 @@ def test_update_transliteration_with_record(fragment_repository, user): TransliterationUpdate(parse_atf_lark("$ (the transliteration)")), user ) - fragment_repository.update_field("transliteration", updated_fragment) + fragment_repository.update_transliteration(updated_fragment) result = fragment_repository.query_by_museum_number(fragment.number) assert result == updated_fragment @@ -367,7 +367,7 @@ def test_update_transliteration_with_record(fragment_repository, user): def test_update_update_transliteration_not_found(fragment_repository): transliterated_fragment = TransliteratedFragmentFactory.build() with pytest.raises(NotFoundError): - fragment_repository.update_field("transliteration", transliterated_fragment) + fragment_repository.update_transliteration(transliterated_fragment) def test_update_genres(fragment_repository): @@ -410,7 +410,7 @@ def test_update_lemmatization(fragment_repository): lemmatization = Lemmatization(tokens) updated_fragment = transliterated_fragment.update_lemmatization(lemmatization) - fragment_repository.update_field("lemmatization", updated_fragment) + fragment_repository.update_field("text", updated_fragment) result = fragment_repository.query_by_museum_number(transliterated_fragment.number) assert result == updated_fragment @@ -446,10 +446,17 @@ def test_update_script(fragment_repository: FragmentRepository): assert result == updated_fragment -def test_update_update_lemmatization_not_found(fragment_repository): +def test_update_lemmatization_not_found(fragment_repository): transliterated_fragment = TransliteratedFragmentFactory.build() with pytest.raises(NotFoundError): - fragment_repository.update_field("lemmatization", transliterated_fragment) + fragment_repository.update_field("text", transliterated_fragment) + + +def test_update_invalid_field_fails(fragment_repository): + with pytest.raises(ValueError, match="Unexpected update field"): + fragment_repository.update_field( + "some invalid field name", FragmentFactory.build() + ) def test_statistics(database, fragment_repository): diff --git a/ebl/tests/fragmentarium/test_fragment_updater.py b/ebl/tests/fragmentarium/test_fragment_updater.py index 2069e365d..52f7769c1 100644 --- a/ebl/tests/fragmentarium/test_fragment_updater.py +++ b/ebl/tests/fragmentarium/test_fragment_updater.py @@ -64,7 +64,7 @@ def test_update_transliteration( ).thenReturn() ( when(fragment_repository) - .update_field("transliteration", transliterated_fragment) + .update_transliteration(transliterated_fragment) .thenReturn() ) @@ -208,9 +208,7 @@ def test_update_lemmatization( {"_id": str(number), **SCHEMA.dump(transliterated_fragment)}, {"_id": str(number), **SCHEMA.dump(lemmatized_fragment)}, ).thenReturn() - when(fragment_repository).update_field( - "lemmatization", lemmatized_fragment - ).thenReturn() + when(fragment_repository).update_field("text", lemmatized_fragment).thenReturn() result = fragment_updater.update_lemmatization(number, lemmatization, user) assert result == (injected_fragment, False)