Skip to content

Commit

Permalink
Add 'hash_pathing_fallback' option, remove File.get_path() method (#1017
Browse files Browse the repository at this point in the history
)
  • Loading branch information
psrok1 authored Jan 13, 2025
1 parent 390bcce commit 4113e72
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 64 deletions.
4 changes: 4 additions & 0 deletions mwdb/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class MWDBConfig(Config):
# Should we break up the uploads into different folders for example:
# uploads/9/f/8/6/9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08
hash_pathing = key(cast=intbool, required=False, default=True)
# Try to open file using opposite hash_pathing setting when MWDB
# fails to open file using current one. It's useful when you want to
# migrate from one scheme to another.
hash_pathing_fallback = key(cast=intbool, required=False, default=True)
# S3 compatible storage endpoint
s3_storage_endpoint = key(cast=str, required=False)
# Use TLS with S3 storage
Expand Down
101 changes: 37 additions & 64 deletions mwdb/model/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
from mwdb.core.auth import AuthScope, generate_token, verify_token
from mwdb.core.config import StorageProviderType, app_config
from mwdb.core.karton import send_file_to_karton
from mwdb.core.util import (
calc_crc32,
calc_hash,
calc_magic,
calc_ssdeep,
get_fd_path,
get_s3_client,
)
from mwdb.core.util import calc_crc32, calc_hash, calc_magic, calc_ssdeep, get_s3_client

from . import db
from .object import Object
Expand Down Expand Up @@ -147,7 +140,9 @@ def get_or_create(
Body=file_stream,
)
elif app_config.mwdb.storage_provider == StorageProviderType.DISK:
with open(file_obj._calculate_path(), "wb") as f:
upload_path = file_obj._calculate_path()
os.makedirs(os.path.dirname(upload_path), mode=0o755, exist_ok=True)
with open(upload_path, "wb") as f:
shutil.copyfileobj(file_stream, f)
else:
raise RuntimeError(
Expand All @@ -158,7 +153,7 @@ def get_or_create(
file_obj.upload_stream = file_stream
return file_obj, is_new

def _calculate_path(self):
def _calculate_path(self, fallback_path=False):
if app_config.mwdb.storage_provider == StorageProviderType.DISK:
upload_path = app_config.mwdb.uploads_folder
elif app_config.mwdb.storage_provider == StorageProviderType.S3:
Expand All @@ -170,65 +165,19 @@ def _calculate_path(self):

sample_sha256 = self.sha256.lower()

if app_config.mwdb.hash_pathing:
hash_pathing = app_config.mwdb.hash_pathing
if fallback_path:
hash_pathing = not hash_pathing
if hash_pathing:
# example: uploads/9/f/8/6/9f86d0818...
upload_path = os.path.join(upload_path, *list(sample_sha256)[0:4])

if app_config.mwdb.storage_provider == StorageProviderType.DISK:
upload_path = os.path.abspath(upload_path)
os.makedirs(upload_path, mode=0o755, exist_ok=True)
return os.path.join(upload_path, sample_sha256)

def get_path(self):
"""
Legacy method used to retrieve the path to the file contents.
Creates NamedTemporaryFile if mwdb-core uses different type of
storage than DISK and file size is too small to be written to
disk by Werkzeug.
Deprecated, use File.open() to get the stream with contents.
"""
if app_config.mwdb.storage_provider == StorageProviderType.DISK:
# Just return path of file stored in local file-system
return self._calculate_path()

if not self.upload_stream:
raise ValueError("Can't retrieve local path for this file")

if isinstance(self.upload_stream.name, str) or isinstance(
self.upload_stream, bytes
):
return self.upload_stream.name

fd_path = get_fd_path(self.upload_stream)
if fd_path:
return fd_path

# If not a file (BytesIO), copy contents to the named temporary file
tmpfile = tempfile.NamedTemporaryFile()
self.upload_stream.seek(0, os.SEEK_SET)
shutil.copyfileobj(self.upload_stream, tmpfile)
self.upload_stream.close()
self.upload_stream = tmpfile
return self.upload_stream.name

def open(self):
"""
Opens the file stream with contents.
return os.path.join(upload_path, sample_sha256)

File stream must be closed using File.close.
"""
if self.upload_stream is not None:
# If file contents are uploaded in this request,
# try to reuse the existing file instead of downloading it from S3.
if isinstance(self.upload_stream, io.BytesIO):
return io.BytesIO(self.upload_stream.getbuffer())
else:
dupfd = os.dup(self.upload_stream.fileno())
stream = os.fdopen(dupfd, "rb")
stream.seek(0, os.SEEK_SET)
return stream
def _open_from_storage(self, fallback_path=False):
if app_config.mwdb.storage_provider == StorageProviderType.S3:
# Stream coming from Boto3 get_object is not buffered and not seekable.
# We need to download it to the temporary file first.
Expand All @@ -243,7 +192,7 @@ def open(self):
app_config.mwdb.s3_storage_iam_auth,
).download_fileobj(
Bucket=app_config.mwdb.s3_storage_bucket_name,
Key=self._calculate_path(),
Key=self._calculate_path(fallback_path=fallback_path),
Fileobj=stream,
)
stream.seek(0, io.SEEK_SET)
Expand All @@ -252,12 +201,36 @@ def open(self):
stream.close()
raise
elif app_config.mwdb.storage_provider == StorageProviderType.DISK:
return open(self._calculate_path(), "rb")
return open(self._calculate_path(fallback_path=fallback_path), "rb")
else:
raise RuntimeError(
f"StorageProvider {app_config.mwdb.storage_provider} is not supported"
)

def open(self):
"""
Opens the file stream with contents.
File stream must be closed using File.close.
"""
if self.upload_stream is not None:
# If file contents are uploaded in this request,
# try to reuse the existing file instead of downloading it from S3.
if isinstance(self.upload_stream, io.BytesIO):
return io.BytesIO(self.upload_stream.getbuffer())
else:
dupfd = os.dup(self.upload_stream.fileno())
stream = os.fdopen(dupfd, "rb")
stream.seek(0, os.SEEK_SET)
return stream
try:
return self._open_from_storage()
except Exception:
if app_config.mwdb.hash_pathing_fallback:
return self._open_from_storage(fallback_path=True)
else:
raise

def read(self):
"""
Reads all bytes from the file
Expand Down

0 comments on commit 4113e72

Please sign in to comment.