Field | +Info | +Can View | +Can Edit | +Required | +Unique | +
---|---|---|---|---|---|
{{ field }} | +{% if field.help_text %}{{ field.help_text }}{% endif %} | +✓ | +{% if field.is_writable %}✓{% endif %} | +{% if field.is_required %}✓{% endif %} | +{% if field.is_unique %}✓{% endif %} | +
+ The table below shows column headers found in the uploaded CSV. +
++ The right side of the table gives all available options for fields to map to + new/updated {{ model_class_name }} objects. +
+Cannot upload csvs for this object.
+{% endblock %} \ No newline at end of file diff --git a/app/querycsv/tests/__init__.py b/app/querycsv/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/querycsv/tests/test_download_data.py b/app/querycsv/tests/test_download_data.py new file mode 100644 index 0000000..7eef286 --- /dev/null +++ b/app/querycsv/tests/test_download_data.py @@ -0,0 +1,120 @@ +""" +CSV Download Tests +""" + +from querycsv.tests.utils import ( + CsvDataM2MTestsBase, + CsvDataM2OTestsBase, + DownloadCsvTestsBase, +) +from utils.helpers import clean_list + + +class DownloadDataTests(DownloadCsvTestsBase): + """Unit tests for download csv data.""" + + def test_download_model_csv(self): + """Should download a csv listing objects for model.""" + + # Create csv using service + self.initialize_dataset() + qs = self.repo.all() + + filepath = self.service.download_csv(queryset=qs) + self.assertValidCsv(filepath) + + # Check csv + df = self.csv_to_df(filepath) + self.assertEqual(len(df.index), self.dataset_size) + + expected_fields = self.serializer.readable_fields + expected_fields.sort() + + actual_fields = list(df.columns) + actual_fields.sort() + + self.assertListEqual(expected_fields, actual_fields) + + # Verify contact fields in csv + self.assertCsvHasFields(df) + + +class DownloadCsvM2OFieldsTests(DownloadCsvTestsBase, CsvDataM2OTestsBase): + """Unit tests for testing downloaded csv many-to-one fields.""" + + def test_download_csv_m2o_fields(self): + """Should be able to download models with many-to-one fields.""" + + # Create csv using service + self.initialize_dataset() + qs = self.repo.all() + + filepath = self.service.download_csv(queryset=qs) + self.assertValidCsv(filepath) + + # Check csv + df = self.csv_to_df(filepath) + self.assertCsvHasFields(df) + + # For each row, check the many-to-one field + for index, row in df.iterrows(): + obj_id = row["id"] + expected_obj = self.repo.get_by_id(obj_id) + + expected_m2o_obj = getattr(expected_obj, self.m2o_selector) + + if expected_m2o_obj is None: + expected_value = None + else: + expected_value = getattr(expected_m2o_obj, self.m2o_target_field) + + actual_value = row[self.m2o_selector] + if actual_value == "": + actual_value = None + + self.assertEqual(actual_value, expected_value) + + +class DownloadCsvM2MFieldsStrTests(DownloadCsvTestsBase, CsvDataM2MTestsBase): + """Unit tests for testing downloaded csv many-to-many fields with str slug.""" + + def test_download_csv_m2m_fields(self): + """Should be able to download models with many-to-many fields.""" + + # Create csv using service + self.initialize_dataset() + qs = self.repo.all() + + filepath = self.service.download_csv(queryset=qs) + self.assertValidCsv(filepath) + + # Check csv + df = self.csv_to_df(filepath) + self.assertCsvHasFields(df) + + # For each row, check the many-to-one field + for index, row in df.iterrows(): + obj_id = row["id"] + expected_obj = self.repo.get_by_id(obj_id) + + expected_m2m_objs = getattr(expected_obj, self.m2m_model_selector) + expected_values = clean_list( + [ + str(getattr(obj, self.m2m_target_field)) + for obj in expected_m2m_objs.all() + ] + ) + + actual_value_raw = str(row[self.m2m_selector]) + actual_values = clean_list( + [str(v).strip() for v in actual_value_raw.split(",")] + ) + + self.assertListEqual(actual_values, expected_values) + + +class DownloadCsvM2MFieldsIntTests(DownloadCsvM2MFieldsStrTests): + """Unit tests for testing downloaded csv many-to-many fields with int slug.""" + + m2m_selector = "many_tags_int" + m2m_target_field = "id" diff --git a/app/querycsv/tests/test_upload_data.py b/app/querycsv/tests/test_upload_data.py new file mode 100644 index 0000000..954739c --- /dev/null +++ b/app/querycsv/tests/test_upload_data.py @@ -0,0 +1,277 @@ +""" +Import/upload data tests. +""" + +from django.contrib.postgres.aggregates import StringAgg +from django.db import models + +from querycsv.models import QueryCsvUploadJob +from querycsv.services import QueryCsvService +from querycsv.tests.utils import ( + CsvDataM2MTestsBase, + CsvDataM2OTestsBase, + UploadCsvTestsBase, +) + + +class UploadDataTests(UploadCsvTestsBase): + """Test uploading data from a csv.""" + + def test_create_objects_from_csv(self): + """Should be able to take csv and create models.""" + + # Initialize data + objects_before = self.initialize_csv_data() + + # Call service upload function + _, failed = self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsExist(objects_before, failed) + self.assertObjectsHaveFields(objects_before) + + def test_update_objects_from_csv(self): + # Initialize data + objects_before = self.initialize_csv_data(clear_db=False) + + for obj in self.repo.all(): + self.update_mock_object(obj) + + # Call service upload function + self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsExist(objects_before) + self.assertObjectsHaveFields(objects_before) + + def test_upload_csv_bad_fields(self): + """Should create objects and ignore bad fields.""" + + # Initialize csv, add invalid column + objects_before = self.initialize_csv_data() + self.df["Invalid field"] = "bad value" + self.df_to_csv(self.df) + + self.assertTrue("Invalid field" in list(self.df.columns)) + + self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsExist(objects_before) + self.assertObjectsHaveFields(objects_before) + + def test_upload_csv_update_objects(self): + """Uploading a csv should update objects.""" + + # Prep data, create csv + objects_before = self.initialize_csv_data(clear_db=False) + + updated_records = [] + + for obj in objects_before: + payload = {self.unique_field: obj[self.unique_field]} + payload = self.get_update_params(obj, **payload) + updated_records.append(payload) + + self.data_to_csv(updated_records) + + # Upload CSV + self.service.upload_csv(path=self.filepath) + + # Validate data + self.assertObjectsHaveFields(updated_records) + + def test_upload_csv_spaces(self): + """Should remove pre/post spaces from fields before updating/creating.""" + + # Prep data, create csv + objects_before = self.initialize_csv_data(clear_db=False) + + updated_records = [] + + for obj in objects_before: + payload = {self.unique_field: f" {obj[self.unique_field]} "} + payload = self.get_update_params(obj, **payload) + updated_records.append(payload) + + self.data_to_csv(updated_records) + self.assertObjectsExist(objects_before) + + # Upload CSV + self.service.upload_csv(path=self.filepath) + + # Validate data + self.assertObjectsHaveFields(updated_records) + + +class UploadCsvJobTests(UploadCsvTestsBase): + """Tests for uploading with QSCsv Model.""" + + def test_upload_from_job(self): + """Should upload and process csv from model.""" + + # Initialize data + objects_before = self.initialize_csv_data(clear_db=False) + + # Update fields after create csv + for obj in self.repo.all(): + self.update_mock_object(obj) + + # Upload csv via service + job = QueryCsvUploadJob.objects.create( + filepath=self.filepath, + serializer_class=self.serializer_class, + ) + QueryCsvService.upload_from_job(job) + + # Validate database + self.assertObjectsExist(objects_before) + self.assertObjectsHaveFields(objects_before) + + def test_upload_custom_fields(self): + """Should process csv with custom field mappings.""" + + objects_before = self.initialize_csv_data() + + # Rename csv field + self.df.rename(columns={"name": "Test Value"}, inplace=True) + self.df_to_csv(self.df, self.filepath) + + # Create and upload job + job = QueryCsvUploadJob.objects.create( + serializer_class=self.serializer_class, filepath=self.filepath + ) + job.add_field_mapping(column_name="Test Value", field_name="name") + job.refresh_from_db() + + QueryCsvService.upload_from_job(job) + + # Validate database + self.assertObjectsExist(pre_queryset=objects_before) + self.assertObjectsHaveFields(expected_objects=objects_before) + + +class UploadCsvM2OFieldsTests(UploadCsvTestsBase, CsvDataM2OTestsBase): + """Test uploading csvs for models with many-to-one fields.""" + + def test_upload_csv_m2o_fields(self): + """ + Should present Many-to-One (FK) fields according to serializer. + + Check by comparing the serialized representation before and after + the upload - both should have the save value for writable fields. + """ + + # Initialize data + objects_before = self.initialize_csv_data() + + # Call upload function + self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsHaveFields(objects_before) + self.assertIn(self.m2o_selector, list(self.df.columns)) + + self.assertObjectsM2OValidFields(self.df) + + def test_upload_csv_m2o_fields_update(self): + """Should update models with Many-to-One fields.""" + + # Initialize data + objects_before = self.initialize_csv_data(clear_db=False) + + # Update fields after create csv + for obj in self.repo.all(): + self.update_mock_object(obj) + + # Call upload function + self.service.upload_csv(path=self.filepath) + + # Validate database + self.assertObjectsHaveFields(objects_before) + self.assertIn(self.m2o_selector, list(self.df.columns)) + + self.assertObjectsM2OValidFields(self.df) + + +class UploadCsvM2MFieldsTests(UploadCsvTestsBase, CsvDataM2MTestsBase): + """Test uploading csvs for models with many-to-many fields.""" + + def test_upload_csv_m2m_fields(self): + """When csv is uploaded, m2m fields should be handled properly.""" + + # Initialize data + objects_before = self.initialize_csv_data() + + # Upload csv using service + self.service.upload_csv(path=self.filepath) + + # Validate results + self.assertObjectsHaveFields(objects_before) + self.assertIn(self.m2m_selector, list(self.df.columns)) + + self.assertObjectsM2MValidFields(self.df) + + def test_upload_csv_m2m_fields_spaces(self): + """When csv is uploaded, m2m fields should be stripped of leading/trailing spaces.""" + + objects_before = self.initialize_csv_data() + + # Iterate through csv, manually add spacing + for i, row in self.df.iterrows(): + pre_value = row[self.m2m_selector] + pre_values = pre_value.split(",") + modified_value = " , ".join(pre_values) + row[self.m2m_selector] = modified_value + + self.df_to_csv(self.df) + + # Upload csv using service + self.service.upload_csv(path=self.filepath) + + # Validate results + self.assertObjectsHaveFields(objects_before) + self.assertIn(self.m2m_selector, list(self.df.columns)) + + self.assertObjectsM2MValidFields(self.df) + + def test_upload_csv_m2m_update_fields(self): + """When csv is uploaded, should update objects with many-to-many fields.""" + + # Initialize data + objects_before = self.initialize_csv_data(clear_db=False) + + # Update fields after create csv + self.update_dataset() + # for obj in self.repo.all().prefetch_related(self.m2m_selector): + # self.update_mock_object(obj) + + objects_before = list( + self.repo.all() + .annotate( + pre_objs_count=models.Count(self.m2m_selector), + pre_objs=StringAgg( + models.F(f"{self.m2m_selector}__{self.m2m_target_field}"), + distinct=True, + delimiter=",", + ), + ) + .values() + ) + + # Upload csv using service + success, failed = self.service.upload_csv(path=self.filepath) + + # Validate results + self.assertEqual(self.repo.all().count(), self.dataset_size) + expected_objects = list(self.df.to_dict("records")) + + self.assertObjectsHaveFields(expected_objects) + self.assertIn(self.m2m_selector, list(self.df.columns)) + self.assertTrue( + self.m2m_repo.all().count() <= self.m2m_size + self.m2m_update_size, + f"Expected at most {self.m2m_size + self.m2m_update_size} M2M objects, " + f"but {self.m2m_repo.all().count()} were created.", + ) + + self.assertObjectsM2MValidFields(self.df, objects_before) diff --git a/app/querycsv/tests/test_upload_views.py b/app/querycsv/tests/test_upload_views.py new file mode 100644 index 0000000..cc44213 --- /dev/null +++ b/app/querycsv/tests/test_upload_views.py @@ -0,0 +1,108 @@ +import pandas as pd +from django.template.response import TemplateResponse +from django.test import RequestFactory +from rest_framework import status + +from core.mock.models import Buster +from core.mock.serializers import BusterCsvSerializer +from querycsv.forms import CsvHeaderMappingFormSet, CsvUploadForm +from querycsv.models import QueryCsvUploadJob +from querycsv.tests.test_upload_data import UploadCsvTestsBase +from querycsv.views import QueryCsvViewSet + + +class UploadCsvViewsTests(UploadCsvTestsBase): + """Test functionality for upload views used in admin.""" + + model_class = Buster + serializer_class = BusterCsvSerializer + + def get_reverse(self, name="fallback"): + return "core:index" + + def setUp(self): + self.views = QueryCsvViewSet( + self.serializer_class, get_reverse=self.get_reverse + ) + self.req_factory = RequestFactory() + + return super().setUp() + + #################### + # == Unit Tests == # + #################### + + def test_upload_csv(self): + """Should show form for uploading csv.""" + + req = self.req_factory.get("/") + res: TemplateResponse = self.views.upload_csv(request=req) + + self.assertIsInstance(res, TemplateResponse) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + # Check context + self.assertIsInstance(res.context_data["form"], CsvUploadForm) + self.assertEqual( + res.context_data["template_url"], self.get_reverse("csv_template") + ) + # FIXME: Checking csv fields in context failes + # self.assertEqual( + # res.context_data["all_fields"], self.service.flat_fields.values() + # ) + self.assertEqual( + res.context_data["unique_together_fields"], + self.serializer.unique_together_fields, + ) + + def test_map_upload_csv_headers_get(self): + """Should show form for header associations.""" + + self.initialize_csv_data() + job = QueryCsvUploadJob.objects.create( + serializer_class=self.serializer_class, filepath=self.filepath + ) + + req = self.req_factory.get("/") + res: TemplateResponse = self.views.map_upload_csv_headers( + request=req, id=job.id + ) + self.assertIsInstance(res, TemplateResponse) + + # Check context + context = res.context_data + self.assertEqual(context["upload_job"], job) + self.assertEqual(context["model_class_name"], job.model_class.__name__) + self.assertIsInstance(context["formset"], CsvHeaderMappingFormSet) + + def test_map_upload_csv_headers_post(self): + """Should add custom header associations for upload job.""" + + # Initialize data + self.initialize_csv_data() + df = pd.read_csv(self.filepath) + df.rename(columns={"name": "Test Name"}, inplace=True) + self.df_to_csv(df, self.filepath) + + job = QueryCsvUploadJob.objects.create( + serializer_class=self.serializer_class, filepath=self.filepath + ) + data = { + "form-TOTAL_FORMS": "1", + "form-INITIAL_FORMS": "0", + "form-0-csv_header": "Test Name", + "form-0-object_field": "name", + } + + # Send request + req = self.req_factory.post("/", data=data) + res = self.views.map_upload_csv_headers(request=req, id=job.id) + + self.assertEqual(res.status_code, status.HTTP_302_FOUND) + + # Check mappings + job.refresh_from_db() + self.assertEqual(len(job.custom_fields), 1) + + self.assertEqual(job.custom_fields[0]["column_name"], "Test Name") + self.assertEqual(job.custom_fields[0]["field_name"], "name") diff --git a/app/querycsv/tests/utils.py b/app/querycsv/tests/utils.py new file mode 100644 index 0000000..4c08fe0 --- /dev/null +++ b/app/querycsv/tests/utils.py @@ -0,0 +1,579 @@ +""" +CSV Data Tests Utilities +""" + +import random +import uuid +from typing import Optional + +import numpy as np +import pandas as pd +from django.db import models + +from app.settings import MEDIA_ROOT +from core.abstracts.tests import TestsBase +from core.mock.models import Buster, BusterTag +from core.mock.serializers import BusterCsvSerializer +from lib.faker import fake +from querycsv.services import QueryCsvService +from utils.files import get_media_path +from utils.helpers import clean_list + + +class CsvDataTestsBase(TestsBase): + """ + Base tests for Csv data services. + + Overrides + --------- + Required: + - model_class + - serializer_class + - def get_create_params + - def get_update_params + + Optional: + - dataset_size + + Terms + ----- + - repo: alias for Model.objects + - objects: all instances of Model in database + """ + + model_class = Buster + serializer_class = BusterCsvSerializer + dataset_size = 5 + update_size = 3 + + unique_field = "unique_name" + """The field to test updates against.""" + + def setUp(self) -> None: + self.repo = self.model_class.objects + self.serializer = self.serializer_class() + self.service = QueryCsvService(serializer_class=self.serializer_class) + + return super().setUp() + + # Overrides + ##################### + def get_create_params(self, **kwargs): + return {"name": fake.title(), **kwargs} + + def get_update_params(self, obj: model_class, **kwargs): + return {"name": fake.title(), **kwargs} + + # Initialization + ##################### + def initialize_dataset(self): + """Create mock objects, and any other setup tasks.""" + return self.create_mock_objects() + + def update_dataset(self): + objects = list(self.repo.all()) + + for i in range(self.update_size): + obj = random.choice(objects) + objects.remove(obj) + + self.update_mock_object(obj=obj) + + # Utilities + ##################### + def create_mock_object(self, **kwargs): + return self.repo.create(**self.get_create_params(**kwargs)) + + def create_mock_objects(self, amount: Optional[int] = None): + """Create a set number of models.""" + + if not amount: + amount = self.dataset_size + + for _ in range(amount): + self.create_mock_object() + + def update_mock_object(self, obj: model_class, **kwargs): + """Update the object to differ from csv.""" + + for key, value in self.get_update_params(obj=obj, **kwargs).items(): + setattr(obj, key, value) + + obj.save() + + return obj + + def get_unique_filepath(self): + return get_media_path( + nested_path="tests/csv-data/uploads/", + filename=f"{uuid.uuid4()}.csv", + create_path=True, + ) + + def df_to_csv(self, df: pd.DataFrame, filepath: Optional[str] = None): + """ + Dump a dataframe to a csv, return filepath. + """ + + if filepath is None: + filepath = self.filepath + + df.to_csv(filepath, index=False, mode="w") + + return filepath + + def data_to_df(self, data: list[dict]): + """Convert output of serializer to dataframe.""" + + for model in data: + for key, value in model.items(): + if isinstance(value, list): + model[key] = ",".join([str(v) for v in value]) + + return pd.DataFrame.from_records(data) + + def data_to_csv(self, data: list[dict]): + """Convert list of dicts to a csv, return filepath.""" + + df = self.data_to_df(data) + return self.df_to_csv(df) + + def csv_to_df(self, path: str): + """Convert csv at path to list of objects.""" + + # Start by importing csv + if path.endswith(".xlsx") or path.endswith(".xls"): + df = pd.read_excel(path, dtype=str) + else: + df = pd.read_csv(path, dtype=str) + + df.replace(np.nan, "", inplace=True) + + return df + + # Custom assertions + ##################### + def assertObjectsCount(self, count: int, msg=None): + """Objects count in db should match given count.""" + self.assertEqual(self.repo.count(), count, msg=msg) + + def assertNoObjects(self): + """Database should be empty.""" + + self.assertObjectsCount(0) + + +class CsvDataM2OTestsBase(CsvDataTestsBase): + """ + Test csv data with many-to-one fields. + + Overrides + --------- + Required: + - model_class + - serializer_class + - m2o_model_class + - m2o_selector + - m2o_target_field + - def get_create_params + - def get_update_params + - def get_m2o_create_params + + Optional: + - dataset_size + - m2o_size + - def create_mock_objects + """ + + model_class = Buster + serializer_class = BusterCsvSerializer + m2o_model_class = BusterTag + m2o_size = 2 + + m2o_selector = "one_tag" + """Field on the main object that points to child object.""" + + m2o_target_field = "name" + """Field on child object whose value is used in serializer.""" + + def setUp(self) -> None: + super().setUp() + + self.m2o_repo = self.m2o_model_class.objects + + def get_m2o_create_params(self, **kwargs): + return {"name": fake.title()} + + def create_mock_m2o_object(self, **kwargs): + """Create Many to One object for testing.""" + + return self.m2o_repo.create(**self.get_m2o_create_params(**kwargs)) + + def initialize_dataset(self): + super().initialize_dataset() + + m2o_objects = [] + for i in range(self.m2o_size): + m2o_objects.append(self.create_mock_m2o_object()) + + for obj in self.repo.all(): + setattr(obj, self.m2o_selector, random.choice(m2o_objects)) + + # return self.repo.all() + + def update_dataset(self): + objects = list(self.repo.all()) + m2os = list(self.m2o_repo.all()) + + for _ in range(self.update_size): + obj = random.choice(objects) + objects.remove(obj) + + m2o = random.choice(m2os) + self.update_mock_object(obj=obj, **{self.m2o_selector: m2o}) + + def clear_db(self) -> list: + self.m2o_repo.all().delete() + + return super().clear_db() + + def assertObjectsM2OValidFields(self, df: pd.DataFrame): + """Compare actual objects in the database with expected values in csv.""" + + # Compare csv value with actual value + for index, row in df.iterrows(): + # Raw values in csv + expected_value = row[self.m2o_selector] + + if expected_value is None: + continue + + self.assertIsInstance(expected_value, str) + query = row.to_dict() + obj = self.repo.get( + **{ + k: v + for k, v in query.items() + if k != self.m2o_selector + and k not in self.serializer.readonly_fields + and k not in self.serializer.any_related_fields + } + ) + + m2o_obj = getattr(obj, self.m2o_selector) + actual_value = getattr(m2o_obj, self.m2o_target_field) + + self.assertEqual(expected_value, actual_value) + + +class CsvDataM2MTestsBase(CsvDataTestsBase): + """ + Base utilities for testing many-to-many fields. + + Overrides + --------- + Required: + - model_class + - serializer_class + - m2m_model_class + - m2m_selector + - m2m_target_field + - def get_create_params + - def get_update_params + - def get_m2m_create_params + + Optional: + - dataset_size + - m2m_size + - m2m_update_size + """ + + model_class = Buster + serializer_class = BusterCsvSerializer + + m2m_model_class = BusterTag + m2m_size = 10 + m2m_update_size = 4 + m2m_assignment_max = 3 + + m2m_selector = "many_tags" + """Field on the main object that points to child object.""" + + m2m_target_field = "name" + """Field on child object whose value is used in serializer.""" + + def setUp(self) -> None: + super().setUp() + + self.m2m_repo = self.m2m_model_class.objects + + if self.m2m_selector not in self.model_class.get_fields_list(): + self.m2m_model_selector = self.serializer.get_fields()[ + self.m2m_selector + ].source + else: + self.m2m_model_selector = self.m2m_selector + + def get_m2m_create_params(self, **kwargs): + return {"name": fake.title(), **kwargs} + + def initialize_dataset(self): + super().initialize_dataset() + + m2m_objects = [] + for i in range(self.m2m_size): + m2m_objects.append(self.create_mock_m2m_object()) + + for obj in self.repo.all(): + m2m_repo = getattr(obj, self.m2m_model_selector) + + assignment_count = random.randint(0, self.m2m_assignment_max) + selected_m2m_objects = random.sample(m2m_objects, assignment_count) + + for m2m_obj in selected_m2m_objects: + m2m_repo.add(m2m_obj) + + obj.save() + + def update_dataset(self): + return super().update_dataset() + + def create_mock_m2m_object(self, **kwargs): + return self.m2m_model_class.objects.create( + **self.get_m2m_create_params(**kwargs) + ) + + def assertObjectsM2MValidFields( + self, df: pd.DataFrame, objects_before: list[dict] = None + ): + """Compare expected objects in the csv with actual objects from database.""" + + # Compare csv value with actual value + for index, row in df.iterrows(): + # Raw value in csv + expected_value = row[self.m2m_selector] + + if expected_value is None: + continue + + # self.assertIsInstance(expected_value, str) + csv_values = row.to_dict() + query = None + + for key, value in csv_values.items(): + # Skip fields if they represent object, are none, or are for the serializer only + if ( + key == self.m2m_selector + or key in self.serializer.readonly_fields + or value is None + or key not in self.model_class.get_fields_list() + ): + continue + + query_filter = models.Q(**{key: value}) + query = query & query_filter if query is not None else query_filter + + actual_obj = self.repo.get(query) + actual_related_objs = getattr(actual_obj, self.m2m_selector).all() + + # Check database against csv + expected_values = [str(v).strip() for v in str(expected_value).split(",")] + expected_values = clean_list(expected_values) + + actual_values = [ + getattr(obj, self.m2m_target_field) for obj in actual_related_objs + ] + actual_values = clean_list(actual_values) + + self.assertListEqual(expected_values, actual_values) + + +class DownloadCsvTestsBase(CsvDataTestsBase): + """ + Base utilities for download csv tests. + + Overrides + --------- + Required: + - model_class + - serializer_class + - def get_create_params + - def get_update_params + + Optional: + - dataset_size + + Terms + ----- + - repo: alias for Model.objects + - objects: all instances of Model in database + """ + + # def initialize_dataset(self): + # """Create database objects, return queryset.""" + + # self.create_mock_objects() + # self.assertObjectsCount(self.dataset_size) + + # return self.repo.all() + + def assertValidCsv(self, filepath: str): + """File at filepath should be a valid csv.""" + + self.assertFileExists(filepath) + self.assertStartsWith(filepath, MEDIA_ROOT) + self.assertEndsWith(filepath, ".csv") + + def assertCsvHasFields(self, df: pd.DataFrame): + """Iterate over csv data and verify with DB.""" + + records = df.to_dict("records") + + for record in records: + id = record.get("id") + actual_object = self.repo.get_by_id(id) + + actual_serializer = self.serializer_class(actual_object) + + for field, expected_value in actual_serializer.data.items(): + self.assertIn(field, record.keys()) + + actual_value = record[field] + + if field in self.serializer.many_related_fields: + actual_values = [ + val.strip() for val in str(actual_value).split(",") + ] + actual_values.sort() + + expected_values = [str(val) for val in expected_value] + expected_values.sort() + + self.assertListEqual( + clean_list(actual_values), clean_list(expected_values) + ) + else: + self.assertEqual(str(actual_value or ""), str(expected_value or "")) + + +class UploadCsvTestsBase(CsvDataTestsBase): + """ + Base utilities for upload data service. + + Overrides + --------- + Required: + - model_class + - serializer_class + - def get_create_params + - def get_update_params + + Optional: + - dataset_size + - def create_mock_object # Calls get_create_params by default + - def create_mock_objects + - def update_mock_object # Calls get_update_params by default + """ + + def setUp(self) -> None: + super().setUp() + + self.filepath = self.get_unique_filepath() + + def create_objects(self): + # Create test models + self.assertNoObjects() + self.create_mock_objects() + + objects = self.repo.all() + self.assertEqual(objects.count(), self.dataset_size) + + return objects + + def dump_csv(self, query: models.QuerySet): + """Manually Print query to csv, independent of services.""" + + data = self.serializer_class(query, many=True).data + self.df = self.data_to_df(data) + self.df_to_csv(self.df) + + def initialize_csv_data(self, clear_db=True): + """Create csv with data, then clear the database.""" + + # Initialize data + self.initialize_dataset() + objects = self.repo.all() + objects_before = objects.values() + self.dump_csv(objects) + + # Clear database + if clear_db: + self.clear_db() + + return objects_before + + def clear_db(self) -> list: + """Save list of current objects and clear the database.""" + + self.repo.all().delete() + self.assertNoObjects() + + def assertObjectsExist(self, pre_queryset: list, msg=None): + """Objects represented in queryset should exist in the database.""" + self.assertObjectsCount(self.dataset_size, msg=msg) + + for expected_obj in pre_queryset: + query = self.repo.filter(**expected_obj) + self.assertTrue(query.exists(), msg=msg) + + def assertObjectsHaveFields(self, expected_objects: list[dict]): + """ + Check if the actual object has expected fields. + + Verify by comparing the serialized representation for before and after + the upload - both should have the save value for writable fields. + """ + + for expected_obj in expected_objects: + expected_serializer = self.serializer_class(data=expected_obj) + self.assertValidSerializer(expected_serializer) + + # Search for object matching query + query = { + k: v + for k, v in expected_obj.items() + if k in self.serializer.writable_fields + and k not in self.serializer.any_related_fields + and k in self.model_class.get_fields_list() + and v is not None + } + + # Extra parsing for query + for k, v in query.items(): + if isinstance(v, str): + query[k] = v.strip() + + # Validate object fields + actual_object = self.repo.get(**query) + actual_serializer = self.serializer_class(actual_object) + + for field in self.serializer.writable_fields: + if ( + field not in expected_serializer.data.keys() + and field not in actual_serializer.data.keys() + ): + continue + + expected_value = expected_serializer.data[field] + actual_value = actual_serializer.data[field] + + if isinstance(expected_value, str): + expected_value.strip() + + self.assertFalse(str(expected_value).startswith(" ")) + self.assertFalse(str(expected_value).endswith(" ")) + + self.assertFalse(str(actual_value).startswith(" ")) + self.assertFalse(str(actual_value).endswith(" ")) + + self.assertEqual(expected_value, actual_value) diff --git a/app/querycsv/views.py b/app/querycsv/views.py new file mode 100644 index 0000000..3f7b197 --- /dev/null +++ b/app/querycsv/views.py @@ -0,0 +1,135 @@ +import logging +from typing import Type + +from django.http import HttpRequest +from django.shortcuts import get_object_or_404, redirect +from django.template.response import TemplateResponse + +from core.abstracts.serializers import ModelSerializerBase +from querycsv.forms import CsvHeaderMappingFormSet, CsvUploadForm +from querycsv.models import QueryCsvUploadJob +from querycsv.services import QueryCsvService +from querycsv.signals import send_process_csv_job_signal + + +class QueryCsvViewSet: + serializer_class: Type[ModelSerializerBase] + + def __init__( + self, + serializer_class: Type[ModelSerializerBase], + get_reverse: callable, + message_user_fn=None, + ): + self.serializer_class = serializer_class + self.serializer = serializer_class() + self.service = QueryCsvService(serializer_class) + self.get_reverse = get_reverse + self.message_user = message_user_fn + + def message_user_fallback(*args, **kwargs): + pass + + if not self.message_user: + self.message_user = message_user_fallback + + def upload_csv(self, request: HttpRequest, extra_context=None): + """Upload csv for processing.""" + context = extra_context if extra_context else {} + + context["template_url"] = self.get_reverse("csv_template") + context["all_fields"] = self.service.flat_fields.values() + context["unique_together_fields"] = ( + self.serializer_class().unique_together_fields + ) + + # Not able to upload csv if no serializer is set + if self.serializer_class is None: + return TemplateResponse( + request, "admin/querycsv/upload_not_available.html", context + ) + + if request.POST: + form = CsvUploadForm(data=request.POST, files=request.FILES) + + if form.is_valid(): + # Process new csv + job = QueryCsvUploadJob.objects.create( + serializer_class=self.serializer_class, + notify_email=request.user.email, + file=request.FILES["file"], + ) + + return redirect(self.get_reverse("upload_headermapping"), id=job.id) + else: + context["form"] = form + + return TemplateResponse( + request, "admin/querycsv/upload_csv.html", context=context + ) + else: + context["form"] = CsvUploadForm() + + return TemplateResponse( + request, "admin/querycsv/upload_csv.html", context=context + ) + + def map_upload_csv_headers(self, request: HttpRequest, id: int, extra_context=None): + """Given a csv upload job, define custom mappings between csv headers and object fields.""" + + job = get_object_or_404(QueryCsvUploadJob, id=id) + # TODO: What to do if job is completed, or url is visited for a previous job + + context = { + **(extra_context or {}), + "upload_job": job, + "model_class_name": job.model_class.__name__, + } + + if request.POST: + formset = CsvHeaderMappingFormSet(request.POST, upload_job=job) + + if formset.is_valid(): + custom_mappings = [ + mapping + for mapping in formset.cleaned_data + if mapping["csv_header"] != mapping["object_field"] + ] + + for mapping in custom_mappings: + job.add_field_mapping( + column_name=mapping["csv_header"], + field_name=mapping["object_field"], + commit=False, + ) + + job.save() + + send_process_csv_job_signal(job) + self.message_user(request, "Successfully uploaded csv.", logging.INFO) + + return redirect(self.get_reverse()) + + else: + initial_data = [] + + for header in job.csv_headers: + cleaned_header = header.strip().lower().replace(" ", "_") + # if cleaned_header in self.serializer.all_field_names: + if cleaned_header in self.service.flat_fields.keys(): + initial_mapping = { + "csv_header": header, + "object_field": cleaned_header, + } + else: + initial_mapping = {"csv_header": header, "object_field": "pass"} + + initial_data.append(initial_mapping) + + formset = CsvHeaderMappingFormSet(initial=initial_data, upload_job=job) + + context["formset"] = formset + + return TemplateResponse( + request, "admin/querycsv/upload_csv_headermapping.html", context=context + ) diff --git a/app/utils/admin.py b/app/utils/admin.py index acf3409..4226dfc 100644 --- a/app/utils/admin.py +++ b/app/utils/admin.py @@ -1,3 +1,4 @@ +from django.contrib import admin from django.utils.translation import gettext_lazy as _ other_info_fields = ( @@ -18,3 +19,22 @@ ) __all__ = ("other_info_fields",) + + +def get_admin_context(request, extra_context=None): + """Get default context dict for the admin site.""" + + return {**admin.site.each_context(request), **(extra_context or {})} + + +def get_model_admin_reverse(admin_name, model, url_context): + """Format info to proper reversable format.""" + + info = ( + admin_name, + model._meta.app_label, + model._meta.model_name, + url_context, + ) + + return "%s:%s_%s_%s" % info diff --git a/app/utils/files.py b/app/utils/files.py index 63ba590..7e8f9d9 100644 --- a/app/utils/files.py +++ b/app/utils/files.py @@ -2,7 +2,9 @@ from pathlib import Path from typing import Optional -from app.settings import MEDIA_ROOT +from django.db import models + +from app.settings import MEDIA_ROOT, S3_STORAGE_BACKEND # def get_media_dir(nested_path=""): # return Path(MEDIA_ROOT, nested_path) @@ -54,3 +56,18 @@ def get_media_path( path = Path(path, filename) return str(path) + + +def get_file_path(file: models.FileField): + """ + Returns the appropriate path for a file. + + In production, this returns file.url, and in development + mode it returns file.path. This is because boto3 will + raise an error if file.path is called in production. + """ + + if S3_STORAGE_BACKEND is True: + return file.url + else: + return file.path diff --git a/app/utils/helpers.py b/app/utils/helpers.py index 2b31385..ca8e91d 100644 --- a/app/utils/helpers.py +++ b/app/utils/helpers.py @@ -75,3 +75,20 @@ def import_from_path(path: str): """ return import_string(path) + + +def clean_list(target: list): + """Remove None values and empty strings from list.""" + + return [item for item in target if item is not None and item != ""] + + +def str_to_list(target: str | None): + """Split string into list using comma as a separator.""" + if not isinstance(target, str): + return [] + + items = target.split(",") + items = clean_list([item.strip() for item in items]) + + return items diff --git a/app/utils/models.py b/app/utils/models.py index e59021c..bfe2cfd 100644 --- a/app/utils/models.py +++ b/app/utils/models.py @@ -1,15 +1,16 @@ import os import uuid +from pathlib import Path +from django.core.files import File from django.db import models from django.db.models.fields.related_descriptors import ReverseOneToOneDescriptor from django.utils.deconstruct import deconstructible from rest_framework.fields import ObjectDoesNotExist +from utils.helpers import import_from_path from utils.types import T -# from utils.types import T - @deconstructible class UploadFilepathFactory(object): @@ -22,17 +23,45 @@ class UploadFilepathFactory(object): Ex: "/user/profile/" -> "/media/uploads/user/profile/" """ - def __init__(self, path: str): + def __init__(self, path: str, default_extension=None): self.path = path + self.default_extension = default_extension def __call__(self, instance, filename): - extension = filename.split(".")[-1] - filename = "{}.{}".format(uuid.uuid4().hex, extension) + if "." in filename: + extension = filename.split(".")[-1] + else: + extension = self.default_extension or "" + filename = "{}.{}".format(uuid.uuid4().hex, extension) nested_dirs = [dirname for dirname in self.path.split("/") if dirname] return os.path.join("uploads", *nested_dirs, filename) +@deconstructible +class ValidateImportString(object): + """ + Validate that a given string can be imported using the `import_from_path` function. + """ + + def __init__(self, target_type=None) -> None: + self.target_type = target_type + + def __call__(self, text: str): + symbol = import_from_path(text) + # print( + # "symbol:", + # symbol, + # " target type:", + # self.target_type, + # " is instance:", + # isinstance(symbol, self.target_type), + # ) + assert issubclass( + symbol, self.target_type + ), f"Imported object needs to be of type {self.target_type}, but got {type(symbol)}." + + class ReverseOneToOneOrNoneDescriptor(ReverseOneToOneDescriptor): def __get__(self, instance, cls=None): try: @@ -51,3 +80,18 @@ class OneToOneOrNoneField(models.OneToOneField[T]): """ # noqa: E501 related_accessor_class = ReverseOneToOneOrNoneDescriptor + + +def save_file_to_model(model: models.Model, filepath, field="file"): + """ + Given file path, save a file to a given model. + + This abstracts the process of opening the file and + copying over the file data. + """ + path = Path(filepath) + + with path.open(mode="rb") as f: + file = File(f, name=path.name) + setattr(model, field, file) + model.save() diff --git a/docker-compose.yml b/docker-compose.yml index c907650..5fa87af 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,6 +9,7 @@ services: - '8000:8000' volumes: - ./app:/app + - static-clubs-dev:/vol/static command: > sh -c "python manage.py wait_for_db && python manage.py migrate && @@ -25,7 +26,6 @@ services: - DJANGO_SUPERUSER_EMAIL=${DJANGO_SUPERUSER_EMAIL:-admin@example.com} - DJANGO_SUPERUSER_PASS=${DJANGO_SUPERUSER_PASS:-changeme} - DJANGO_BASE_URL=${DJANGO_BASE_URL:-http://localhost:8000} - - S3_STORAGE_BACKEND=0 - CREATE_SUPERUSER=1 - EMAIL_HOST_USER=${EMAIL_HOST_USER:-""} - EMAIL_HOST_PASS=${EMAIL_HOST_PASS:-""} @@ -42,8 +42,12 @@ services: - CELERY_ACKS_LATE=True - DJANGO_DB=postgresql - DJANGO_REDIS_URL=redis://clubs-dev-redis:6379/1 + + - AWS_EXECUTION_ENV=0 + - S3_STORAGE_BACKEND=0 depends_on: - postgres + - redis postgres: image: postgres:13-alpine @@ -71,7 +75,7 @@ services: command: ['celery', '-A', 'app', 'worker', '--loglevel=info'] volumes: - ./app:/app - - static-clubs-dev:/vol/web + - static-clubs-dev:/vol/static depends_on: - redis - postgres @@ -80,15 +84,31 @@ services: - DEBUG=1 - CELERY_BROKER_URL=redis://clubs-dev-redis:6379/0 - CELERY_RESULT_BACKEND=redis://clubs-dev-redis:6379/0 - - DJANGO_DB=postgresql + - CELERY_ACKS_LATE=True - POSTGRES_HOST=clubs-dev-db - POSTGRES_PORT=5432 - - POSTGRES_NAME=devdatabase - POSTGRES_DB=devdatabase + - POSTGRES_NAME=devdatabase - POSTGRES_USER=devuser - POSTGRES_PASSWORD=devpass + - DJANGO_DB=postgresql + - DJANGO_REDIS_URL=redis://clubs-dev-redis:6379/1 + - DB_HOST=clubs-dev-db + - DB_NAME=devdatabase + - DB_USER=devuser + - DB_PASS=devpass + + - EMAIL_HOST_USER=${EMAIL_HOST_USER:-""} + - EMAIL_HOST_PASS=${EMAIL_HOST_PASS:-""} + - CONSOLE_EMAIL_BACKEND=${CONSOLE_EMAIL_BACKEND:-1} + - SENDGRID_API_KEY=${SENDGRID_API_KEY:-""} + - DEFAULT_FROM_EMAIL=${DEFAULT_FROM_EMAIL:-""} + + - AWS_EXECUTION_ENV=0 + - S3_STORAGE_BACKEND=0 + celerybeat: build: context: . @@ -108,7 +128,7 @@ services: ] volumes: - ./app:/app - - static-clubs-dev:/vol/web + - static-clubs-dev:/vol/static depends_on: - redis - postgres @@ -122,6 +142,7 @@ services: - POSTGRES_HOST=clubs-dev-db - POSTGRES_PORT=5432 + - POSTGRES_DB=devdatabase - POSTGRES_NAME=devdatabase - POSTGRES_USER=devuser - POSTGRES_PASSWORD=devpass @@ -131,6 +152,15 @@ services: - DB_USER=devuser - DB_PASS=devpass + - EMAIL_HOST_USER=${EMAIL_HOST_USER:-""} + - EMAIL_HOST_PASS=${EMAIL_HOST_PASS:-""} + - CONSOLE_EMAIL_BACKEND=${CONSOLE_EMAIL_BACKEND:-1} + - SENDGRID_API_KEY=${SENDGRID_API_KEY:-""} + - DEFAULT_FROM_EMAIL=${DEFAULT_FROM_EMAIL:-""} + + - AWS_EXECUTION_ENV=0 + - S3_STORAGE_BACKEND=0 + coverage: image: nginx ports: diff --git a/requirements.txt b/requirements.txt index 432ad4e..069a670 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,12 @@ django-admin-tools>=0.9.3,<1.0 coverage>=7.4.1,<7.5 Pillow>=10.2.0,<11.2 sendgrid>=6.11.0,<6.12 +typing_extensions>=4.12.2,<4.13 + +# csv files +pandas>=2.2.3,<2.3 +xlsxwriter>=3.2.0,<3.3 +pathlib>=1.0.1,<1.1 # QRCodes segno>=1.6.1,<1.7 @@ -16,6 +22,10 @@ celery>=5.4.0,<5.5 redis>=5.0.4,<5.1 django-celery-beat>=2.7.0,<2.8 +# AWS S3 +boto3>=1.34.0,<1.35.0 +django-storages>=1.14.3,<1.15.0 + # Not required for local dev psycopg2>=2.9.3,<2.9.4 # 2.9.10 raises pip error, 2/13/25 uwsgi>=2.0.26,<2.0.27 # 2.0.28 raises pip error, 2/13/25 \ No newline at end of file