From 3e3e9913c751818ee7e29a7be9030f1a0f929e35 Mon Sep 17 00:00:00 2001 From: martinbaerwolff Date: Fri, 1 Sep 2023 16:23:24 +0200 Subject: [PATCH 1/5] Implement CLI for exporting counts --- OTAnalytics/application/config.py | 5 +- OTAnalytics/plugin_ui/cli.py | 71 +++++++++++++++++-- .../customtkinter_gui/dummy_viewmodel.py | 3 +- .../customtkinter_gui/frame_analysis.py | 4 +- OTAnalytics/plugin_ui/main_application.py | 8 +++ tests/OTAnalytics/plugin_ui/test_cli.py | 32 ++++++++- 6 files changed, 110 insertions(+), 13 deletions(-) diff --git a/OTAnalytics/application/config.py b/OTAnalytics/application/config.py index db555a2bc..5645c5662 100644 --- a/OTAnalytics/application/config.py +++ b/OTAnalytics/application/config.py @@ -5,10 +5,13 @@ """The log save directory.""" GEOMETRY_CACHE_SIZE: int = 20000 -DEFAULT_EVENTLIST_SAVE_NAME: str = "events" +DEFAULT_EVENTLIST_FILE_STEM: str = "events" DEFAULT_EVENTLIST_FILE_TYPE: str = "otevents" +DEFAULT_COUNTS_FILE_STEM: str = "counts" +DEFAULT_COUNTS_FILE_TYPE: str = "csv" DEFAULT_TRACK_FILE_TYPE: str = "ottrk" DEFAULT_SECTIONS_FILE_TYPE: str = "otflow" +DEFAULT_COUNTING_INTERVAL_IN_MINUTES: int = 15 OS: str = platform.system() """OS OTAnalyitcs is currently running on""" diff --git a/OTAnalytics/plugin_ui/cli.py b/OTAnalytics/plugin_ui/cli.py index 37ac10841..cab15729e 100644 --- a/OTAnalytics/plugin_ui/cli.py +++ b/OTAnalytics/plugin_ui/cli.py @@ -3,15 +3,24 @@ from pathlib import Path from typing import Iterable +from OTAnalytics.application.analysis.traffic_counting import ExportCounts +from OTAnalytics.application.analysis.traffic_counting_specification import ( + CountingSpecificationDto, +) from OTAnalytics.application.config import ( + DEFAULT_COUNTING_INTERVAL_IN_MINUTES, + DEFAULT_COUNTS_FILE_STEM, + DEFAULT_COUNTS_FILE_TYPE, + DEFAULT_EVENTLIST_FILE_STEM, DEFAULT_EVENTLIST_FILE_TYPE, - DEFAULT_EVENTLIST_SAVE_NAME, DEFAULT_SECTIONS_FILE_TYPE, DEFAULT_TRACK_FILE_TYPE, ) from OTAnalytics.application.datastore import EventListParser, FlowParser, TrackParser from OTAnalytics.application.logger import logger +from OTAnalytics.application.state import TracksMetadata from OTAnalytics.application.use_cases.create_events import CreateEvents +from OTAnalytics.application.use_cases.flow_repository import AddFlow from OTAnalytics.application.use_cases.section_repository import AddSection from OTAnalytics.application.use_cases.track_repository import ( AddAllTracks, @@ -123,7 +132,9 @@ def __init__( event_list_parser: EventListParser, event_repository: EventRepository, add_section: AddSection, + add_flow: AddFlow, create_events: CreateEvents, + export_counts: ExportCounts, add_all_tracks: AddAllTracks, clear_all_tracks: ClearAllTracks, progressbar: ProgressbarBuilder, @@ -136,7 +147,9 @@ def __init__( self._event_list_parser = event_list_parser self._event_repository = event_repository self._add_section = add_section + self._add_flow = add_flow self._create_events = create_events + self._export_counts = export_counts self._add_all_tracks = add_all_tracks self._clear_all_tracks = clear_all_tracks self._progressbar = progressbar @@ -149,7 +162,7 @@ def start(self) -> None: sections, flows = self._parse_flows(sections_file) - self._run_analysis(ottrk_files, sections) + self._run_analysis(ottrk_files, sections, flows) def _parse_flows(self, flow_file: Path) -> tuple[Iterable[Section], Iterable[Flow]]: return self._flow_parser.parse(flow_file) @@ -159,18 +172,24 @@ def _add_sections(self, sections: Iterable[Section]) -> None: for section in sections: self._add_section.add(section) + def _add_flows(self, flows: Iterable[Flow]) -> None: + """Add flows to flow repository.""" + for flow in flows: + self._add_flow.add(flow) + def _parse_tracks(self, track_files: list[Path]) -> None: for track_file in self._progressbar(track_files, "Parsed track files", "files"): tracks = self._track_parser.parse(track_file) self._add_all_tracks(tracks) def _run_analysis( - self, ottrk_files: set[Path], sections: Iterable[Section] + self, ottrk_files: set[Path], sections: Iterable[Section], flows: Iterable[Flow] ) -> None: """Run analysis.""" self._clear_all_tracks() self._event_repository.clear() self._add_sections(sections) + self._add_flows(flows) ottrk_files_sorted: list[Path] = sorted( ottrk_files, key=lambda file: str(file).lower() ) @@ -180,11 +199,15 @@ def _run_analysis( self._create_events() logger().info("Event list created.") - save_path = self._determine_eventlist_save_path(ottrk_files_sorted[0]) + event_list_output_file = self._determine_eventlist_save_path( + ottrk_files_sorted[0] + ) self._event_list_parser.serialize( - self._event_repository.get_all(), sections, save_path + self._event_repository.get_all(), sections, event_list_output_file ) - logger().info(f"Event list saved at '{save_path}'") + logger().info(f"Event list saved at '{event_list_output_file}'") + + self._do_export_counts(event_list_output_file) def _determine_eventlist_save_path(self, track_file: Path) -> Path: """Determine save path of eventlist. @@ -202,7 +225,7 @@ def _determine_eventlist_save_path(self, track_file: Path) -> Path: eventlist_file_name = self.cli_args.eventlist_filename if eventlist_file_name == "": return track_file.with_name( - f"{DEFAULT_EVENTLIST_SAVE_NAME}.{DEFAULT_EVENTLIST_FILE_TYPE}" + f"{DEFAULT_EVENTLIST_FILE_STEM}.{DEFAULT_EVENTLIST_FILE_TYPE}" ) return track_file.with_name( @@ -285,3 +308,37 @@ def _get_sections_file(file: str) -> Path: ) return sections_file + + def _do_export_counts(self, event_list_output_file: Path) -> None: + logger().info("Create counts ...") + tracks_metadata = TracksMetadata(self._add_all_tracks._track_repository) + tracks_metadata.notify_tracks([]) + start = tracks_metadata.first_detection_occurrence + end = tracks_metadata.last_detection_occurrence + modes = tracks_metadata.classifications + if start is None: + raise ValueError("start is None but has to be defined for exporting counts") + if end is None: + raise ValueError("end is None but has to be defined for exporting counts") + if modes is None: + raise ValueError("modes is None but has to be defined for exporting counts") + interval: int = DEFAULT_COUNTING_INTERVAL_IN_MINUTES + if event_list_output_file.stem == DEFAULT_EVENTLIST_FILE_STEM: + output_file_stem = DEFAULT_COUNTS_FILE_STEM + else: + output_file_stem = ( + f"{event_list_output_file.stem}_{DEFAULT_COUNTS_FILE_STEM}" + ) + output_file = event_list_output_file.with_stem(output_file_stem).with_suffix( + f".{DEFAULT_COUNTS_FILE_TYPE}" + ) + counting_specificaation = CountingSpecificationDto( + start=start, + end=end, + modes=list(modes), + interval_in_minutes=interval, + output_file=str(output_file), + output_format="CSV", + ) + self._export_counts.export(specification=counting_specificaation) + logger().info(f"Counts saved at {str(output_file)}") diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py index 800263958..fa4c56cc2 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py @@ -34,6 +34,7 @@ MultipleSectionsSelected, OTAnalyticsApplication, ) +from OTAnalytics.application.config import DEFAULT_COUNTING_INTERVAL_IN_MINUTES from OTAnalytics.application.datastore import FlowParser, NoSectionsToSave from OTAnalytics.application.logger import logger from OTAnalytics.application.use_cases.config import MissingDate @@ -1386,7 +1387,7 @@ def export_counts(self) -> None: end = self._application._tracks_metadata.last_detection_occurrence modes = list(self._application._tracks_metadata.classifications) default_values: dict = { - INTERVAL: 15, + INTERVAL: DEFAULT_COUNTING_INTERVAL_IN_MINUTES, START: start, END: end, EXPORT_FORMAT: default_format, diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/frame_analysis.py b/OTAnalytics/plugin_ui/customtkinter_gui/frame_analysis.py index 5eff5be48..e6d9ed947 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/frame_analysis.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/frame_analysis.py @@ -5,8 +5,8 @@ from OTAnalytics.adapter_ui.view_model import ViewModel from OTAnalytics.application.config import ( + DEFAULT_EVENTLIST_FILE_STEM, DEFAULT_EVENTLIST_FILE_TYPE, - DEFAULT_EVENTLIST_SAVE_NAME, ) from OTAnalytics.application.logger import logger from OTAnalytics.plugin_ui.customtkinter_gui.constants import PADX, PADY, STICKY @@ -92,7 +92,7 @@ def _save_eventlist(self) -> None: title="Save event list file as", filetypes=[("events file", "*.otevents")], defaultextension=".otevents", - initialfile=f"{DEFAULT_EVENTLIST_SAVE_NAME}.{DEFAULT_EVENTLIST_FILE_TYPE}", + initialfile=f"{DEFAULT_EVENTLIST_FILE_STEM}.{DEFAULT_EVENTLIST_FILE_TYPE}", ) if not file: return diff --git a/OTAnalytics/plugin_ui/main_application.py b/OTAnalytics/plugin_ui/main_application.py index 6fff67b96..58737a10b 100644 --- a/OTAnalytics/plugin_ui/main_application.py +++ b/OTAnalytics/plugin_ui/main_application.py @@ -51,6 +51,7 @@ AddEvents, ClearEventRepository, ) +from OTAnalytics.application.use_cases.flow_repository import AddFlow from OTAnalytics.application.use_cases.generate_flows import ( ArrowFlowNameGenerator, CrossProductFlowGenerator, @@ -286,11 +287,13 @@ def start_gui(self) -> None: def start_cli(self, cli_args: CliArguments) -> None: track_repository = self._create_track_repository() section_repository = self._create_section_repository() + flow_repository = self._create_flow_repository() track_parser = self._create_track_parser(track_repository) flow_parser = self._create_flow_parser() event_list_parser = self._create_event_list_parser() event_repository = self._create_event_repository() add_section = AddSection(section_repository) + add_flow = AddFlow(flow_repository) add_events = AddEvents(event_repository) get_all_tracks = GetAllTracks(track_repository) create_events = self._create_use_case_create_events( @@ -298,6 +301,9 @@ def start_cli(self, cli_args: CliArguments) -> None: ) add_all_tracks = AddAllTracks(track_repository) clear_all_tracks = ClearAllTracks(track_repository) + export_counts = self._create_export_counts( + event_repository, flow_repository, track_repository + ) OTAnalyticsCli( cli_args, track_parser=track_parser, @@ -306,7 +312,9 @@ def start_cli(self, cli_args: CliArguments) -> None: event_repository=event_repository, add_section=add_section, create_events=create_events, + export_counts=export_counts, add_all_tracks=add_all_tracks, + add_flow=add_flow, clear_all_tracks=clear_all_tracks, progressbar=TqdmBuilder(), ).start() diff --git a/tests/OTAnalytics/plugin_ui/test_cli.py b/tests/OTAnalytics/plugin_ui/test_cli.py index 04b6bdeb0..0e50b8511 100644 --- a/tests/OTAnalytics/plugin_ui/test_cli.py +++ b/tests/OTAnalytics/plugin_ui/test_cli.py @@ -7,9 +7,16 @@ import pytest from OTAnalytics.adapter_ui.default_values import TRACK_LENGTH_LIMIT +from OTAnalytics.application.analysis.traffic_counting import ( + ExportCounts, + ExportTrafficCounting, + FilterBySectionEnterEvent, + SimpleRoadUserAssigner, + SimpleTaggerFactory, +) from OTAnalytics.application.config import ( + DEFAULT_EVENTLIST_FILE_STEM, DEFAULT_EVENTLIST_FILE_TYPE, - DEFAULT_EVENTLIST_SAVE_NAME, DEFAULT_TRACK_FILE_TYPE, ) from OTAnalytics.application.datastore import EventListParser, FlowParser, TrackParser @@ -23,6 +30,7 @@ AddEvents, ClearEventRepository, ) +from OTAnalytics.application.use_cases.flow_repository import AddFlow, FlowRepository from OTAnalytics.application.use_cases.section_repository import AddSection from OTAnalytics.application.use_cases.track_repository import ( AddAllTracks, @@ -41,6 +49,10 @@ from OTAnalytics.plugin_intersect_parallelization.multiprocessing import ( MultiprocessingIntersectParallelization, ) +from OTAnalytics.plugin_parser.export import ( + FillZerosExporterFactory, + SimpleExporterFactory, +) from OTAnalytics.plugin_parser.otvision_parser import ( OtEventListParser, OtFlowParser, @@ -124,7 +136,9 @@ class TestOTAnalyticsCli: EVENT_LIST_PARSER: str = "event_list_parser" EVENT_REPOSITORY: str = "event_repository" ADD_SECTION: str = "add_section" + ADD_FLOW: str = "add_flow" CREATE_EVENTS: str = "create_events" + EXPORT_COUNTS: str = "export_counts" ADD_ALL_TRACKS: str = "add_all_tracks" CLEAR_ALL_TRACKS: str = "clear_all_tracks" PROGRESSBAR: str = "progressbar" @@ -137,7 +151,9 @@ def mock_cli_dependencies(self) -> dict[str, Any]: self.EVENT_LIST_PARSER: Mock(spec=EventListParser), self.EVENT_REPOSITORY: Mock(spec=EventRepository), self.ADD_SECTION: Mock(spec=AddSection), + self.ADD_FLOW: Mock(spec=AddFlow), self.CREATE_EVENTS: Mock(spec=CreateEvents), + self.EXPORT_COUNTS: Mock(spec=ExportCounts), self.ADD_ALL_TRACKS: Mock(spec=AddAllTracks), self.CLEAR_ALL_TRACKS: Mock(spec=ClearAllTracks), self.PROGRESSBAR: Mock(spec=NoProgressbarBuilder), @@ -148,6 +164,7 @@ def cli_dependencies(self) -> dict[str, Any]: track_repository = TrackRepository() section_repository = SectionRepository() event_repository = EventRepository() + flow_repository = FlowRepository() add_events = AddEvents(event_repository) get_all_tracks = GetAllTracks(track_repository) @@ -172,6 +189,13 @@ def cli_dependencies(self) -> dict[str, Any]: create_events = CreateEvents( clear_event_repository, create_intersection_events, create_scene_events ) + export_counts = ExportTrafficCounting( + event_repository, + flow_repository, + FilterBySectionEnterEvent(SimpleRoadUserAssigner()), + SimpleTaggerFactory(track_repository), + FillZerosExporterFactory(SimpleExporterFactory()), + ) return { self.TRACK_PARSER: OttrkParser( CalculateTrackClassificationByMaxConfidence(), @@ -182,7 +206,9 @@ def cli_dependencies(self) -> dict[str, Any]: self.EVENT_LIST_PARSER: OtEventListParser(), self.EVENT_REPOSITORY: event_repository, self.ADD_SECTION: AddSection(section_repository), + self.ADD_FLOW: AddFlow(flow_repository), self.CREATE_EVENTS: create_events, + self.EXPORT_COUNTS: export_counts, self.ADD_ALL_TRACKS: add_all_tracks, self.CLEAR_ALL_TRACKS: clear_all_tracks, self.PROGRESSBAR: NoProgressbarBuilder(), @@ -202,7 +228,9 @@ def test_init(self, mock_cli_dependencies: dict[str, Any]) -> None: assert cli._flow_parser == mock_cli_dependencies[self.FLOW_PARSER] assert cli._event_list_parser == mock_cli_dependencies[self.EVENT_LIST_PARSER] assert cli._add_section == mock_cli_dependencies[self.ADD_SECTION] + assert cli._add_flow == mock_cli_dependencies[self.ADD_FLOW] assert cli._create_events == mock_cli_dependencies[self.CREATE_EVENTS] + assert cli._export_counts == mock_cli_dependencies[self.EXPORT_COUNTS] assert cli._add_all_tracks == mock_cli_dependencies[self.ADD_ALL_TRACKS] assert cli._clear_all_tracks == mock_cli_dependencies[self.CLEAR_ALL_TRACKS] assert cli._progressbar == mock_cli_dependencies[self.PROGRESSBAR] @@ -310,7 +338,7 @@ def test_parse_sections_file_wrong_filetype(self, test_data_tmp_dir: Path) -> No @pytest.mark.parametrize( "eventlist_filename,expected_filename", - [("my_events", "my_events"), ("", DEFAULT_EVENTLIST_SAVE_NAME)], + [("my_events", "my_events"), ("", DEFAULT_EVENTLIST_FILE_STEM)], ) def test_determine_eventlist_save_path( self, From 5716338337ee8d33d9137629d422f57ec78d2bad Mon Sep 17 00:00:00 2001 From: martinbaerwolff Date: Fri, 1 Sep 2023 18:12:13 +0200 Subject: [PATCH 2/5] Export counts with full date and time --- .../application/analysis/traffic_counting.py | 4 ++-- .../analysis/test_traffic_counting.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/OTAnalytics/application/analysis/traffic_counting.py b/OTAnalytics/application/analysis/traffic_counting.py index fa5509a9b..fb15d1733 100644 --- a/OTAnalytics/application/analysis/traffic_counting.py +++ b/OTAnalytics/application/analysis/traffic_counting.py @@ -169,8 +169,8 @@ def create_mode_tag(tag: str) -> Tag: def create_timeslot_tag(start_of_time_slot: datetime, interval: timedelta) -> Tag: end_of_time_slot = start_of_time_slot + interval - serialized_start = start_of_time_slot.strftime("%H:%M") - serialized_end = end_of_time_slot.strftime("%H:%M") + serialized_start = start_of_time_slot.strftime(r"%Y-%m-%d %H:%M:%S") + serialized_end = end_of_time_slot.strftime(r"%Y-%m-%d %H:%M:%S") return MultiTag( frozenset( [ diff --git a/tests/OTAnalytics/application/analysis/test_traffic_counting.py b/tests/OTAnalytics/application/analysis/test_traffic_counting.py index aa8400b39..e20dc1065 100644 --- a/tests/OTAnalytics/application/analysis/test_traffic_counting.py +++ b/tests/OTAnalytics/application/analysis/test_traffic_counting.py @@ -145,7 +145,7 @@ def create_event( road_user_type="car", hostname="my_hostname", occurrence=datetime( - 2020, 1, 1, 0, minute, second=real_seconds, tzinfo=timezone.utc + 2000, 1, 1, 0, minute, second=real_seconds, tzinfo=timezone.utc ), frame_number=1, section_id=section, @@ -361,8 +361,8 @@ def create_tagging_test_cases(self) -> list[tuple[RoadUserAssignment, Tag]]: first_result = MultiTag( frozenset( [ - SingleTag(level=LEVEL_START_TIME, id="00:00"), - SingleTag(level=LEVEL_END_TIME, id="00:01"), + SingleTag(level=LEVEL_START_TIME, id="2000-01-01 00:00:00"), + SingleTag(level=LEVEL_END_TIME, id="2000-01-01 00:01:00"), ] ) ) @@ -377,8 +377,8 @@ def create_tagging_test_cases(self) -> list[tuple[RoadUserAssignment, Tag]]: second_result = MultiTag( frozenset( [ - SingleTag(level=LEVEL_START_TIME, id="00:00"), - SingleTag(level=LEVEL_END_TIME, id="00:01"), + SingleTag(level=LEVEL_START_TIME, id="2000-01-01 00:00:00"), + SingleTag(level=LEVEL_END_TIME, id="2000-01-01 00:01:00"), ] ) ) @@ -393,8 +393,8 @@ def create_tagging_test_cases(self) -> list[tuple[RoadUserAssignment, Tag]]: third_result = MultiTag( frozenset( [ - SingleTag(level=LEVEL_START_TIME, id="00:01"), - SingleTag(level=LEVEL_END_TIME, id="00:02"), + SingleTag(level=LEVEL_START_TIME, id="2000-01-01 00:01:00"), + SingleTag(level=LEVEL_END_TIME, id="2000-01-01 00:02:00"), ] ) ) @@ -409,8 +409,8 @@ def create_tagging_test_cases(self) -> list[tuple[RoadUserAssignment, Tag]]: forth_result = MultiTag( frozenset( [ - SingleTag(level=LEVEL_START_TIME, id="00:02"), - SingleTag(level=LEVEL_END_TIME, id="00:03"), + SingleTag(level=LEVEL_START_TIME, id="2000-01-01 00:02:00"), + SingleTag(level=LEVEL_END_TIME, id="2000-01-01 00:03:00"), ] ) ) From 0183813a5a453fe89a802c097b1e0a8197bd88cb Mon Sep 17 00:00:00 2001 From: martinbaerwolff Date: Fri, 1 Sep 2023 18:22:23 +0200 Subject: [PATCH 3/5] Set column order in dataframe for export of counts --- OTAnalytics/plugin_parser/export.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/OTAnalytics/plugin_parser/export.py b/OTAnalytics/plugin_parser/export.py index 53f7c97e8..e889ac133 100644 --- a/OTAnalytics/plugin_parser/export.py +++ b/OTAnalytics/plugin_parser/export.py @@ -5,6 +5,10 @@ from pandas import DataFrame from OTAnalytics.application.analysis.traffic_counting import ( + LEVEL_CLASSIFICATION, + LEVEL_END_TIME, + LEVEL_FLOW, + LEVEL_START_TIME, Count, Exporter, ExporterFactory, @@ -28,8 +32,23 @@ def __init__(self, output_file: str) -> None: def export(self, counts: Count) -> None: logger().info(f"Exporting counts {counts} to {self._output_file}") dataframe = self.__create_data_frame(counts) + dataframe = self._set_column_order(dataframe) dataframe.to_csv(self.__create_path(), index=False) + def _set_column_order(self, dataframe: DataFrame) -> DataFrame: + desired_columns_order = [ + LEVEL_START_TIME, + LEVEL_END_TIME, + LEVEL_CLASSIFICATION, + LEVEL_FLOW, + ] + dataframe = dataframe[ + desired_columns_order + + [col for col in dataframe.columns if col not in desired_columns_order] + ] + + return dataframe + def __create_data_frame(self, counts: Count) -> DataFrame: transformed = counts.to_dict() indexed: list[dict] = [] From 361451bcffc9df23021af7dc2a6a8f8cb46e2f81 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Tue, 5 Sep 2023 15:43:35 +0200 Subject: [PATCH 4/5] Fix classes in tracks metadata not being filled in cli --- .../application/use_cases/track_repository.py | 21 +++++- OTAnalytics/domain/track.py | 8 +++ OTAnalytics/plugin_ui/cli.py | 11 +-- OTAnalytics/plugin_ui/main_application.py | 3 + .../use_cases/test_track_repository.py | 12 +++- tests/OTAnalytics/domain/test_track.py | 68 +++++++++---------- tests/OTAnalytics/plugin_ui/test_cli.py | 5 ++ 7 files changed, 88 insertions(+), 40 deletions(-) diff --git a/OTAnalytics/application/use_cases/track_repository.py b/OTAnalytics/application/use_cases/track_repository.py index a0b95d2e6..898886e98 100644 --- a/OTAnalytics/application/use_cases/track_repository.py +++ b/OTAnalytics/application/use_cases/track_repository.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Iterable -from OTAnalytics.domain.track import Track, TrackRepository +from OTAnalytics.domain.track import Track, TrackId, TrackRepository class GetAllTracks: @@ -18,6 +18,25 @@ def __call__(self) -> list[Track]: return self._track_repository.get_all() +class GetAllTrackIds: + """Get all track ids from the track repository. + + Args: + track_repository (TrackRepository): the track repository to get the ids from. + """ + + def __init__(self, track_repository: TrackRepository) -> None: + self._track_repository = track_repository + + def __call__(self) -> Iterable[TrackId]: + """Get all track ids from the track repository. + + Returns: + Iterable[TrackId]: the track ids. + """ + return self._track_repository.get_all_ids() + + class AddAllTracks: """Add tracks to the track repository. diff --git a/OTAnalytics/domain/track.py b/OTAnalytics/domain/track.py index ea7e9069f..662e8e93f 100644 --- a/OTAnalytics/domain/track.py +++ b/OTAnalytics/domain/track.py @@ -419,6 +419,14 @@ def get_all(self) -> list[Track]: """ return list(self._tracks.values()) + def get_all_ids(self) -> Iterable[TrackId]: + """Get all track ids in this repository. + + Returns: + Iterable[TrackId]: the track ids. + """ + return self._tracks.keys() + def clear(self) -> None: """ Clear the repository and inform the observers about the empty repository. diff --git a/OTAnalytics/plugin_ui/cli.py b/OTAnalytics/plugin_ui/cli.py index cab15729e..a59bfaf26 100644 --- a/OTAnalytics/plugin_ui/cli.py +++ b/OTAnalytics/plugin_ui/cli.py @@ -25,6 +25,7 @@ from OTAnalytics.application.use_cases.track_repository import ( AddAllTracks, ClearAllTracks, + GetAllTrackIds, ) from OTAnalytics.domain.event import EventRepository from OTAnalytics.domain.flow import Flow @@ -136,6 +137,7 @@ def __init__( create_events: CreateEvents, export_counts: ExportCounts, add_all_tracks: AddAllTracks, + get_all_track_ids: GetAllTrackIds, clear_all_tracks: ClearAllTracks, progressbar: ProgressbarBuilder, ) -> None: @@ -151,6 +153,7 @@ def __init__( self._create_events = create_events self._export_counts = export_counts self._add_all_tracks = add_all_tracks + self._get_all_track_ids = get_all_track_ids self._clear_all_tracks = clear_all_tracks self._progressbar = progressbar @@ -312,7 +315,7 @@ def _get_sections_file(file: str) -> Path: def _do_export_counts(self, event_list_output_file: Path) -> None: logger().info("Create counts ...") tracks_metadata = TracksMetadata(self._add_all_tracks._track_repository) - tracks_metadata.notify_tracks([]) + tracks_metadata.notify_tracks(list(self._get_all_track_ids())) start = tracks_metadata.first_detection_occurrence end = tracks_metadata.last_detection_occurrence modes = tracks_metadata.classifications @@ -332,7 +335,7 @@ def _do_export_counts(self, event_list_output_file: Path) -> None: output_file = event_list_output_file.with_stem(output_file_stem).with_suffix( f".{DEFAULT_COUNTS_FILE_TYPE}" ) - counting_specificaation = CountingSpecificationDto( + counting_specification = CountingSpecificationDto( start=start, end=end, modes=list(modes), @@ -340,5 +343,5 @@ def _do_export_counts(self, event_list_output_file: Path) -> None: output_file=str(output_file), output_format="CSV", ) - self._export_counts.export(specification=counting_specificaation) - logger().info(f"Counts saved at {str(output_file)}") + self._export_counts.export(specification=counting_specification) + logger().info(f"Counts saved at {output_file}") diff --git a/OTAnalytics/plugin_ui/main_application.py b/OTAnalytics/plugin_ui/main_application.py index 58737a10b..a63b778ee 100644 --- a/OTAnalytics/plugin_ui/main_application.py +++ b/OTAnalytics/plugin_ui/main_application.py @@ -75,6 +75,7 @@ AddAllTracks, ClearAllTracks, GetAllTrackFiles, + GetAllTrackIds, GetAllTracks, ) from OTAnalytics.domain.event import EventRepository, SceneEventBuilder @@ -296,6 +297,7 @@ def start_cli(self, cli_args: CliArguments) -> None: add_flow = AddFlow(flow_repository) add_events = AddEvents(event_repository) get_all_tracks = GetAllTracks(track_repository) + get_all_track_ids = GetAllTrackIds(track_repository) create_events = self._create_use_case_create_events( section_repository, event_repository, get_all_tracks, add_events ) @@ -314,6 +316,7 @@ def start_cli(self, cli_args: CliArguments) -> None: create_events=create_events, export_counts=export_counts, add_all_tracks=add_all_tracks, + get_all_track_ids=get_all_track_ids, add_flow=add_flow, clear_all_tracks=clear_all_tracks, progressbar=TqdmBuilder(), diff --git a/tests/OTAnalytics/application/use_cases/test_track_repository.py b/tests/OTAnalytics/application/use_cases/test_track_repository.py index b04abed4e..0feea9ff8 100644 --- a/tests/OTAnalytics/application/use_cases/test_track_repository.py +++ b/tests/OTAnalytics/application/use_cases/test_track_repository.py @@ -7,9 +7,10 @@ AddAllTracks, ClearAllTracks, GetAllTrackFiles, + GetAllTrackIds, GetAllTracks, ) -from OTAnalytics.domain.track import Detection, Track, TrackRepository +from OTAnalytics.domain.track import Detection, Track, TrackId, TrackRepository @pytest.fixture @@ -21,6 +22,7 @@ def tracks() -> list[Mock]: def track_repository(tracks: list[Mock]) -> Mock: repository = Mock(spec=TrackRepository) repository.get_all.return_value = tracks + repository.get_all_ids.return_value = {TrackId(1), TrackId(2)} return repository @@ -32,6 +34,14 @@ def test_get_all_tracks(self, track_repository: Mock, tracks: list[Mock]) -> Non track_repository.get_all.assert_called_once() +class TestGetAllTrackIds: + def test_get_all_tracks(self, track_repository: Mock) -> None: + get_all_track_ids = GetAllTrackIds(track_repository) + result_tracks = get_all_track_ids() + assert result_tracks == {TrackId(1), TrackId(2)} + track_repository.get_all_ids.assert_called_once() + + class TestAddAllTracks: def test_add_all_tracks(self, track_repository: Mock, tracks: list[Track]) -> None: add_all_tracks = AddAllTracks(track_repository) diff --git a/tests/OTAnalytics/domain/test_track.py b/tests/OTAnalytics/domain/test_track.py index c27b683a9..66c20c26d 100644 --- a/tests/OTAnalytics/domain/test_track.py +++ b/tests/OTAnalytics/domain/test_track.py @@ -228,18 +228,27 @@ def test_calculate(self) -> None: class TestTrackRepository: - def test_add(self) -> None: - track_id = TrackId(1) - track = Mock() - track.id = track_id + @pytest.fixture + def track_1(self) -> Mock: + track = Mock(spec=Track) + track.id = TrackId(1) + return track + + @pytest.fixture + def track_2(self) -> Mock: + track = Mock(spec=Track) + track.id = TrackId(2) + return track + + def test_add(self, track_1: Mock) -> None: observer = Mock(spec=TrackListObserver) repository = TrackRepository() repository.register_tracks_observer(observer) - repository.add(track) + repository.add(track_1) - assert track in repository.get_all() - observer.notify_tracks.assert_called_with([track_id]) + assert track_1 in repository.get_all() + observer.notify_tracks.assert_called_with([track_1.id]) def test_add_nothing(self) -> None: observer = Mock(spec=TrackListObserver) @@ -251,50 +260,41 @@ def test_add_nothing(self) -> None: assert 0 == len(repository.get_all()) observer.notify_tracks.assert_not_called() - def test_add_all(self) -> None: - first_id = TrackId(1) - second_id = TrackId(2) - first_track = Mock() - first_track.id = first_id - second_track = Mock() - second_track.id = second_id + def test_add_all(self, track_1: Mock, track_2: Mock) -> None: observer = Mock(spec=TrackListObserver) repository = TrackRepository() repository.register_tracks_observer(observer) - repository.add_all([first_track, second_track]) + repository.add_all([track_1, track_2]) - assert first_track in repository.get_all() - assert second_track in repository.get_all() - observer.notify_tracks.assert_called_with([first_id, second_id]) + assert track_1 in repository.get_all() + assert track_2 in repository.get_all() + observer.notify_tracks.assert_called_with([track_1.id, track_2.id]) - def test_get_by_id(self) -> None: - first_track = Mock() - first_track.id.return_value = TrackId(1) - second_track = Mock() + def test_get_by_id(self, track_1: Mock, track_2: Mock) -> None: repository = TrackRepository() - repository.add_all([first_track, second_track]) + repository.add_all([track_1, track_2]) - returned = repository.get_for(first_track.id) + returned = repository.get_for(track_1.id) - assert returned == first_track + assert returned == track_1 - def test_clear(self) -> None: - first_id = TrackId(1) - second_id = TrackId(2) - first_track = Mock() - first_track.id = first_id - second_track = Mock() - second_track.id = second_id + def test_clear(self, track_1: Mock, track_2: Mock) -> None: observer = Mock(spec=TrackListObserver) repository = TrackRepository() repository.register_tracks_observer(observer) - repository.add_all([first_track, second_track]) + repository.add_all([track_1, track_2]) repository.clear() assert not list(repository.get_all()) assert observer.notify_tracks.call_args_list == [ - call([first_id, second_id]), + call([track_1.id, track_2.id]), call([]), ] + + def test_get_all_ids(self, track_1: Mock, track_2: Mock) -> None: + repository = TrackRepository() + repository.add_all([track_1, track_2]) + ids = repository.get_all_ids() + assert set(ids) == {track_1.id, track_2.id} diff --git a/tests/OTAnalytics/plugin_ui/test_cli.py b/tests/OTAnalytics/plugin_ui/test_cli.py index 0e50b8511..9ab6f2e39 100644 --- a/tests/OTAnalytics/plugin_ui/test_cli.py +++ b/tests/OTAnalytics/plugin_ui/test_cli.py @@ -35,6 +35,7 @@ from OTAnalytics.application.use_cases.track_repository import ( AddAllTracks, ClearAllTracks, + GetAllTrackIds, GetAllTracks, ) from OTAnalytics.domain.event import EventRepository, SceneEventBuilder @@ -140,6 +141,7 @@ class TestOTAnalyticsCli: CREATE_EVENTS: str = "create_events" EXPORT_COUNTS: str = "export_counts" ADD_ALL_TRACKS: str = "add_all_tracks" + GET_ALL_TRACK_IDS: str = "get_all_track_ids" CLEAR_ALL_TRACKS: str = "clear_all_tracks" PROGRESSBAR: str = "progressbar" @@ -155,6 +157,7 @@ def mock_cli_dependencies(self) -> dict[str, Any]: self.CREATE_EVENTS: Mock(spec=CreateEvents), self.EXPORT_COUNTS: Mock(spec=ExportCounts), self.ADD_ALL_TRACKS: Mock(spec=AddAllTracks), + self.GET_ALL_TRACK_IDS: Mock(spec=GetAllTrackIds), self.CLEAR_ALL_TRACKS: Mock(spec=ClearAllTracks), self.PROGRESSBAR: Mock(spec=NoProgressbarBuilder), } @@ -168,6 +171,7 @@ def cli_dependencies(self) -> dict[str, Any]: add_events = AddEvents(event_repository) get_all_tracks = GetAllTracks(track_repository) + get_all_track_ids = GetAllTrackIds(track_repository) add_all_tracks = AddAllTracks(track_repository) clear_all_tracks = ClearAllTracks(track_repository) @@ -210,6 +214,7 @@ def cli_dependencies(self) -> dict[str, Any]: self.CREATE_EVENTS: create_events, self.EXPORT_COUNTS: export_counts, self.ADD_ALL_TRACKS: add_all_tracks, + self.GET_ALL_TRACK_IDS: get_all_track_ids, self.CLEAR_ALL_TRACKS: clear_all_tracks, self.PROGRESSBAR: NoProgressbarBuilder(), } From 4a2a2dc40d5028d4166ec6afd674d1d88ef754c9 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:06:29 +0200 Subject: [PATCH 5/5] Format to fix lint errors --- OTAnalytics/application/use_cases/track_repository.py | 7 ++++++- .../application/use_cases/test_track_repository.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/OTAnalytics/application/use_cases/track_repository.py b/OTAnalytics/application/use_cases/track_repository.py index 671a0b821..0e985dae8 100644 --- a/OTAnalytics/application/use_cases/track_repository.py +++ b/OTAnalytics/application/use_cases/track_repository.py @@ -1,7 +1,12 @@ from pathlib import Path from typing import Iterable -from OTAnalytics.domain.track import Track, TrackFileRepository, TrackId, TrackRepository +from OTAnalytics.domain.track import ( + Track, + TrackFileRepository, + TrackId, + TrackRepository, +) class GetAllTracks: diff --git a/tests/OTAnalytics/application/use_cases/test_track_repository.py b/tests/OTAnalytics/application/use_cases/test_track_repository.py index 2e8d12b43..60e64563b 100644 --- a/tests/OTAnalytics/application/use_cases/test_track_repository.py +++ b/tests/OTAnalytics/application/use_cases/test_track_repository.py @@ -10,7 +10,12 @@ GetAllTrackIds, GetAllTracks, ) -from OTAnalytics.domain.track import Detection, Track, TrackFileRepository, TrackId, TrackRepository +from OTAnalytics.domain.track import ( + Track, + TrackFileRepository, + TrackId, + TrackRepository, +) @pytest.fixture