diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a884998..6b846ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: [ "--profile", "black", "--filter-files" ] - repo: https://github.com/psf/black - rev: 22.6.0 # Replace by any tag/version: https://github.com/psf/black/tags + rev: 23.3.0 # Replace by any tag/version: https://github.com/psf/black/tags hooks: - id: black language_version: python3 # Should be a command that runs python3.6.2+ diff --git a/drf_attachments/models/fields.py b/drf_attachments/models/fields.py index c6f7089..8a5b503 100644 --- a/drf_attachments/models/fields.py +++ b/drf_attachments/models/fields.py @@ -1,7 +1,6 @@ +from django.conf import settings from django.contrib.contenttypes.fields import GenericRelation from django.db.models import FileField -from django.conf import settings - __all__ = [ "AttachmentRelation", @@ -19,8 +18,6 @@ def __init__(self, *args, **kwargs): class DynamicStorageFileField(FileField): def pre_save(self, model_instance, add): meta = getattr(model_instance.content_object, "AttachmentMeta", None) - storage_location = getattr(meta, "storage_location", None) - if storage_location is None: - storage_location = settings.PRIVATE_ROOT + storage_location = getattr(meta, "storage_location", settings.PRIVATE_ROOT) self.storage.location = storage_location return super().pre_save(model_instance, add) diff --git a/drf_attachments/rest/views.py b/drf_attachments/rest/views.py index d9933f6..57e8526 100644 --- a/drf_attachments/rest/views.py +++ b/drf_attachments/rest/views.py @@ -37,6 +37,15 @@ def get_serializer(self, *args, **kwargs): def get_queryset(self): return Attachment.objects.viewable() + def get_storage_path(self): + attachment = self.get_object() + meta = getattr(attachment.content_object, "AttachmentMeta", None) + storage_location = getattr(meta, "storage_location", None) + if storage_location: + return f"{storage_location}/{attachment.file.name}" + else: + return attachment.file.path + @action( detail=True, methods=["GET"], @@ -46,6 +55,7 @@ def download(self, request, format=None, *args, **kwargs): """Downloads the uploaded attachment file.""" attachment = self.get_object() extension = attachment.get_extension() + storage_path = self.get_storage_path() if attachment.name: download_file_name = f"{attachment.name}{extension}" @@ -53,7 +63,7 @@ def download(self, request, format=None, *args, **kwargs): download_file_name = f"attachment_{attachment.pk}{extension}" return FileResponse( - open(attachment.file.path, "rb"), + open(storage_path, "rb"), as_attachment=True, filename=download_file_name, )