diff --git a/hours/filters.py b/hours/filters.py index f8bc66b0..85f70f45 100644 --- a/hours/filters.py +++ b/hours/filters.py @@ -121,6 +121,7 @@ def filter(self, qs, value): class DatePeriodFilter(filters.FilterSet): + data_source = filters.CharFilter(field_name="origins__data_source") resource = filters.CharFilter(method="resource_filter") start_date = MaybeRelativeNullableDateFilter() end_date = MaybeRelativeNullableDateFilter() diff --git a/hours/tests/conftest.py b/hours/tests/conftest.py index b1844e3d..a8b6deeb 100644 --- a/hours/tests/conftest.py +++ b/hours/tests/conftest.py @@ -16,6 +16,7 @@ from hours.models import ( DataSource, DatePeriod, + PeriodOrigin, Resource, ResourceOrigin, Rule, @@ -74,11 +75,38 @@ class Meta: @register class DatePeriodFactory(factory.django.DjangoModelFactory): + class Meta: + model = DatePeriod + name = factory.LazyAttribute(lambda x: "DP-" + faker.pystr()) start_date = factory.LazyAttribute(lambda x: faker.date()) + @factory.post_generation + def origins(self, create, extracted, **__): + if not create or not extracted: + return + + for origin in extracted: + self.origins.add(origin) + + @factory.post_generation + def data_sources(self, create, extracted, **__): + if not create or not extracted: + return + + for data_source in extracted: + # Create a new origin for each data source, since data sources + # are accessed through origins. + self.origins.add(PeriodOriginFactory(data_source=data_source, period=self)) + + +@register +class PeriodOriginFactory(factory.django.DjangoModelFactory): class Meta: - model = DatePeriod + model = PeriodOrigin + + origin_id = factory.LazyAttribute(lambda x: "OID-" + faker.pystr()) + data_source = factory.SubFactory(DataSourceFactory) @register diff --git a/hours/tests/test_dateperiod_api.py b/hours/tests/test_dateperiod_api.py index f83d2ce7..ed5d4e2f 100644 --- a/hours/tests/test_dateperiod_api.py +++ b/hours/tests/test_dateperiod_api.py @@ -7,6 +7,7 @@ from hours.enums import RuleContext, RuleSubject, State, Weekday from hours.models import DatePeriod +from hours.tests.utils import assert_response_status_code @pytest.mark.django_db @@ -106,6 +107,31 @@ def test_list_date_periods_filter_by_resource( assert response.data[0]["id"] == date_period.id +@pytest.mark.django_db +def test_list_date_periods_filter_by_data_source( + admin_client, resource, data_source_factory, date_period_factory +): + expected_data_source = data_source_factory() + expected_date_period = date_period_factory( + resource=resource, + data_sources=[expected_data_source], + ) + date_period_factory( + resource=resource, + data_sources=[data_source_factory()], + ) + + url = reverse("date_period-list") + + response = admin_client.get( + url, data={"resource": resource.id, "data_source": [expected_data_source.id]} + ) + + assert_response_status_code(response, 200) + assert len(response.data) == 1 + assert response.data[0]["id"] == expected_date_period.id + + @pytest.mark.django_db def test_list_date_periods_filter_start_date_lte( admin_client, resource, date_period_factory diff --git a/hours/viewsets.py b/hours/viewsets.py index aa0422a6..2e9a680b 100644 --- a/hours/viewsets.py +++ b/hours/viewsets.py @@ -656,6 +656,12 @@ def copy_date_periods(self, request, pk=None): OpenApiParameter.QUERY, description="Filter by resource id or multiple resource ids (comma-separated)", # noqa ), + OpenApiParameter( + "data_source", + OpenApiTypes.STR, + OpenApiParameter.QUERY, + description="Filter by data source", + ), OpenApiParameter( "end_date", OpenApiTypes.DATE,