diff --git a/app/app/settings.py b/app/app/settings.py index 12cee6d..ba4f1a5 100644 --- a/app/app/settings.py +++ b/app/app/settings.py @@ -67,6 +67,7 @@ def environ_list(key: str, default=""): "core", "users", "users.authentication", + "querycsv", "analytics", "clubs", "clubs.polls", @@ -228,10 +229,22 @@ def environ_list(key: str, default=""): LOGIN_URL = "/auth/login/" AUTHENTICATION_BACKENDS = ["core.backend.CustomBackend"] + +######################## +# == AWS S3 Config == # +######################## +S3_STORAGE_BACKEND = bool(int(os.environ.get("S3_STORAGE_BACKEND", 1))) +if S3_STORAGE_BACKEND is True: + DEFAULT_FILE_STORAGE = "storages.backends.s3boto3.S3Boto3Storage" + +AWS_DEFAULT_ACL = "public-read" +AWS_STORAGE_BUCKET_NAME = os.environ.get("S3_STORAGE_BUCKET_NAME", "") +AWS_S3_REGION_NAME = os.environ.get("S3_STORAGE_BUCKET_REGION", "us-east-1") +AWS_QUERYSTRING_AUTH = False + ###################### # == Email Config == # ###################### - CONSOLE_EMAIL_BACKEND = environ_bool("CONSOLE_EMAIL_BACKEND", 0) if CONSOLE_EMAIL_BACKEND: @@ -264,6 +277,14 @@ def environ_list(key: str, default=""): # Custom schedules CELERY_BEAT_SCHEDULE = {} +CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.redis.RedisCache", + "LOCATION": os.environ.get("DJANGO_REDIS_URL"), + }, +} + + ############################### # == Environment Overrides == # ############################### @@ -291,7 +312,12 @@ def environ_list(key: str, default=""): if TESTING: - INSTALLED_APPS.append("core.mock") + # Ensure tasks execute immediately + CELERY_TASK_ALWAYS_EAGER = True # Force disable notifications EMAIL_HOST_PASSWORD = None + +if DEV or TESTING: + # Allow for migrations during dev mode + INSTALLED_APPS.append("core.mock") diff --git a/app/clubs/admin.py b/app/clubs/admin.py index 9887dba..e3a0442 100644 --- a/app/clubs/admin.py +++ b/app/clubs/admin.py @@ -12,7 +12,9 @@ Team, TeamMembership, ) +from clubs.serializers import ClubCsvSerializer, ClubMembershipCsvSerializer from clubs.services import ClubService +from core.abstracts.admin import ModelAdminBase class ClubMembershipInlineAdmin(admin.StackedInline): @@ -29,9 +31,11 @@ class ClubRoleInlineAdmin(admin.StackedInline): extra = 0 -class ClubAdmin(admin.ModelAdmin): +class ClubAdmin(ModelAdminBase): """Admin config for Clubs.""" + csv_serializer_class = ClubCsvSerializer + inlines = ( ClubRoleInlineAdmin, ClubMembershipInlineAdmin, @@ -46,16 +50,6 @@ class ClubAdmin(admin.ModelAdmin): def members_count(self, obj): return obj.memberships.count() - def get_queryset(self, request): - user_club_ids = request.user.club_memberships.all().values_list("club__id") - - queryset = super().get_queryset(request) - - if request.user.is_superuser: - return queryset - - return queryset.filter(id__in=user_club_ids) - class RecurringEventAdmin(admin.ModelAdmin): @@ -139,7 +133,24 @@ class TeamAdmin(admin.ModelAdmin): inlines = (TeamMembershipInlineAdmin,) +class ClubMembershipAdmin(ModelAdminBase): + """Manage club memberships in admin.""" + + csv_serializer_class = ClubMembershipCsvSerializer + + list_display = ( + "__str__", + "club", + "club_roles", + "created_at", + ) + + def club_roles(self, obj): + return ", ".join(str(role) for role in list(obj.roles.all())) + + admin.site.register(Club, ClubAdmin) admin.site.register(Event, EventAdmin) admin.site.register(RecurringEvent, RecurringEventAdmin) admin.site.register(Team, TeamAdmin) +admin.site.register(ClubMembership, ClubMembershipAdmin) diff --git a/app/clubs/polls/migrations/0003_alter_choiceinput_options_and_more.py b/app/clubs/polls/migrations/0003_alter_choiceinput_options_and_more.py new file mode 100644 index 0000000..fbee622 --- /dev/null +++ b/app/clubs/polls/migrations/0003_alter_choiceinput_options_and_more.py @@ -0,0 +1,58 @@ +# Generated by Django 4.2.19 on 2025-02-24 03:18 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ( + "polls", + "0002_remove_choiceinputoption_unique_choiceoption_order_per_input_and_more", + ), + ] + + operations = [ + migrations.AlterModelOptions( + name="choiceinput", + options={"ordering": ["question__field", "-id"]}, + ), + migrations.AlterModelOptions( + name="choiceinputoption", + options={"ordering": ["order", "-id"]}, + ), + migrations.AlterModelOptions( + name="pollfield", + options={"ordering": ["order", "-id"]}, + ), + migrations.AlterModelOptions( + name="pollquestion", + options={"ordering": ["field", "-id"]}, + ), + migrations.AlterField( + model_name="choiceinput", + name="multiple_choice_type", + field=models.CharField( + blank=True, + choices=[ + ("select", "Multi Select Box"), + ("checkbox", "Multi Checkbox Select"), + ], + default="checkbox", + null=True, + ), + ), + migrations.AlterField( + model_name="choiceinput", + name="single_choice_type", + field=models.CharField( + blank=True, + choices=[ + ("select", "Single Dropdown Select"), + ("radio", "Single Radio Select"), + ], + default="radio", + null=True, + ), + ), + ] diff --git a/app/clubs/serializers.py b/app/clubs/serializers.py index db945f8..62cd24d 100644 --- a/app/clubs/serializers.py +++ b/app/clubs/serializers.py @@ -2,6 +2,7 @@ from clubs.models import Club, ClubMembership from core.abstracts.serializers import ModelSerializerBase +from querycsv.serializers import CsvModelSerializer from users.models import User @@ -39,6 +40,14 @@ class Meta: ] +class ClubCsvSerializer(CsvModelSerializer): + """Represents clubs in csvs.""" + + class Meta: + model = Club + fields = "__all__" + + class ClubMembershipSerializer(ModelSerializerBase): """Represents a club membership to use for CRUD operations.""" @@ -59,3 +68,17 @@ class Meta: "owner", "points", ] + + +class ClubMembershipCsvSerializer(CsvModelSerializer, ClubMembershipSerializer): + """Serialize club memberships for a csv.""" + + class Meta: + model = ClubMembership + fields = [ + *ModelSerializerBase.default_fields, + "user_id", + "club_id", + "owner", + "points", + ] diff --git a/app/core/abstracts/admin.py b/app/core/abstracts/admin.py new file mode 100644 index 0000000..20c7404 --- /dev/null +++ b/app/core/abstracts/admin.py @@ -0,0 +1,217 @@ +import logging +from functools import update_wrapper +from typing import Literal, Optional + +from django.contrib import admin +from django.db import models +from django.http import FileResponse, HttpRequest +from django.shortcuts import redirect +from django.template.response import TemplateResponse +from django.urls import reverse +from django.urls.resolvers import URLPattern +from django.utils.safestring import mark_safe + +from querycsv.serializers import CsvModelSerializer +from querycsv.services import QueryCsvService +from querycsv.views import QueryCsvViewSet +from utils.admin import get_admin_context, get_model_admin_reverse + + +class AdminBase: + """Common fields, utilities for model admin, inline admin, etc.""" + + admin_name = "admin" + + def get_admin_url( + self, + model: models.Model, + url_context: Literal[ + "changelist", "add", "history", "delete", "change" + ] = "changelist", + admin_name=None, + as_link=False, + link_text=None, + ): + """Given a model, return a link to the appropriate admin page.""" + admin_name = admin_name or self.admin_name + url = get_model_admin_reverse(admin_name, model, url_context) + + if url_context in ["change", "history", "delete"]: + url = reverse(url, args=[model.id]) + else: + url = reverse(url) + + if as_link: + return self.as_link(url, link_text or url) + + return url + + def as_link(self, url, text): + """Create anchor tag for a url.""" + + return mark_safe(f'{text}') + + +class ModelAdminBase(AdminBase, admin.ModelAdmin): + """Base class for all model admins.""" + + prefetch_related_fields = () + select_related_fields = () + readonly_fields = ( + "id", + "created_at", + "updated_at", + ) + object_tools = () + + formfield_overrides = {} + + change_list_template = "admin/core/change_list.html" + csv_serializer_class: Optional[CsvModelSerializer] = None + """Serializer to use for csv uploads""" + + ################################# + # == Django Method Overrides == # + ################################# + + def __init__(self, model: type, admin_site: admin.AdminSite | None) -> None: + super().__init__(model, admin_site) + + # If serializer is set, enable certain features + if self.csv_serializer_class is not None: + self.csv_svc = QueryCsvService(self.csv_serializer_class) + + def get_reverse(name: str = "changelist"): + return f"{self.admin_name}:{self._url_name(name)}" + + self.csv_views = QueryCsvViewSet( + serializer_class=self.csv_serializer_class, get_reverse=get_reverse + ) + + self.actions += ("download_csv",) + self.object_tools += ( + { + "url": "%s:%s_%s_upload" + % (self.admin_name, self.opts.app_label, self.opts.model_name), + "label": "Upload CSV", + }, + ) + + def get_queryset(self, request): + qs = super().get_queryset(request) + + if len(self.prefetch_related_fields) > 0: + return qs.prefetch_related(*self.prefetch_related_fields).select_related( + *self.select_related_fields + ) + else: + return qs + + def changelist_view( + self, request: HttpRequest, extra_context: dict[str, str] | None = None + ) -> TemplateResponse: + context = { + **(extra_context or {}), + "object_tools": self.object_tools, + } + + return super().changelist_view(request, extra_context=context) + + def get_urls(self) -> list[URLPattern]: + """Extends django's default functionality to add custom urls.""" + from django.urls import path + + # Start duplicated django code ############### + def wrap(view): + def wrapper(*args, **kwargs): + return self.admin_site.admin_view(view)(*args, **kwargs) + + wrapper.model_admin = self + return update_wrapper(wrapper, view) + + # End duplicated django code ################# + + # Custom Url Definitions + ######################## + urls = [ + path("upload/", wrap(self.upload_csv), name=self._url_name("upload")), + path( + "upload//headermapping", + wrap(self.map_upload_csv_headers), + name=self._url_name("upload_headermapping"), + ), + path( + "csv-template/", + wrap(self.download_csv_template), + name=self._url_name("csv_template"), + ), + path( + "csv-template/", + wrap(self.download_csv_template), + name=self._url_name("csv_template"), + ), + ] + super(ModelAdminBase, self).get_urls() + + return urls + + ############################## + # == Custom Admin Methods == # + ############################## + + def _url_name(self, url_context="changelist"): + """Get url name to reverse for this admin class.""" + info = self.opts.app_label, self.opts.model_name, url_context + + return "%s_%s_%s" % info + + def upload_csv(self, request: HttpRequest, extra_context=None): + """Custom action for uploading csvs through admin.""" + + context = {**get_admin_context(request, extra_context)} + return self.csv_views.upload_csv(request, extra_context=context) + + def map_upload_csv_headers(self, request: HttpRequest, id: int, extra_context=None): + """Given a csv upload job, allow admin to define custom field mappings.""" + + context = {**get_admin_context(request, extra_context)} + return self.csv_views.map_upload_csv_headers(request, id, extra_context=context) + + def download_csv_template(self, request: HttpRequest): + """Get template for csv uploads.""" + + include_fields = request.GET.get("fields", None) + + filepath = self.csv_svc.get_csv_template(field_types=include_fields) + return FileResponse(open(filepath, "rb")) + + ############################## + # == Custom Admin Actions == # + ############################## + + @admin.action(description="Download selection as CSV") + def download_csv(self, request, queryset): + """Download queryset of objects.""" + + if self.csv_serializer_class is None: + self.message_user( + request, + "Unable to download objects without a serializer.", + logging.WARNING, + ) + return redirect(f"{self.admin_name}:{self._url_name()}") + + filepath = self.csv_svc.download_csv(queryset) + return FileResponse(open(filepath, "rb")) + + +class InlineBase(AdminBase): + extra = 0 + formfield_overrides = {} + + +class StackedInlineBase(InlineBase, admin.StackedInline): + """Display fk related objects as cards, form flowing down.""" + + +class TabularInlineBase(InlineBase, admin.TabularInline): + """Display fk related objects in a table, fields flowing horizontally in rows.""" diff --git a/app/core/abstracts/models.py b/app/core/abstracts/models.py index 1a042f8..7afe783 100644 --- a/app/core/abstracts/models.py +++ b/app/core/abstracts/models.py @@ -155,6 +155,20 @@ def get_content_type(cls): """ return ContentType.objects.get_for_model(cls) + @classmethod + def get_fields_list( + cls, include_parents=True, exclude_read_only=False + ) -> list[str]: + """Return a list of editable fields.""" + + fields = [ + str(field.name) + for field in cls._meta.get_fields(include_parents=include_parents) + if (not exclude_read_only or (exclude_read_only and field.editable is True)) + ] + + return fields + class Meta: abstract = True diff --git a/app/core/abstracts/serializers.py b/app/core/abstracts/serializers.py index 4d241dd..cbf244b 100644 --- a/app/core/abstracts/serializers.py +++ b/app/core/abstracts/serializers.py @@ -1,10 +1,85 @@ +from enum import Enum +from typing import Type + +from django.db import models from rest_framework import serializers +class FieldType(Enum): + READONLY = "readonly" + WRITABLE = "writable" + REQUIRED = "required" + UNIQUE = "unique" + + +class SerializerBase(serializers.Serializer): + """Wrapper around the base drf serializer.""" + + datetime_format = "%Y-%m-%d %H:%M:%S" + + @property + def all_fields(self) -> list[str]: + """Get list of all fields in serializer.""" + + return [key for key in self.get_fields().keys()] + + @property + def readable_fields(self) -> list[str]: + """Get list of all fields in serializer that can be read.""" + + return self.all_fields + + @property + def writable_fields(self) -> list[str]: + """Get list of all fields that can be written to.""" + + return [ + key for key, value in self.get_fields().items() if value.read_only is False + ] + + @property + def readonly_fields(self) -> list[str]: + """Get list of all fields that can only be read, not written.""" + + return [ + key for key, value in self.get_fields().items() if value.read_only is True + ] + + @property + def required_fields(self) -> list[str]: + """Get list of all fields that must be written to on object creation.""" + + return [ + key + for key, value in self.fields.items() + if value.required is True and value.read_only is False + ] + + def get_field_types(self, field_name: str, serializer=None) -> list[FieldType]: + """Get ``FieldType`` for a given field.""" + serializer = serializer if serializer is not None else self + + field_types = [] + + if field_name in serializer.writable_fields: + field_types.append(FieldType.WRITABLE) + + if field_name in serializer.readonly_fields: + field_types.append(FieldType.READONLY) + + if field_name in serializer.required_fields: + field_types.append(FieldType.REQUIRED) + + if field_name in serializer.unique_fields: + field_types.append(FieldType.UNIQUE) + + return field_types + + class ModelSerializerBase(serializers.ModelSerializer): """Default functionality for model serializer.""" - datetime_format = "%Y-%m-%d %H:%M:%S" + datetime_format = SerializerBase.datetime_format id = serializers.IntegerField(label="ID", read_only=True) created_at = serializers.DateTimeField( @@ -16,6 +91,65 @@ class ModelSerializerBase(serializers.ModelSerializer): default_fields = ["id", "created_at", "updated_at"] + class Meta: + model = None + + @property + def model_class(self) -> Type[models.Model]: + return self.Meta.model + + @property + def unique_fields(self) -> list[str]: + """Get list of all fields that can be used to unique identify models.""" + + model_fields = self.model_class._meta.get_fields() + unique_fields = [ + field + for field in model_fields + if getattr(field, "primary_key", False) or getattr(field, "_unique", False) + ] + unique_fields = [field.name for field in unique_fields] + + return [field for field in self.readable_fields if field in unique_fields] + + @property + def related_fields(self) -> list[str]: + """List of fields that inherit RelatedField, representing foreign key relations.""" + + return [ + key + for key, value in self.get_fields().items() + if isinstance(value, serializers.RelatedField) + ] + + @property + def many_related_fields(self) -> list[str]: + """List of fields that inherit ManyRelatedField, representing M2M relations.""" + + return [ + key + for key, value in self.get_fields().items() + if isinstance(value, serializers.ManyRelatedField) + ] + + @property + def any_related_fields(self) -> list[str]: + """List of fields that are single or many related.""" + + return self.related_fields + self.many_related_fields + + @property + def unique_together_fields(self): + """List of tuples of fields that must be unique together.""" + + constraints = self.model_class._meta.constraints + + return [ + constraint.fields + for constraint in constraints + if isinstance(constraint, models.UniqueConstraint) + ] + class ModelSerializer(ModelSerializerBase): """Base fields for model serializer.""" diff --git a/app/core/abstracts/tests.py b/app/core/abstracts/tests.py index 48c5bf4..601b51c 100644 --- a/app/core/abstracts/tests.py +++ b/app/core/abstracts/tests.py @@ -1,10 +1,11 @@ +import os from typing import Optional, Type from django import forms from django.http import HttpResponse from django.test import TestCase from django.urls import reverse -from rest_framework import status +from rest_framework import serializers, status from rest_framework.status import HTTP_200_OK from rest_framework.test import APIClient @@ -31,6 +32,34 @@ def assertLength(self, target: list, length=1, msg=None): self.assertEqual(len(target), length, msg) + def assertStartsWith(self, string: str, substring: str): + """Target string should start with substring.""" + + self.assertIsInstance(string, str) + self.assertTrue( + string.startswith(substring), + f"String {string or 'EMPTY'} does not start with {substring}.", + ) + + def assertEndsWith(self, string: str, substring: str): + """Target string should end with substring.""" + + self.assertIsInstance(string, str) + self.assertTrue( + string.endswith(substring), + f"String {string} does not end with {substring}.", + ) + + def assertFileExists(self, path): + """File with path should exist.""" + + self.assertTrue(os.path.exists(path), f"File does not exist at {path}.") + + def assertValidSerializer(self, serializer: serializers.Serializer): + """Check `.is_valid()` function on serializer, prints errors if invalid.""" + + self.assertTrue(serializer.is_valid(), serializer.errors) + class ApiTestsBase(TestsBase): """Abstract testing utilities for api testing.""" diff --git a/app/core/mock/migrations/0001_initial.py b/app/core/mock/migrations/0001_initial.py index c55f882..324249e 100644 --- a/app/core/mock/migrations/0001_initial.py +++ b/app/core/mock/migrations/0001_initial.py @@ -1,7 +1,8 @@ -# Generated by Django 4.2.17 on 2024-12-24 06:47 +# Generated by Django 4.2.19 on 2025-02-24 03:55 from django.db import migrations, models import django.db.models.deletion +import uuid class Migration(migrations.Migration): @@ -46,6 +47,7 @@ class Migration(migrations.Migration): ("created_at", models.DateTimeField(auto_now_add=True)), ("updated_at", models.DateTimeField(auto_now=True)), ("name", models.CharField()), + ("unique_name", models.CharField(default=uuid.uuid4, unique=True)), ( "many_tags", models.ManyToManyField( diff --git a/app/core/mock/models.py b/app/core/mock/models.py index 1f40f57..f98d376 100644 --- a/app/core/mock/models.py +++ b/app/core/mock/models.py @@ -1,3 +1,5 @@ +import uuid + from django.db import models from core.abstracts.models import ModelBase @@ -16,6 +18,7 @@ class Buster(ModelBase): """ name = models.CharField() + unique_name = models.CharField(unique=True, default=uuid.uuid4) one_tag = models.ForeignKey( BusterTag, on_delete=models.SET_NULL, null=True, blank=True ) diff --git a/app/core/mock/serializers.py b/app/core/mock/serializers.py new file mode 100644 index 0000000..cc7f31f --- /dev/null +++ b/app/core/mock/serializers.py @@ -0,0 +1,42 @@ +from core.mock.models import Buster, BusterTag +from querycsv.serializers import CsvModelSerializer, WritableSlugRelatedField + + +class BusterCsvSerializer(CsvModelSerializer): + """Serialize dummy model for testing.""" + + many_tags_int = WritableSlugRelatedField( + source="many_tags", + slug_field="id", + queryset=BusterTag.objects.all(), + many=True, + required=False, + allow_null=True, + ) + many_tags = WritableSlugRelatedField( + slug_field="name", + queryset=BusterTag.objects.all(), + many=True, + required=False, + allow_null=True, + ) + one_tag = WritableSlugRelatedField( + slug_field="name", + required=False, + queryset=BusterTag.objects.all(), + allow_null=True, + ) + + class Meta(CsvModelSerializer.Meta): + model = Buster + fields = [ + "id", + "created_at", + "updated_at", + "name", + "unique_name", + "one_tag", + "many_tags", + "many_tags_int", + ] + exclude = None diff --git a/app/core/templates/admin/core/change_list.html b/app/core/templates/admin/core/change_list.html new file mode 100644 index 0000000..9818e76 --- /dev/null +++ b/app/core/templates/admin/core/change_list.html @@ -0,0 +1,15 @@ +{% extends "admin/change_list.html" %} +{% block object-tools-items %} + +{% for tool in object_tools %} +
  • + {{ tool.label }} +
  • +{% endfor %} + +{{ block.super }} +{% endblock %} diff --git a/app/lib/spreadsheets.py b/app/lib/spreadsheets.py new file mode 100644 index 0000000..d5a100b --- /dev/null +++ b/app/lib/spreadsheets.py @@ -0,0 +1,23 @@ +import numpy as np +import pandas as pd + +SPREADSHEET_EXTS = ("csv", "xls", "xlsx") +"""Tuple of supported spreadsheet extensions.""" + + +def read_spreadsheet(path: str): + """Import spreadsheet from filepath.""" + + if isinstance(path, str): + # assert os.path.exists(path), f"File doesn't exist at {path}." + + if path.endswith(".xlsx") or path.endswith(".xls"): + df = pd.read_excel(path, dtype=str) + else: + df = pd.read_csv(path, dtype=str) + else: + df = pd.read_csv(path, dtype=str) + + df.replace(np.nan, "", inplace=True) + + return df diff --git a/app/querycsv/__init__.py b/app/querycsv/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/querycsv/admin.py b/app/querycsv/admin.py new file mode 100644 index 0000000..7388d0b --- /dev/null +++ b/app/querycsv/admin.py @@ -0,0 +1,5 @@ +from django.contrib import admin + +from querycsv.models import QueryCsvUploadJob + +admin.site.register(QueryCsvUploadJob) diff --git a/app/querycsv/apps.py b/app/querycsv/apps.py new file mode 100644 index 0000000..54fc4c2 --- /dev/null +++ b/app/querycsv/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class QuerycsvConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "querycsv" diff --git a/app/querycsv/consts.py b/app/querycsv/consts.py new file mode 100644 index 0000000..ec207f3 --- /dev/null +++ b/app/querycsv/consts.py @@ -0,0 +1,6 @@ +QUERYCSV_MEDIA_SUBDIR = "core/querycsv/" +"""Nested directory to use in media storage. Excludes media path.""" + +EXTRA_QUERYCSV_FIELDS = ("SKIP",) + +__all__ = ["QUERYCSV_MEDIA_SUBDIR", "EXTRA_QUERYCSV_FIELDS"] diff --git a/app/querycsv/forms.py b/app/querycsv/forms.py new file mode 100644 index 0000000..6b8e9b1 --- /dev/null +++ b/app/querycsv/forms.py @@ -0,0 +1,46 @@ +from django import forms + +from querycsv.consts import EXTRA_QUERYCSV_FIELDS +from querycsv.models import QueryCsvUploadJob + + +class CsvUploadForm(forms.Form): + """Form used to upload csv and create/update objects.""" + + file = forms.FileField( + label="Select CSV or Excel Spreadsheet to upload.", + widget=forms.FileInput(attrs={"class": "form-control"}), + ) + + +class CsvHeaderMappingForm(forms.Form): + """Map csv headers to object fields.""" + + csv_header = forms.CharField( + # disabled=True, + required=True, + widget=forms.TextInput(attrs={"readonly": True, "class": "form-control"}), + ) + object_field = forms.ChoiceField( + choices=[], required=True, widget=forms.Select(attrs={"class": "form-control"}) + ) + + def __init__(self, *args, available_fields: list[str], **kwargs): + + super().__init__(*args, **kwargs) + + self.fields["object_field"].choices = [ + (field, field.upper()) for field in EXTRA_QUERYCSV_FIELDS + ] + [(field, field) for field in available_fields] + + +class CsvHeaderMappingFormSet(forms.formset_factory(CsvHeaderMappingForm, extra=0)): + """Custom FormSet for defining csv header mappings.""" + + def __init__(self, *args, upload_job: QueryCsvUploadJob, **kwargs): + kwargs["form_kwargs"] = { + **kwargs.get("form_kwargs", {}), + "available_fields": upload_job.serializer_class().get_flat_fields().keys(), + } + + super().__init__(*args, **kwargs) diff --git a/app/querycsv/migrations/0001_initial.py b/app/querycsv/migrations/0001_initial.py new file mode 100644 index 0000000..e12ba01 --- /dev/null +++ b/app/querycsv/migrations/0001_initial.py @@ -0,0 +1,89 @@ +# Generated by Django 4.2.19 on 2025-02-24 03:18 + +import django.core.validators +from django.db import migrations, models +import rest_framework.serializers +import utils.models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="QueryCsvUploadJob", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "file", + models.FileField( + upload_to=utils.models.UploadFilepathFactory( + default_extension="csv", path="core/querycsv/uploads/" + ), + validators=[ + django.core.validators.FileExtensionValidator( + allowed_extensions=("csv", "xls", "xlsx") + ) + ], + ), + ), + ( + "serializer", + models.CharField( + max_length=64, + null=True, + validators=[ + utils.models.ValidateImportString( + target_type=rest_framework.serializers.Serializer + ) + ], + ), + ), + ( + "status", + models.CharField( + choices=[ + ("pending", "Pending"), + ("processing", "Processing"), + ("failed", "Failed"), + ("success", "Success"), + ], + default="pending", + ), + ), + ( + "notify_email", + models.EmailField(blank=True, max_length=254, null=True), + ), + ( + "report", + models.FileField( + blank=True, null=True, upload_to="core/querycsv/reports/" + ), + ), + ( + "custom_field_mappings", + models.JSONField( + blank=True, + help_text="Key value pairs, column name => model field", + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/app/querycsv/migrations/0002_alter_querycsvuploadjob_serializer.py b/app/querycsv/migrations/0002_alter_querycsvuploadjob_serializer.py new file mode 100644 index 0000000..278caf1 --- /dev/null +++ b/app/querycsv/migrations/0002_alter_querycsvuploadjob_serializer.py @@ -0,0 +1,28 @@ +# Generated by Django 4.2.19 on 2025-02-24 23:43 + +from django.db import migrations, models +import querycsv.serializers +import utils.models + + +class Migration(migrations.Migration): + + dependencies = [ + ("querycsv", "0001_initial"), + ] + + operations = [ + migrations.AlterField( + model_name="querycsvuploadjob", + name="serializer", + field=models.CharField( + max_length=64, + null=True, + validators=[ + utils.models.ValidateImportString( + target_type=querycsv.serializers.CsvModelSerializer + ) + ], + ), + ), + ] diff --git a/app/querycsv/migrations/__init__.py b/app/querycsv/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/querycsv/models.py b/app/querycsv/models.py new file mode 100644 index 0000000..da2a3ee --- /dev/null +++ b/app/querycsv/models.py @@ -0,0 +1,145 @@ +""" +CSV data logging models. +""" + +from pathlib import Path +from typing import ClassVar, Optional, Type, TypedDict + +from django.core.files import File +from django.core.validators import FileExtensionValidator +from django.db import models +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from core.abstracts.models import ManagerBase, ModelBase +from lib.spreadsheets import SPREADSHEET_EXTS, read_spreadsheet +from querycsv.consts import QUERYCSV_MEDIA_SUBDIR +from querycsv.serializers import CsvModelSerializer +from utils.files import get_file_path +from utils.helpers import get_import_path, import_from_path +from utils.models import UploadFilepathFactory, ValidateImportString + + +class CsvUploadStatus(models.TextChoices): + """When a csv is uploaded, will have one of these statuses""" + + PENDING = "pending", _("Pending") + PROCESSING = "processing", _("Processing") + FAILED = "failed", _("Failed") + SUCCESS = "success", _("Success") + + +class FieldMappingType(TypedDict): + column_name: str + field_name: str + + +class QueryCsvUploadJobManager(ManagerBase["QueryCsvUploadJob"]): + """Model manager for queryset csvs.""" + + def create( + self, + serializer_class: Type[serializers.Serializer], + filepath: Optional[str] = None, + notify_email: Optional[str] = None, + **kwargs, + ) -> "QueryCsvUploadJob": + """ + Create new QuerySet Csv Upload Job. + """ + + kwargs["serializer"] = get_import_path(serializer_class) + + if filepath: + path = Path(filepath) + with path.open(mode="rb") as f: + kwargs["file"] = File(f, name=path.name) + job = super().create(notify_email=notify_email, **kwargs) + else: + job = super().create(notify_email=notify_email, **kwargs) + + return job + + +class QueryCsvUploadJob(ModelBase): + """Used to store meta info about csvs from querysets.""" + + validate_import_string = ValidateImportString(target_type=CsvModelSerializer) + csv_upload_path = UploadFilepathFactory( + path=QUERYCSV_MEDIA_SUBDIR + "uploads/", default_extension="csv" + ) + + # Primary fields + file = models.FileField( + upload_to=csv_upload_path, + validators=[FileExtensionValidator(allowed_extensions=SPREADSHEET_EXTS)], + ) + serializer = models.CharField( + max_length=64, validators=[validate_import_string], null=True + ) + + # Meta fields + status = models.CharField( + choices=CsvUploadStatus.choices, default=CsvUploadStatus.PENDING + ) + notify_email = models.EmailField(null=True, blank=True) + report = models.FileField( + upload_to=QUERYCSV_MEDIA_SUBDIR + "reports/", null=True, blank=True + ) + custom_field_mappings = models.JSONField( + blank=True, help_text="Key value pairs, column name => model field" + ) + + # Overrides + objects: ClassVar[QueryCsvUploadJobManager] = QueryCsvUploadJobManager() + + def save(self, *args, **kwargs) -> None: + if self.custom_field_mappings is None: + self.custom_field_mappings = {"fields": []} + + return super().save(*args, **kwargs) + + # Dynamic properties + @property + def filepath(self): + return get_file_path(self.file) + + @property + def spreadsheet(self): + return read_spreadsheet(self.filepath) + + @property + def serializer_class(self) -> Type[CsvModelSerializer]: + return import_from_path(self.serializer) + + @serializer_class.setter + def serializer_class(self, value: Type[CsvModelSerializer]): + self.serializer = get_import_path(value) + + @property + def model_class(self) -> Type[ModelBase]: + return self.serializer_class.Meta.model + + @property + def custom_fields(self) -> list[FieldMappingType]: + return self.custom_field_mappings["fields"] + + @property + def csv_headers(self): + return list(self.spreadsheet.columns) + + # Methods + def add_field_mapping(self, column_name: str, field_name: str, commit=True): + """Add custom field mapping.""" + column_options = list(self.spreadsheet.columns) + + assert ( + column_name in column_options + ), f"The name {column_name} is not in available columns: {', '.join(column_options)}" + + self.custom_field_mappings["fields"].append( + {"column_name": column_name, "field_name": field_name} + ) + + if commit: + self.save() diff --git a/app/querycsv/serializers.py b/app/querycsv/serializers.py new file mode 100644 index 0000000..f90c8cc --- /dev/null +++ b/app/querycsv/serializers.py @@ -0,0 +1,315 @@ +import re +from typing import Optional + +from django.db import models +from rest_framework import serializers +from rest_framework.fields import empty +from rest_framework.relations import SlugRelatedField + +from core.abstracts.serializers import FieldType, ModelSerializerBase, SerializerBase +from utils.helpers import str_to_list + + +class FlatField: + key: str + help_text: str + field_types: list[FieldType] + is_list_item = False + + def __init__( + self, key: str, value: serializers.Field, field_types: list[FieldType] + ): + """Initialize field object, value is from serializer.get_fields[key]""" + + self.key = key + self.help_text = value.help_text + self.field_types = field_types + + def __str__(self): + return self.key + + def __eq__(self, value): + return self.key == value + + @property + def is_readonly(self): + return FieldType.READONLY in self.field_types + + @property + def is_writable(self): + return FieldType.WRITABLE in self.field_types + + @property + def is_required(self): + return FieldType.REQUIRED in self.field_types + + @property + def is_unique(self): + return FieldType.UNIQUE in self.field_types + + +class FlatListField(FlatField): + """Represents a flat field that's part of a list.""" + + is_list_item = True + + # Additional fields for list items + index: Optional[int] + parent_key: str + sub_key: Optional[str] + generic_key: str + + def __init__(self, key, value, field_types): + super().__init__(key, value, field_types) + + self._set_list_values() + + def __eq__(self, value): + return value == self.key or re.sub(r"\[(\d+|n)\]", "[n]", value) == self.key + + def _set_list_values(self): + matches = re.match(r"([a-z0-9_-]+)\[(\d+|n)\]\.?(.*)?", self.key) + assert bool(matches), f"Invalid list item field: {self.key}" + + parent_field, index, sub_field = list(matches.groups()) + + self.parent_key = parent_field + self.index = index if index != "n" else None + self.sub_key = sub_field if sub_field != "" else None + self.generic_key = re.sub(r"\[\d+|n\]", "[n]", self.key) + + def set_index(self, index: int): + """Used when index is found later.""" + + self.index = index + self.key = f"{self.parent_key}[{index}]" + + if self.sub_key is not None: + self.key += f".{self.sub_key}" + + +class FlatSerializer(SerializerBase): + """Convert between json data and flattened data.""" + + def __init__(self, instance=None, data=empty, flat=False, **kwargs): + if not flat or not data: + return super().__init__(instance, data, **kwargs) + + nested_data = self.flat_to_json(data) + + super().__init__(instance, data=nested_data, **kwargs) + + ############################## + # == Serializer Functions == # + ############################## + + @property + def writable_many_related_fields(self): + """List of fields that are WritableRelated, and have many=True""" + + return [ + key + for key, value in self.get_fields().items() + if isinstance(value, serializers.ManyRelatedField) + and isinstance(value.child_relation, WritableSlugRelatedField) + ] + + @property + def flat_data(self): + """Like ``serializer.data``, but returns flattened data.""" + + data = self.data + return self.json_to_flat(data) + + @classmethod + def json_to_flat(cls, data: dict): + """Convert representation to flattened struction for CSV.""" + + # TODO: Handle nested json to flat + for key, value in data.items(): + # Convert lists to string + if isinstance(value, list): + data[key] = ", ".join([str(v) for v in value]) + + return data + + @classmethod + def flat_to_json(cls, record: dict) -> dict: + """ + Convert data from csv to a nested json rep. + + Examples + -------- + IN : {"some_list[0]": "zero", "some_list[1]": "one"} + OUT: {"some_list": ["zero", "one"]} + -- + IN : {"another_list[0].first_name": "John", "another_list[0].last_name": "Doe"} + OUT: {"another_list": [{"first_name": "John", "last_name": "Doe"}]} + """ + + parsed = {} + + # Initial parsing + for key, value in record.items(): + list_objs_res = re.match(r"([a-z0-9_-]+)\[([0-9]+)\]\.?(.*)?", key) + + if bool(list_objs_res): + # Handle list of objects + field, index, nested_field = list_objs_res.groups() + index = int(index) + + if field not in parsed.keys(): + parsed[field] = [] + + assert isinstance( + parsed[field], list + ), f"Inconsistent types for field {field}" + + # Need to ensure the object is put at that specific location, + # since the other fields will expect it there. + while len(parsed[field]) <= index: + parsed[field].append({}) + + # TODO: Recurse for deeply nested objects + parsed[field][index][nested_field] = value + elif key in cls().writable_many_related_fields and isinstance(value, str): + # Handle list of literals + parsed[key] = str_to_list(value) + elif key in cls().writable_many_related_fields and not isinstance( + value, list + ): + parsed[key] = [value] + else: + # Handle objects + # TODO: CSV Objects + + # Default + parsed[key] = value + + # Filtering + for key, value in parsed.items(): + # Skip if: 1) not a list 2) empty list 3) list contains non-dict values + if ( + not isinstance(value, list) + or len(value) == 0 + or not isinstance(value[0], dict) + ): + continue + + # Remove empty objects from nested lists + parsed[key] = [item for item in value if len(item.keys()) > 0] + + return parsed + + def get_flat_fields(self) -> dict[str, FlatField | FlatListField]: + """Like ``get_fields``, returns a dict of fields with their flat type.""" + + flat_fields = {} + + for key, value in self.get_fields().items(): + if not isinstance(value, serializers.BaseSerializer): + + field = FlatField(key, value, self.get_field_types(key)) + flat_fields[key] = field + continue + + field_name = key + + if value.many: + field_name += "[n]." + else: + field_name += "." + + sub_serializer = value.child + + for sub_field in sub_serializer.get_fields(): + nested_field_name = field_name + sub_field + field_cls = FlatField if not value.many else FlatListField + + field = field_cls( + nested_field_name, + sub_serializer.get_fields()[sub_field], + self.get_field_types(sub_field, serializer=sub_serializer), + ) + + flat_fields[key] = field + + return flat_fields + + +class CsvModelSerializer(FlatSerializer, ModelSerializerBase): + """Convert fields to csv columns.""" + + def __init__(self, instance=None, data=empty, **kwargs): + """Override default functionality to implement update or create.""" + + # Skip if data is empty + if data is None: + return super().__init__(instance=instance, **kwargs) + + # Coerce Slug Many Related Field to a list before processing + # TODO: This should be handled entirely by flat_to_json + try: + fields = [ + field + for field in self.writable_many_related_fields + if field in data.keys() + ] + + for field in fields: + if isinstance(data[field], str): + data[field] = str_to_list(data[field]) + + except Exception: + pass + + # Initialize rest of serializer first, needed if data is flat + super().__init__(data=data, **kwargs) + + # Allow create_or_udpate functionality + try: + if instance is None and data is not None and data is not empty: + ModelClass = self.model_class + search_fields = {} + search_query = None + + for field in self.unique_fields: + value = data.get(field, None) + + # Remove leading/trailing spaces before processing + if value is None or value == "": + continue + elif isinstance(value, str): + value = value.strip() + + search_fields[field] = value + + if search_query is None: + search_query = models.Q(**{field: value}) + else: + search_query = search_query | models.Q(**{field: value}) + + query = ModelClass.objects.filter(search_query) + if query.exists(): + instance = query.first() + else: + self.instance = instance + + except Exception: + pass + + self.instance = instance + + +class WritableSlugRelatedField(SlugRelatedField): + """Wraps slug related field and creates object if not found.""" + + def to_internal_value(self, data): + """Overrides default behavior to create if not found.""" + queryset = self.get_queryset() + + try: + obj, _ = queryset.get_or_create(**{self.slug_field: data}) + return obj + except (TypeError, ValueError): + self.fail("invalid") diff --git a/app/querycsv/services.py b/app/querycsv/services.py new file mode 100644 index 0000000..96339cb --- /dev/null +++ b/app/querycsv/services.py @@ -0,0 +1,209 @@ +import re +from enum import Enum +from typing import Literal, Optional, OrderedDict, Type, TypedDict + +import pandas as pd +from django.db import models + +from core.abstracts.serializers import ModelSerializerBase +from lib.spreadsheets import read_spreadsheet +from querycsv.consts import QUERYCSV_MEDIA_SUBDIR +from querycsv.models import QueryCsvUploadJob +from querycsv.serializers import CsvModelSerializer +from utils.files import get_media_path + + +class FieldMappingType(TypedDict): + column_name: str + field_name: str + + +class QueryCsvService: + """Handle uploads and downloads of models using csvs.""" + + class Actions(Enum): + SKIP = "SKIP" + CF = "CUSTOM_FIELD" + + def __init__(self, serializer_class: Type[CsvModelSerializer]): + self.serializer_class = serializer_class + self.serializer = serializer_class() + self.model_name = self.serializer.model_class.__name__ + + self.fields: OrderedDict = self.serializer.get_fields() + self.readonly_fields = self.serializer.readonly_fields + self.writable_fields = self.serializer.writable_fields + self.all_fields = self.serializer.readable_fields + self.required_fields = self.serializer.required_fields + self.unique_fields = self.serializer.unique_fields + + self.flat_fields = self.serializer.get_flat_fields() + + self.actions = [action.value for action in self.Actions] + + @classmethod + def upload_from_job(cls, job: QueryCsvUploadJob): + """Upload csv using predefined job.""" + + assert job.serializer is not None, "Upload job must container serializer." + + svc = cls(serializer_class=job.serializer_class) + return svc.upload_csv(job.file, custom_field_maps=job.custom_fields) + + @classmethod + def queryset_to_csv( + cls, queryset: models.QuerySet, serializer_class: Type[ModelSerializerBase] + ): + """Print a queryset to a csv, return file path.""" + + service = cls(serializer_class=serializer_class) + return service.download_csv(queryset) + + def download_csv(self, queryset: models.QuerySet) -> str: + """Download: Convert queryset to csv, return path to csv.""" + + data = self.serializer_class(queryset, many=True).data + flattened = [self.serializer_class.json_to_flat(obj) for obj in data] + + df = pd.json_normalize(flattened) + filepath = get_media_path( + QUERYCSV_MEDIA_SUBDIR + "downloads/", + fileprefix=f"{self.model_name}", + fileext="csv", + ) + df.to_csv(filepath, index=False) + + return filepath + + def get_csv_template(self, field_types: Literal["all", "required", "writable"]): + """ + Get path to csv file containing required fields for upload. + + Parameters + ---------- + - all_fields (bool): Whether to include all fields or just required fields. + """ + + match field_types: + case "required": + template_fields = self.required_fields + case "writable": + template_fields = self.writable_fields + case "all" | _: + template_fields = self.all_fields + + filepath = get_media_path( + QUERYCSV_MEDIA_SUBDIR + "templates/", + f"{self.model_name}_template.csv", + create_path=True, + ) + df = pd.DataFrame([], columns=template_fields) + df.to_csv(filepath, index=False) + + return filepath + + def upload_csv( + self, path: str, custom_field_maps: Optional[list[FieldMappingType]] = None + ): + """ + Upload: Given path to csv, create/update models and + return successful and failed objects. + """ + + # Start by importing csv + df = read_spreadsheet(path) + + # Update df values with header associations + if custom_field_maps: + generic_list_keys = [] # Used for determining index when ambiguous + + for mapping in custom_field_maps: + map_field_name = mapping["field_name"] + + if ( + map_field_name not in self.flat_fields.keys() + and map_field_name not in self.actions + ): + continue # Safely skip invalid mappings + + field = self.flat_fields[map_field_name] + + if not field.is_list_item: + # Default field logic + df.rename( + columns={mapping["column_name"]: map_field_name}, + inplace=True, + ) + continue + + ####################################################### + # Handle list items. + # + # Mappings can come in as field[n].subfield, or field[0].subfield. + # If the mapping uses n for the index, then the n will be the "nth" occurance + # of that field, starting at 0. + # + # At this point, all "field" (FlatListField) values are index=None, + # n-mappings will all be assigned indexes. + ####################################################### + + # Determine type + numbers = re.findall(r"\d+", mapping["column_name"]) + assert ( + len(numbers) <= 1 + ), "List items can only contain 0 or 1 numbers (multi digit allowed)." + + if len(numbers) == 1: + # Number was provided in spreadsheet + index = numbers[0] + else: + # Number was not provided in spreadsheet, get index of field + index = len( + [key for key in generic_list_keys if key == field.generic_key] + ) + + field.set_index(index) + generic_list_keys.append(field.generic_key) + + df.rename(columns={mapping["column_name"]: str(field)}, inplace=True) + + # Normalize & clean fields before conversion to dict + for field_name, field_type in self.serializer.get_flat_fields().items(): + if field_name not in list(df.columns): + continue + + if field_type.is_list_item: + df[field_name] = df[field_name].map( + lambda val: [ + item for item in str(val).split(",") if str(item) != "" + ] + ) + else: + df[field_name] = df[field_name].map( + lambda val: val if val != "" else None + ) + + # Convert df to list of dicts, drop null fields + upload_data = df.to_dict("records") + filtered_data = [ + {k: v for k, v in record.items() if v is not None} for record in upload_data + ] + + # Finally, save data if valid + success = [] + errors = [] + + # Note: string stripping is done in the serializer + serializers = [ + self.serializer_class(data=data, flat=True) for data in filtered_data + ] + + for serializer in serializers: + if serializer.is_valid(): + serializer.save() + success.append(serializer.data) + else: + report = {**serializer.data, "errors": {**serializer.errors}} + errors.append(report) + + return success, errors diff --git a/app/querycsv/signals.py b/app/querycsv/signals.py new file mode 100644 index 0000000..2550877 --- /dev/null +++ b/app/querycsv/signals.py @@ -0,0 +1,37 @@ +from typing import Optional + +from django import dispatch + +from querycsv.models import QueryCsvUploadJob +from querycsv.tasks import process_csv_job_task + +#################### +# Signal Producers # +#################### + +process_csv_job_signal = dispatch.Signal() + + +def send_process_csv_job_signal(job: QueryCsvUploadJob): + """Sends signal for queueing up a csv upload job.""" + + process_csv_job_signal.send(job.__class__, instance=job) + + +#################### +# Signal Receivers # +#################### + + +@dispatch.receiver(process_csv_job_signal) +def on_process_csv_job_signal(sender, instance: Optional[QueryCsvUploadJob], **kwargs): + """ + Runs when the process upload job signal is fired. + + This will create a new celery task for processing a csv upload. + """ + + if not instance: + return + + process_csv_job_task.delay(job_id=instance.pk) diff --git a/app/querycsv/tasks.py b/app/querycsv/tasks.py new file mode 100644 index 0000000..8325b5e --- /dev/null +++ b/app/querycsv/tasks.py @@ -0,0 +1,65 @@ +import pandas as pd +from celery import shared_task +from django.core.mail import EmailMessage +from django.utils import timezone +from django.utils.safestring import mark_safe + +from querycsv.consts import QUERYCSV_MEDIA_SUBDIR +from querycsv.models import CsvUploadStatus, QueryCsvUploadJob +from querycsv.services import QueryCsvService +from utils.files import get_media_path +from utils.helpers import import_from_path +from utils.models import save_file_to_model + + +@shared_task +def upload_csv_task(filepath: str, serializer_path: str): + Serializer = import_from_path(serializer_path) + svc = QueryCsvService(serializer_class=Serializer) + + qs = svc.upload_csv(filepath) + print("Created objects:", qs) + + +@shared_task +def process_csv_job_task(job_id: int): + """ + Processes a predefined upload job. + Used for larger uploads. + """ + # Process job + job = QueryCsvUploadJob.objects.find_by_id(job_id) + success, failed = QueryCsvService.upload_from_job(job) + job.status = CsvUploadStatus.SUCCESS + + # Create report + report_file_path = get_media_path( + QUERYCSV_MEDIA_SUBDIR + f"reports/{job.model_class.__name__}/", + fileprefix=str(timezone.now().strftime("%d-%m-%Y_%H:%M:%S")), + fileext="xlsx", + ) + + success_report = pd.json_normalize(success) + failed_report = pd.json_normalize(failed) + + with pd.ExcelWriter(report_file_path) as writer: + success_report.to_excel(writer, sheet_name="Successful", index=False) + failed_report.to_excel(writer, sheet_name="Failed", index=False) + + save_file_to_model(job, report_file_path, field="report") + job.refresh_from_db() + + # Send admin email + if job.notify_email: + model_name = job.model_class._meta.verbose_name_plural + mail = EmailMessage( + subject=f"Upload {model_name} report", + to=[job.notify_email], + body=mark_safe( + f"Your {model_name} csv has finished processing.

    " + f"Objects processed successfully: {len(success)}
    " + f"Objects unsuccessfully processed: {len(failed)}" + ), + ) + mail.attach_file(report_file_path) + mail.send() diff --git a/app/querycsv/templates/admin/querycsv/upload_csv.html b/app/querycsv/templates/admin/querycsv/upload_csv.html new file mode 100644 index 0000000..b37d32c --- /dev/null +++ b/app/querycsv/templates/admin/querycsv/upload_csv.html @@ -0,0 +1,84 @@ +{% extends 'admin/base_site.html' %} {% block content %} +
    +
    +

    Upload CSV

    +
    +
    +
    + CSV Templates +
      +
    • + Required fields only: + Download +
    • +
    • + All writable fields: + Download +
    • +
    +
    +
    +
    +
    + {% csrf_token %} {% for field in form %} +
    + {{ field.label_tag }} {{ field }} +

    {{ field.help_text }}

    +

    {{ field.errors }}

    +
    + {% endfor %} + +
    +
    +
    +
    +

    Field Info

    +
    + {% if unique_together_fields %} +
    + Unique together: +
      + {% for fields in unique_together_fields %} +
    • {{ fields|join:', ' }}
    • + {% endfor %} +
    +
    + {% endif %} +
    + + + + + + + + + + + + + {% for field in all_fields %} + + + + + + + + + {% endfor %} + +
    FieldInfoCan ViewCan EditRequiredUnique
    {{ field }}{% if field.help_text %}{{ field.help_text }}{% endif %}{% if field.is_writable %}✓{% endif %}{% if field.is_required %}✓{% endif %}{% if field.is_unique %}✓{% endif %}
    +
    +
    +
    +{% endblock %} diff --git a/app/querycsv/templates/admin/querycsv/upload_csv_headermapping.html b/app/querycsv/templates/admin/querycsv/upload_csv_headermapping.html new file mode 100644 index 0000000..c57bbac --- /dev/null +++ b/app/querycsv/templates/admin/querycsv/upload_csv_headermapping.html @@ -0,0 +1,80 @@ +{% extends 'admin/base_site.html' %} {% block content %} +
    +
    +
    +

    Confirm Upload

    + +
      +
    • Upload Job Id: {{ upload_job.id }}
    • +
    • Object Type: {{ model_class_name }}
    • +
    • + Rows Found: + {{ upload_job.spreadsheet.index | length }} +
    • +
    • Send Updates To: {{ upload_job.notify_email }}
    • +
    + +

    Map CSV Headers

    + +

    + The table below shows column headers found in the uploaded CSV. +

    +

    Instructions

    +

    + The right side of the table gives all available options for fields to map to + new/updated {{ model_class_name }} objects. +

    +
      +
    • + Select "PASS" to skip an entire column. +
    • +
    • + All empty field will automatically be skipped +
    • +
    • + If there is an extra column that does not match a field on the {{ model_class_name }} object, + that column will automatically be skipped. +
    • +
    +
    +
    +
    +
    +
    + {% csrf_token %}{{ formset.management_form }} + + + + + + + + + {% for form in formset %} + + + + + {% endfor %} + +
    CSV HeaderObject Field
    {{ form.csv_header }}{{ form.object_field }}
    + +
    + + +
    +
    +
    + +
    +
    +
    +{% endblock %} diff --git a/app/querycsv/templates/admin/querycsv/upload_not_available.html b/app/querycsv/templates/admin/querycsv/upload_not_available.html new file mode 100644 index 0000000..d334c44 --- /dev/null +++ b/app/querycsv/templates/admin/querycsv/upload_not_available.html @@ -0,0 +1,3 @@ +{% extends 'admin/base_site.html' %} {% block content %} +

    Cannot upload csvs for this object.

    +{% endblock %} \ No newline at end of file diff --git a/app/querycsv/tests/__init__.py b/app/querycsv/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/querycsv/tests/test_download_data.py b/app/querycsv/tests/test_download_data.py new file mode 100644 index 0000000..7eef286 --- /dev/null +++ b/app/querycsv/tests/test_download_data.py @@ -0,0 +1,120 @@ +""" +CSV Download Tests +""" + +from querycsv.tests.utils import ( + CsvDataM2MTestsBase, + CsvDataM2OTestsBase, + DownloadCsvTestsBase, +) +from utils.helpers import clean_list + + +class DownloadDataTests(DownloadCsvTestsBase): + """Unit tests for download csv data.""" + + def test_download_model_csv(self): + """Should download a csv listing objects for model.""" + + # Create csv using service + self.initialize_dataset() + qs = self.repo.all() + + filepath = self.service.download_csv(queryset=qs) + self.assertValidCsv(filepath) + + # Check csv + df = self.csv_to_df(filepath) + self.assertEqual(len(df.index), self.dataset_size) + + expected_fields = self.serializer.readable_fields + expected_fields.sort() + + actual_fields = list(df.columns) + actual_fields.sort() + + self.assertListEqual(expected_fields, actual_fields) + + # Verify contact fields in csv + self.assertCsvHasFields(df) + + +class DownloadCsvM2OFieldsTests(DownloadCsvTestsBase, CsvDataM2OTestsBase): + """Unit tests for testing downloaded csv many-to-one fields.""" + + def test_download_csv_m2o_fields(self): + """Should be able to download models with many-to-one fields.""" + + # Create csv using service + self.initialize_dataset() + qs = self.repo.all() + + filepath = self.service.download_csv(queryset=qs) + self.assertValidCsv(filepath) + + # Check csv + df = self.csv_to_df(filepath) + self.assertCsvHasFields(df) + + # For each row, check the many-to-one field + for index, row in df.iterrows(): + obj_id = row["id"] + expected_obj = self.repo.get_by_id(obj_id) + + expected_m2o_obj = getattr(expected_obj, self.m2o_selector) + + if expected_m2o_obj is None: + expected_value = None + else: + expected_value = getattr(expected_m2o_obj, self.m2o_target_field) + + actual_value = row[self.m2o_selector] + if actual_value == "": + actual_value = None + + self.assertEqual(actual_value, expected_value) + + +class DownloadCsvM2MFieldsStrTests(DownloadCsvTestsBase, CsvDataM2MTestsBase): + """Unit tests for testing downloaded csv many-to-many fields with str slug.""" + + def test_download_csv_m2m_fields(self): + """Should be able to download models with many-to-many fields.""" + + # Create csv using service + self.initialize_dataset() + qs = self.repo.all() + + filepath = self.service.download_csv(queryset=qs) + self.assertValidCsv(filepath) + + # Check csv + df = self.csv_to_df(filepath) + self.assertCsvHasFields(df) + + # For each row, check the many-to-one field + for index, row in df.iterrows(): + obj_id = row["id"] + expected_obj = self.repo.get_by_id(obj_id) + + expected_m2m_objs = getattr(expected_obj, self.m2m_model_selector) + expected_values = clean_list( + [ + str(getattr(obj, self.m2m_target_field)) + for obj in expected_m2m_objs.all() + ] + ) + + actual_value_raw = str(row[self.m2m_selector]) + actual_values = clean_list( + [str(v).strip() for v in actual_value_raw.split(",")] + ) + + self.assertListEqual(actual_values, expected_values) + + +class DownloadCsvM2MFieldsIntTests(DownloadCsvM2MFieldsStrTests): + """Unit tests for testing downloaded csv many-to-many fields with int slug.""" + + m2m_selector = "many_tags_int" + m2m_target_field = "id" diff --git a/app/querycsv/tests/test_upload_data.py b/app/querycsv/tests/test_upload_data.py new file mode 100644 index 0000000..954739c --- /dev/null +++ b/app/querycsv/tests/test_upload_data.py @@ -0,0 +1,277 @@ +""" +Import/upload data tests. +""" + +from django.contrib.postgres.aggregates import StringAgg +from django.db import models + +from querycsv.models import QueryCsvUploadJob +from querycsv.services import QueryCsvService +from querycsv.tests.utils import ( + CsvDataM2MTestsBase, + CsvDataM2OTestsBase, + UploadCsvTestsBase, +) + + +class UploadDataTests(UploadCsvTestsBase): + """Test uploading data from a csv.""" + + def test_create_objects_from_csv(self): + """Should be able to take csv and create models.""" + + # Initialize data + objects_before = self.initialize_csv_data() + + # Call service upload function + _, failed = self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsExist(objects_before, failed) + self.assertObjectsHaveFields(objects_before) + + def test_update_objects_from_csv(self): + # Initialize data + objects_before = self.initialize_csv_data(clear_db=False) + + for obj in self.repo.all(): + self.update_mock_object(obj) + + # Call service upload function + self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsExist(objects_before) + self.assertObjectsHaveFields(objects_before) + + def test_upload_csv_bad_fields(self): + """Should create objects and ignore bad fields.""" + + # Initialize csv, add invalid column + objects_before = self.initialize_csv_data() + self.df["Invalid field"] = "bad value" + self.df_to_csv(self.df) + + self.assertTrue("Invalid field" in list(self.df.columns)) + + self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsExist(objects_before) + self.assertObjectsHaveFields(objects_before) + + def test_upload_csv_update_objects(self): + """Uploading a csv should update objects.""" + + # Prep data, create csv + objects_before = self.initialize_csv_data(clear_db=False) + + updated_records = [] + + for obj in objects_before: + payload = {self.unique_field: obj[self.unique_field]} + payload = self.get_update_params(obj, **payload) + updated_records.append(payload) + + self.data_to_csv(updated_records) + + # Upload CSV + self.service.upload_csv(path=self.filepath) + + # Validate data + self.assertObjectsHaveFields(updated_records) + + def test_upload_csv_spaces(self): + """Should remove pre/post spaces from fields before updating/creating.""" + + # Prep data, create csv + objects_before = self.initialize_csv_data(clear_db=False) + + updated_records = [] + + for obj in objects_before: + payload = {self.unique_field: f" {obj[self.unique_field]} "} + payload = self.get_update_params(obj, **payload) + updated_records.append(payload) + + self.data_to_csv(updated_records) + self.assertObjectsExist(objects_before) + + # Upload CSV + self.service.upload_csv(path=self.filepath) + + # Validate data + self.assertObjectsHaveFields(updated_records) + + +class UploadCsvJobTests(UploadCsvTestsBase): + """Tests for uploading with QSCsv Model.""" + + def test_upload_from_job(self): + """Should upload and process csv from model.""" + + # Initialize data + objects_before = self.initialize_csv_data(clear_db=False) + + # Update fields after create csv + for obj in self.repo.all(): + self.update_mock_object(obj) + + # Upload csv via service + job = QueryCsvUploadJob.objects.create( + filepath=self.filepath, + serializer_class=self.serializer_class, + ) + QueryCsvService.upload_from_job(job) + + # Validate database + self.assertObjectsExist(objects_before) + self.assertObjectsHaveFields(objects_before) + + def test_upload_custom_fields(self): + """Should process csv with custom field mappings.""" + + objects_before = self.initialize_csv_data() + + # Rename csv field + self.df.rename(columns={"name": "Test Value"}, inplace=True) + self.df_to_csv(self.df, self.filepath) + + # Create and upload job + job = QueryCsvUploadJob.objects.create( + serializer_class=self.serializer_class, filepath=self.filepath + ) + job.add_field_mapping(column_name="Test Value", field_name="name") + job.refresh_from_db() + + QueryCsvService.upload_from_job(job) + + # Validate database + self.assertObjectsExist(pre_queryset=objects_before) + self.assertObjectsHaveFields(expected_objects=objects_before) + + +class UploadCsvM2OFieldsTests(UploadCsvTestsBase, CsvDataM2OTestsBase): + """Test uploading csvs for models with many-to-one fields.""" + + def test_upload_csv_m2o_fields(self): + """ + Should present Many-to-One (FK) fields according to serializer. + + Check by comparing the serialized representation before and after + the upload - both should have the save value for writable fields. + """ + + # Initialize data + objects_before = self.initialize_csv_data() + + # Call upload function + self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsHaveFields(objects_before) + self.assertIn(self.m2o_selector, list(self.df.columns)) + + self.assertObjectsM2OValidFields(self.df) + + def test_upload_csv_m2o_fields_update(self): + """Should update models with Many-to-One fields.""" + + # Initialize data + objects_before = self.initialize_csv_data(clear_db=False) + + # Update fields after create csv + for obj in self.repo.all(): + self.update_mock_object(obj) + + # Call upload function + self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsHaveFields(objects_before) + self.assertIn(self.m2o_selector, list(self.df.columns)) + + self.assertObjectsM2OValidFields(self.df) + + +class UploadCsvM2MFieldsTests(UploadCsvTestsBase, CsvDataM2MTestsBase): + """Test uploading csvs for models with many-to-many fields.""" + + def test_upload_csv_m2m_fields(self): + """When csv is uploaded, m2m fields should be handled properly.""" + + # Initialize data + objects_before = self.initialize_csv_data() + + # Upload csv using service + self.service.upload_csv(path=self.filepath) + + # Validate results + self.assertObjectsHaveFields(objects_before) + self.assertIn(self.m2m_selector, list(self.df.columns)) + + self.assertObjectsM2MValidFields(self.df) + + def test_upload_csv_m2m_fields_spaces(self): + """When csv is uploaded, m2m fields should be stripped of leading/trailing spaces.""" + + objects_before = self.initialize_csv_data() + + # Iterate through csv, manually add spacing + for i, row in self.df.iterrows(): + pre_value = row[self.m2m_selector] + pre_values = pre_value.split(",") + modified_value = " , ".join(pre_values) + row[self.m2m_selector] = modified_value + + self.df_to_csv(self.df) + + # Upload csv using service + self.service.upload_csv(path=self.filepath) + + # Validate results + self.assertObjectsHaveFields(objects_before) + self.assertIn(self.m2m_selector, list(self.df.columns)) + + self.assertObjectsM2MValidFields(self.df) + + def test_upload_csv_m2m_update_fields(self): + """When csv is uploaded, should update objects with many-to-many fields.""" + + # Initialize data + objects_before = self.initialize_csv_data(clear_db=False) + + # Update fields after create csv + self.update_dataset() + # for obj in self.repo.all().prefetch_related(self.m2m_selector): + # self.update_mock_object(obj) + + objects_before = list( + self.repo.all() + .annotate( + pre_objs_count=models.Count(self.m2m_selector), + pre_objs=StringAgg( + models.F(f"{self.m2m_selector}__{self.m2m_target_field}"), + distinct=True, + delimiter=",", + ), + ) + .values() + ) + + # Upload csv using service + success, failed = self.service.upload_csv(path=self.filepath) + + # Validate results + self.assertEqual(self.repo.all().count(), self.dataset_size) + expected_objects = list(self.df.to_dict("records")) + + self.assertObjectsHaveFields(expected_objects) + self.assertIn(self.m2m_selector, list(self.df.columns)) + self.assertTrue( + self.m2m_repo.all().count() <= self.m2m_size + self.m2m_update_size, + f"Expected at most {self.m2m_size + self.m2m_update_size} M2M objects, " + f"but {self.m2m_repo.all().count()} were created.", + ) + + self.assertObjectsM2MValidFields(self.df, objects_before) diff --git a/app/querycsv/tests/test_upload_views.py b/app/querycsv/tests/test_upload_views.py new file mode 100644 index 0000000..cc44213 --- /dev/null +++ b/app/querycsv/tests/test_upload_views.py @@ -0,0 +1,108 @@ +import pandas as pd +from django.template.response import TemplateResponse +from django.test import RequestFactory +from rest_framework import status + +from core.mock.models import Buster +from core.mock.serializers import BusterCsvSerializer +from querycsv.forms import CsvHeaderMappingFormSet, CsvUploadForm +from querycsv.models import QueryCsvUploadJob +from querycsv.tests.test_upload_data import UploadCsvTestsBase +from querycsv.views import QueryCsvViewSet + + +class UploadCsvViewsTests(UploadCsvTestsBase): + """Test functionality for upload views used in admin.""" + + model_class = Buster + serializer_class = BusterCsvSerializer + + def get_reverse(self, name="fallback"): + return "core:index" + + def setUp(self): + self.views = QueryCsvViewSet( + self.serializer_class, get_reverse=self.get_reverse + ) + self.req_factory = RequestFactory() + + return super().setUp() + + #################### + # == Unit Tests == # + #################### + + def test_upload_csv(self): + """Should show form for uploading csv.""" + + req = self.req_factory.get("/") + res: TemplateResponse = self.views.upload_csv(request=req) + + self.assertIsInstance(res, TemplateResponse) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + # Check context + self.assertIsInstance(res.context_data["form"], CsvUploadForm) + self.assertEqual( + res.context_data["template_url"], self.get_reverse("csv_template") + ) + # FIXME: Checking csv fields in context failes + # self.assertEqual( + # res.context_data["all_fields"], self.service.flat_fields.values() + # ) + self.assertEqual( + res.context_data["unique_together_fields"], + self.serializer.unique_together_fields, + ) + + def test_map_upload_csv_headers_get(self): + """Should show form for header associations.""" + + self.initialize_csv_data() + job = QueryCsvUploadJob.objects.create( + serializer_class=self.serializer_class, filepath=self.filepath + ) + + req = self.req_factory.get("/") + res: TemplateResponse = self.views.map_upload_csv_headers( + request=req, id=job.id + ) + self.assertIsInstance(res, TemplateResponse) + + # Check context + context = res.context_data + self.assertEqual(context["upload_job"], job) + self.assertEqual(context["model_class_name"], job.model_class.__name__) + self.assertIsInstance(context["formset"], CsvHeaderMappingFormSet) + + def test_map_upload_csv_headers_post(self): + """Should add custom header associations for upload job.""" + + # Initialize data + self.initialize_csv_data() + df = pd.read_csv(self.filepath) + df.rename(columns={"name": "Test Name"}, inplace=True) + self.df_to_csv(df, self.filepath) + + job = QueryCsvUploadJob.objects.create( + serializer_class=self.serializer_class, filepath=self.filepath + ) + data = { + "form-TOTAL_FORMS": "1", + "form-INITIAL_FORMS": "0", + "form-0-csv_header": "Test Name", + "form-0-object_field": "name", + } + + # Send request + req = self.req_factory.post("/", data=data) + res = self.views.map_upload_csv_headers(request=req, id=job.id) + + self.assertEqual(res.status_code, status.HTTP_302_FOUND) + + # Check mappings + job.refresh_from_db() + self.assertEqual(len(job.custom_fields), 1) + + self.assertEqual(job.custom_fields[0]["column_name"], "Test Name") + self.assertEqual(job.custom_fields[0]["field_name"], "name") diff --git a/app/querycsv/tests/utils.py b/app/querycsv/tests/utils.py new file mode 100644 index 0000000..4c08fe0 --- /dev/null +++ b/app/querycsv/tests/utils.py @@ -0,0 +1,579 @@ +""" +CSV Data Tests Utilities +""" + +import random +import uuid +from typing import Optional + +import numpy as np +import pandas as pd +from django.db import models + +from app.settings import MEDIA_ROOT +from core.abstracts.tests import TestsBase +from core.mock.models import Buster, BusterTag +from core.mock.serializers import BusterCsvSerializer +from lib.faker import fake +from querycsv.services import QueryCsvService +from utils.files import get_media_path +from utils.helpers import clean_list + + +class CsvDataTestsBase(TestsBase): + """ + Base tests for Csv data services. + + Overrides + --------- + Required: + - model_class + - serializer_class + - def get_create_params + - def get_update_params + + Optional: + - dataset_size + + Terms + ----- + - repo: alias for Model.objects + - objects: all instances of Model in database + """ + + model_class = Buster + serializer_class = BusterCsvSerializer + dataset_size = 5 + update_size = 3 + + unique_field = "unique_name" + """The field to test updates against.""" + + def setUp(self) -> None: + self.repo = self.model_class.objects + self.serializer = self.serializer_class() + self.service = QueryCsvService(serializer_class=self.serializer_class) + + return super().setUp() + + # Overrides + ##################### + def get_create_params(self, **kwargs): + return {"name": fake.title(), **kwargs} + + def get_update_params(self, obj: model_class, **kwargs): + return {"name": fake.title(), **kwargs} + + # Initialization + ##################### + def initialize_dataset(self): + """Create mock objects, and any other setup tasks.""" + return self.create_mock_objects() + + def update_dataset(self): + objects = list(self.repo.all()) + + for i in range(self.update_size): + obj = random.choice(objects) + objects.remove(obj) + + self.update_mock_object(obj=obj) + + # Utilities + ##################### + def create_mock_object(self, **kwargs): + return self.repo.create(**self.get_create_params(**kwargs)) + + def create_mock_objects(self, amount: Optional[int] = None): + """Create a set number of models.""" + + if not amount: + amount = self.dataset_size + + for _ in range(amount): + self.create_mock_object() + + def update_mock_object(self, obj: model_class, **kwargs): + """Update the object to differ from csv.""" + + for key, value in self.get_update_params(obj=obj, **kwargs).items(): + setattr(obj, key, value) + + obj.save() + + return obj + + def get_unique_filepath(self): + return get_media_path( + nested_path="tests/csv-data/uploads/", + filename=f"{uuid.uuid4()}.csv", + create_path=True, + ) + + def df_to_csv(self, df: pd.DataFrame, filepath: Optional[str] = None): + """ + Dump a dataframe to a csv, return filepath. + """ + + if filepath is None: + filepath = self.filepath + + df.to_csv(filepath, index=False, mode="w") + + return filepath + + def data_to_df(self, data: list[dict]): + """Convert output of serializer to dataframe.""" + + for model in data: + for key, value in model.items(): + if isinstance(value, list): + model[key] = ",".join([str(v) for v in value]) + + return pd.DataFrame.from_records(data) + + def data_to_csv(self, data: list[dict]): + """Convert list of dicts to a csv, return filepath.""" + + df = self.data_to_df(data) + return self.df_to_csv(df) + + def csv_to_df(self, path: str): + """Convert csv at path to list of objects.""" + + # Start by importing csv + if path.endswith(".xlsx") or path.endswith(".xls"): + df = pd.read_excel(path, dtype=str) + else: + df = pd.read_csv(path, dtype=str) + + df.replace(np.nan, "", inplace=True) + + return df + + # Custom assertions + ##################### + def assertObjectsCount(self, count: int, msg=None): + """Objects count in db should match given count.""" + self.assertEqual(self.repo.count(), count, msg=msg) + + def assertNoObjects(self): + """Database should be empty.""" + + self.assertObjectsCount(0) + + +class CsvDataM2OTestsBase(CsvDataTestsBase): + """ + Test csv data with many-to-one fields. + + Overrides + --------- + Required: + - model_class + - serializer_class + - m2o_model_class + - m2o_selector + - m2o_target_field + - def get_create_params + - def get_update_params + - def get_m2o_create_params + + Optional: + - dataset_size + - m2o_size + - def create_mock_objects + """ + + model_class = Buster + serializer_class = BusterCsvSerializer + m2o_model_class = BusterTag + m2o_size = 2 + + m2o_selector = "one_tag" + """Field on the main object that points to child object.""" + + m2o_target_field = "name" + """Field on child object whose value is used in serializer.""" + + def setUp(self) -> None: + super().setUp() + + self.m2o_repo = self.m2o_model_class.objects + + def get_m2o_create_params(self, **kwargs): + return {"name": fake.title()} + + def create_mock_m2o_object(self, **kwargs): + """Create Many to One object for testing.""" + + return self.m2o_repo.create(**self.get_m2o_create_params(**kwargs)) + + def initialize_dataset(self): + super().initialize_dataset() + + m2o_objects = [] + for i in range(self.m2o_size): + m2o_objects.append(self.create_mock_m2o_object()) + + for obj in self.repo.all(): + setattr(obj, self.m2o_selector, random.choice(m2o_objects)) + + # return self.repo.all() + + def update_dataset(self): + objects = list(self.repo.all()) + m2os = list(self.m2o_repo.all()) + + for _ in range(self.update_size): + obj = random.choice(objects) + objects.remove(obj) + + m2o = random.choice(m2os) + self.update_mock_object(obj=obj, **{self.m2o_selector: m2o}) + + def clear_db(self) -> list: + self.m2o_repo.all().delete() + + return super().clear_db() + + def assertObjectsM2OValidFields(self, df: pd.DataFrame): + """Compare actual objects in the database with expected values in csv.""" + + # Compare csv value with actual value + for index, row in df.iterrows(): + # Raw values in csv + expected_value = row[self.m2o_selector] + + if expected_value is None: + continue + + self.assertIsInstance(expected_value, str) + query = row.to_dict() + obj = self.repo.get( + **{ + k: v + for k, v in query.items() + if k != self.m2o_selector + and k not in self.serializer.readonly_fields + and k not in self.serializer.any_related_fields + } + ) + + m2o_obj = getattr(obj, self.m2o_selector) + actual_value = getattr(m2o_obj, self.m2o_target_field) + + self.assertEqual(expected_value, actual_value) + + +class CsvDataM2MTestsBase(CsvDataTestsBase): + """ + Base utilities for testing many-to-many fields. + + Overrides + --------- + Required: + - model_class + - serializer_class + - m2m_model_class + - m2m_selector + - m2m_target_field + - def get_create_params + - def get_update_params + - def get_m2m_create_params + + Optional: + - dataset_size + - m2m_size + - m2m_update_size + """ + + model_class = Buster + serializer_class = BusterCsvSerializer + + m2m_model_class = BusterTag + m2m_size = 10 + m2m_update_size = 4 + m2m_assignment_max = 3 + + m2m_selector = "many_tags" + """Field on the main object that points to child object.""" + + m2m_target_field = "name" + """Field on child object whose value is used in serializer.""" + + def setUp(self) -> None: + super().setUp() + + self.m2m_repo = self.m2m_model_class.objects + + if self.m2m_selector not in self.model_class.get_fields_list(): + self.m2m_model_selector = self.serializer.get_fields()[ + self.m2m_selector + ].source + else: + self.m2m_model_selector = self.m2m_selector + + def get_m2m_create_params(self, **kwargs): + return {"name": fake.title(), **kwargs} + + def initialize_dataset(self): + super().initialize_dataset() + + m2m_objects = [] + for i in range(self.m2m_size): + m2m_objects.append(self.create_mock_m2m_object()) + + for obj in self.repo.all(): + m2m_repo = getattr(obj, self.m2m_model_selector) + + assignment_count = random.randint(0, self.m2m_assignment_max) + selected_m2m_objects = random.sample(m2m_objects, assignment_count) + + for m2m_obj in selected_m2m_objects: + m2m_repo.add(m2m_obj) + + obj.save() + + def update_dataset(self): + return super().update_dataset() + + def create_mock_m2m_object(self, **kwargs): + return self.m2m_model_class.objects.create( + **self.get_m2m_create_params(**kwargs) + ) + + def assertObjectsM2MValidFields( + self, df: pd.DataFrame, objects_before: list[dict] = None + ): + """Compare expected objects in the csv with actual objects from database.""" + + # Compare csv value with actual value + for index, row in df.iterrows(): + # Raw value in csv + expected_value = row[self.m2m_selector] + + if expected_value is None: + continue + + # self.assertIsInstance(expected_value, str) + csv_values = row.to_dict() + query = None + + for key, value in csv_values.items(): + # Skip fields if they represent object, are none, or are for the serializer only + if ( + key == self.m2m_selector + or key in self.serializer.readonly_fields + or value is None + or key not in self.model_class.get_fields_list() + ): + continue + + query_filter = models.Q(**{key: value}) + query = query & query_filter if query is not None else query_filter + + actual_obj = self.repo.get(query) + actual_related_objs = getattr(actual_obj, self.m2m_selector).all() + + # Check database against csv + expected_values = [str(v).strip() for v in str(expected_value).split(",")] + expected_values = clean_list(expected_values) + + actual_values = [ + getattr(obj, self.m2m_target_field) for obj in actual_related_objs + ] + actual_values = clean_list(actual_values) + + self.assertListEqual(expected_values, actual_values) + + +class DownloadCsvTestsBase(CsvDataTestsBase): + """ + Base utilities for download csv tests. + + Overrides + --------- + Required: + - model_class + - serializer_class + - def get_create_params + - def get_update_params + + Optional: + - dataset_size + + Terms + ----- + - repo: alias for Model.objects + - objects: all instances of Model in database + """ + + # def initialize_dataset(self): + # """Create database objects, return queryset.""" + + # self.create_mock_objects() + # self.assertObjectsCount(self.dataset_size) + + # return self.repo.all() + + def assertValidCsv(self, filepath: str): + """File at filepath should be a valid csv.""" + + self.assertFileExists(filepath) + self.assertStartsWith(filepath, MEDIA_ROOT) + self.assertEndsWith(filepath, ".csv") + + def assertCsvHasFields(self, df: pd.DataFrame): + """Iterate over csv data and verify with DB.""" + + records = df.to_dict("records") + + for record in records: + id = record.get("id") + actual_object = self.repo.get_by_id(id) + + actual_serializer = self.serializer_class(actual_object) + + for field, expected_value in actual_serializer.data.items(): + self.assertIn(field, record.keys()) + + actual_value = record[field] + + if field in self.serializer.many_related_fields: + actual_values = [ + val.strip() for val in str(actual_value).split(",") + ] + actual_values.sort() + + expected_values = [str(val) for val in expected_value] + expected_values.sort() + + self.assertListEqual( + clean_list(actual_values), clean_list(expected_values) + ) + else: + self.assertEqual(str(actual_value or ""), str(expected_value or "")) + + +class UploadCsvTestsBase(CsvDataTestsBase): + """ + Base utilities for upload data service. + + Overrides + --------- + Required: + - model_class + - serializer_class + - def get_create_params + - def get_update_params + + Optional: + - dataset_size + - def create_mock_object # Calls get_create_params by default + - def create_mock_objects + - def update_mock_object # Calls get_update_params by default + """ + + def setUp(self) -> None: + super().setUp() + + self.filepath = self.get_unique_filepath() + + def create_objects(self): + # Create test models + self.assertNoObjects() + self.create_mock_objects() + + objects = self.repo.all() + self.assertEqual(objects.count(), self.dataset_size) + + return objects + + def dump_csv(self, query: models.QuerySet): + """Manually Print query to csv, independent of services.""" + + data = self.serializer_class(query, many=True).data + self.df = self.data_to_df(data) + self.df_to_csv(self.df) + + def initialize_csv_data(self, clear_db=True): + """Create csv with data, then clear the database.""" + + # Initialize data + self.initialize_dataset() + objects = self.repo.all() + objects_before = objects.values() + self.dump_csv(objects) + + # Clear database + if clear_db: + self.clear_db() + + return objects_before + + def clear_db(self) -> list: + """Save list of current objects and clear the database.""" + + self.repo.all().delete() + self.assertNoObjects() + + def assertObjectsExist(self, pre_queryset: list, msg=None): + """Objects represented in queryset should exist in the database.""" + self.assertObjectsCount(self.dataset_size, msg=msg) + + for expected_obj in pre_queryset: + query = self.repo.filter(**expected_obj) + self.assertTrue(query.exists(), msg=msg) + + def assertObjectsHaveFields(self, expected_objects: list[dict]): + """ + Check if the actual object has expected fields. + + Verify by comparing the serialized representation for before and after + the upload - both should have the save value for writable fields. + """ + + for expected_obj in expected_objects: + expected_serializer = self.serializer_class(data=expected_obj) + self.assertValidSerializer(expected_serializer) + + # Search for object matching query + query = { + k: v + for k, v in expected_obj.items() + if k in self.serializer.writable_fields + and k not in self.serializer.any_related_fields + and k in self.model_class.get_fields_list() + and v is not None + } + + # Extra parsing for query + for k, v in query.items(): + if isinstance(v, str): + query[k] = v.strip() + + # Validate object fields + actual_object = self.repo.get(**query) + actual_serializer = self.serializer_class(actual_object) + + for field in self.serializer.writable_fields: + if ( + field not in expected_serializer.data.keys() + and field not in actual_serializer.data.keys() + ): + continue + + expected_value = expected_serializer.data[field] + actual_value = actual_serializer.data[field] + + if isinstance(expected_value, str): + expected_value.strip() + + self.assertFalse(str(expected_value).startswith(" ")) + self.assertFalse(str(expected_value).endswith(" ")) + + self.assertFalse(str(actual_value).startswith(" ")) + self.assertFalse(str(actual_value).endswith(" ")) + + self.assertEqual(expected_value, actual_value) diff --git a/app/querycsv/views.py b/app/querycsv/views.py new file mode 100644 index 0000000..3f7b197 --- /dev/null +++ b/app/querycsv/views.py @@ -0,0 +1,135 @@ +import logging +from typing import Type + +from django.http import HttpRequest +from django.shortcuts import get_object_or_404, redirect +from django.template.response import TemplateResponse + +from core.abstracts.serializers import ModelSerializerBase +from querycsv.forms import CsvHeaderMappingFormSet, CsvUploadForm +from querycsv.models import QueryCsvUploadJob +from querycsv.services import QueryCsvService +from querycsv.signals import send_process_csv_job_signal + + +class QueryCsvViewSet: + serializer_class: Type[ModelSerializerBase] + + def __init__( + self, + serializer_class: Type[ModelSerializerBase], + get_reverse: callable, + message_user_fn=None, + ): + self.serializer_class = serializer_class + self.serializer = serializer_class() + self.service = QueryCsvService(serializer_class) + self.get_reverse = get_reverse + self.message_user = message_user_fn + + def message_user_fallback(*args, **kwargs): + pass + + if not self.message_user: + self.message_user = message_user_fallback + + def upload_csv(self, request: HttpRequest, extra_context=None): + """Upload csv for processing.""" + context = extra_context if extra_context else {} + + context["template_url"] = self.get_reverse("csv_template") + context["all_fields"] = self.service.flat_fields.values() + context["unique_together_fields"] = ( + self.serializer_class().unique_together_fields + ) + + # Not able to upload csv if no serializer is set + if self.serializer_class is None: + return TemplateResponse( + request, "admin/querycsv/upload_not_available.html", context + ) + + if request.POST: + form = CsvUploadForm(data=request.POST, files=request.FILES) + + if form.is_valid(): + # Process new csv + job = QueryCsvUploadJob.objects.create( + serializer_class=self.serializer_class, + notify_email=request.user.email, + file=request.FILES["file"], + ) + + return redirect(self.get_reverse("upload_headermapping"), id=job.id) + else: + context["form"] = form + + return TemplateResponse( + request, "admin/querycsv/upload_csv.html", context=context + ) + else: + context["form"] = CsvUploadForm() + + return TemplateResponse( + request, "admin/querycsv/upload_csv.html", context=context + ) + + def map_upload_csv_headers(self, request: HttpRequest, id: int, extra_context=None): + """Given a csv upload job, define custom mappings between csv headers and object fields.""" + + job = get_object_or_404(QueryCsvUploadJob, id=id) + # TODO: What to do if job is completed, or url is visited for a previous job + + context = { + **(extra_context or {}), + "upload_job": job, + "model_class_name": job.model_class.__name__, + } + + if request.POST: + formset = CsvHeaderMappingFormSet(request.POST, upload_job=job) + + if formset.is_valid(): + custom_mappings = [ + mapping + for mapping in formset.cleaned_data + if mapping["csv_header"] != mapping["object_field"] + ] + + for mapping in custom_mappings: + job.add_field_mapping( + column_name=mapping["csv_header"], + field_name=mapping["object_field"], + commit=False, + ) + + job.save() + + send_process_csv_job_signal(job) + self.message_user(request, "Successfully uploaded csv.", logging.INFO) + + return redirect(self.get_reverse()) + + else: + initial_data = [] + + for header in job.csv_headers: + cleaned_header = header.strip().lower().replace(" ", "_") + # if cleaned_header in self.serializer.all_field_names: + if cleaned_header in self.service.flat_fields.keys(): + initial_mapping = { + "csv_header": header, + "object_field": cleaned_header, + } + else: + initial_mapping = {"csv_header": header, "object_field": "pass"} + + initial_data.append(initial_mapping) + + formset = CsvHeaderMappingFormSet(initial=initial_data, upload_job=job) + + context["formset"] = formset + + return TemplateResponse( + request, "admin/querycsv/upload_csv_headermapping.html", context=context + ) diff --git a/app/utils/admin.py b/app/utils/admin.py index acf3409..4226dfc 100644 --- a/app/utils/admin.py +++ b/app/utils/admin.py @@ -1,3 +1,4 @@ +from django.contrib import admin from django.utils.translation import gettext_lazy as _ other_info_fields = ( @@ -18,3 +19,22 @@ ) __all__ = ("other_info_fields",) + + +def get_admin_context(request, extra_context=None): + """Get default context dict for the admin site.""" + + return {**admin.site.each_context(request), **(extra_context or {})} + + +def get_model_admin_reverse(admin_name, model, url_context): + """Format info to proper reversable format.""" + + info = ( + admin_name, + model._meta.app_label, + model._meta.model_name, + url_context, + ) + + return "%s:%s_%s_%s" % info diff --git a/app/utils/files.py b/app/utils/files.py index 63ba590..7e8f9d9 100644 --- a/app/utils/files.py +++ b/app/utils/files.py @@ -2,7 +2,9 @@ from pathlib import Path from typing import Optional -from app.settings import MEDIA_ROOT +from django.db import models + +from app.settings import MEDIA_ROOT, S3_STORAGE_BACKEND # def get_media_dir(nested_path=""): # return Path(MEDIA_ROOT, nested_path) @@ -54,3 +56,18 @@ def get_media_path( path = Path(path, filename) return str(path) + + +def get_file_path(file: models.FileField): + """ + Returns the appropriate path for a file. + + In production, this returns file.url, and in development + mode it returns file.path. This is because boto3 will + raise an error if file.path is called in production. + """ + + if S3_STORAGE_BACKEND is True: + return file.url + else: + return file.path diff --git a/app/utils/helpers.py b/app/utils/helpers.py index 2b31385..ca8e91d 100644 --- a/app/utils/helpers.py +++ b/app/utils/helpers.py @@ -75,3 +75,20 @@ def import_from_path(path: str): """ return import_string(path) + + +def clean_list(target: list): + """Remove None values and empty strings from list.""" + + return [item for item in target if item is not None and item != ""] + + +def str_to_list(target: str | None): + """Split string into list using comma as a separator.""" + if not isinstance(target, str): + return [] + + items = target.split(",") + items = clean_list([item.strip() for item in items]) + + return items diff --git a/app/utils/models.py b/app/utils/models.py index e59021c..bfe2cfd 100644 --- a/app/utils/models.py +++ b/app/utils/models.py @@ -1,15 +1,16 @@ import os import uuid +from pathlib import Path +from django.core.files import File from django.db import models from django.db.models.fields.related_descriptors import ReverseOneToOneDescriptor from django.utils.deconstruct import deconstructible from rest_framework.fields import ObjectDoesNotExist +from utils.helpers import import_from_path from utils.types import T -# from utils.types import T - @deconstructible class UploadFilepathFactory(object): @@ -22,17 +23,45 @@ class UploadFilepathFactory(object): Ex: "/user/profile/" -> "/media/uploads/user/profile/" """ - def __init__(self, path: str): + def __init__(self, path: str, default_extension=None): self.path = path + self.default_extension = default_extension def __call__(self, instance, filename): - extension = filename.split(".")[-1] - filename = "{}.{}".format(uuid.uuid4().hex, extension) + if "." in filename: + extension = filename.split(".")[-1] + else: + extension = self.default_extension or "" + filename = "{}.{}".format(uuid.uuid4().hex, extension) nested_dirs = [dirname for dirname in self.path.split("/") if dirname] return os.path.join("uploads", *nested_dirs, filename) +@deconstructible +class ValidateImportString(object): + """ + Validate that a given string can be imported using the `import_from_path` function. + """ + + def __init__(self, target_type=None) -> None: + self.target_type = target_type + + def __call__(self, text: str): + symbol = import_from_path(text) + # print( + # "symbol:", + # symbol, + # " target type:", + # self.target_type, + # " is instance:", + # isinstance(symbol, self.target_type), + # ) + assert issubclass( + symbol, self.target_type + ), f"Imported object needs to be of type {self.target_type}, but got {type(symbol)}." + + class ReverseOneToOneOrNoneDescriptor(ReverseOneToOneDescriptor): def __get__(self, instance, cls=None): try: @@ -51,3 +80,18 @@ class OneToOneOrNoneField(models.OneToOneField[T]): """ # noqa: E501 related_accessor_class = ReverseOneToOneOrNoneDescriptor + + +def save_file_to_model(model: models.Model, filepath, field="file"): + """ + Given file path, save a file to a given model. + + This abstracts the process of opening the file and + copying over the file data. + """ + path = Path(filepath) + + with path.open(mode="rb") as f: + file = File(f, name=path.name) + setattr(model, field, file) + model.save() diff --git a/docker-compose.yml b/docker-compose.yml index c907650..5fa87af 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,6 +9,7 @@ services: - '8000:8000' volumes: - ./app:/app + - static-clubs-dev:/vol/static command: > sh -c "python manage.py wait_for_db && python manage.py migrate && @@ -25,7 +26,6 @@ services: - DJANGO_SUPERUSER_EMAIL=${DJANGO_SUPERUSER_EMAIL:-admin@example.com} - DJANGO_SUPERUSER_PASS=${DJANGO_SUPERUSER_PASS:-changeme} - DJANGO_BASE_URL=${DJANGO_BASE_URL:-http://localhost:8000} - - S3_STORAGE_BACKEND=0 - CREATE_SUPERUSER=1 - EMAIL_HOST_USER=${EMAIL_HOST_USER:-""} - EMAIL_HOST_PASS=${EMAIL_HOST_PASS:-""} @@ -42,8 +42,12 @@ services: - CELERY_ACKS_LATE=True - DJANGO_DB=postgresql - DJANGO_REDIS_URL=redis://clubs-dev-redis:6379/1 + + - AWS_EXECUTION_ENV=0 + - S3_STORAGE_BACKEND=0 depends_on: - postgres + - redis postgres: image: postgres:13-alpine @@ -71,7 +75,7 @@ services: command: ['celery', '-A', 'app', 'worker', '--loglevel=info'] volumes: - ./app:/app - - static-clubs-dev:/vol/web + - static-clubs-dev:/vol/static depends_on: - redis - postgres @@ -80,15 +84,31 @@ services: - DEBUG=1 - CELERY_BROKER_URL=redis://clubs-dev-redis:6379/0 - CELERY_RESULT_BACKEND=redis://clubs-dev-redis:6379/0 - - DJANGO_DB=postgresql + - CELERY_ACKS_LATE=True - POSTGRES_HOST=clubs-dev-db - POSTGRES_PORT=5432 - - POSTGRES_NAME=devdatabase - POSTGRES_DB=devdatabase + - POSTGRES_NAME=devdatabase - POSTGRES_USER=devuser - POSTGRES_PASSWORD=devpass + - DJANGO_DB=postgresql + - DJANGO_REDIS_URL=redis://clubs-dev-redis:6379/1 + - DB_HOST=clubs-dev-db + - DB_NAME=devdatabase + - DB_USER=devuser + - DB_PASS=devpass + + - EMAIL_HOST_USER=${EMAIL_HOST_USER:-""} + - EMAIL_HOST_PASS=${EMAIL_HOST_PASS:-""} + - CONSOLE_EMAIL_BACKEND=${CONSOLE_EMAIL_BACKEND:-1} + - SENDGRID_API_KEY=${SENDGRID_API_KEY:-""} + - DEFAULT_FROM_EMAIL=${DEFAULT_FROM_EMAIL:-""} + + - AWS_EXECUTION_ENV=0 + - S3_STORAGE_BACKEND=0 + celerybeat: build: context: . @@ -108,7 +128,7 @@ services: ] volumes: - ./app:/app - - static-clubs-dev:/vol/web + - static-clubs-dev:/vol/static depends_on: - redis - postgres @@ -122,6 +142,7 @@ services: - POSTGRES_HOST=clubs-dev-db - POSTGRES_PORT=5432 + - POSTGRES_DB=devdatabase - POSTGRES_NAME=devdatabase - POSTGRES_USER=devuser - POSTGRES_PASSWORD=devpass @@ -131,6 +152,15 @@ services: - DB_USER=devuser - DB_PASS=devpass + - EMAIL_HOST_USER=${EMAIL_HOST_USER:-""} + - EMAIL_HOST_PASS=${EMAIL_HOST_PASS:-""} + - CONSOLE_EMAIL_BACKEND=${CONSOLE_EMAIL_BACKEND:-1} + - SENDGRID_API_KEY=${SENDGRID_API_KEY:-""} + - DEFAULT_FROM_EMAIL=${DEFAULT_FROM_EMAIL:-""} + + - AWS_EXECUTION_ENV=0 + - S3_STORAGE_BACKEND=0 + coverage: image: nginx ports: diff --git a/requirements.txt b/requirements.txt index 432ad4e..069a670 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,12 @@ django-admin-tools>=0.9.3,<1.0 coverage>=7.4.1,<7.5 Pillow>=10.2.0,<11.2 sendgrid>=6.11.0,<6.12 +typing_extensions>=4.12.2,<4.13 + +# csv files +pandas>=2.2.3,<2.3 +xlsxwriter>=3.2.0,<3.3 +pathlib>=1.0.1,<1.1 # QRCodes segno>=1.6.1,<1.7 @@ -16,6 +22,10 @@ celery>=5.4.0,<5.5 redis>=5.0.4,<5.1 django-celery-beat>=2.7.0,<2.8 +# AWS S3 +boto3>=1.34.0,<1.35.0 +django-storages>=1.14.3,<1.15.0 + # Not required for local dev psycopg2>=2.9.3,<2.9.4 # 2.9.10 raises pip error, 2/13/25 uwsgi>=2.0.26,<2.0.27 # 2.0.28 raises pip error, 2/13/25 \ No newline at end of file