Skip to content

Commit

Permalink
Improve archive utils
Browse files Browse the repository at this point in the history
  • Loading branch information
geoff128 committed Jul 1, 2024
1 parent 29a5153 commit 1a3d77f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
41 changes: 22 additions & 19 deletions sio/archive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import tarfile
import zipfile

from sio.workers.util import RegisteredSubclassesBase


class ArchiveException(Exception):
"""Base exception class for all archive errors."""
Expand All @@ -50,23 +52,24 @@ def extract(path, member, to_path='', ext='', **kwargs):
Archive(path, ext=ext).extract(member, to_path, **kwargs)


class Archive(object):
class Archive(RegisteredSubclassesBase):
"""
The external API class that encapsulates an archive implementation.
"""

def __init__(self, file, ext=''):
"""
Arguments:
* 'file' can be a string path to a file or a file-like object.
* Optional 'ext' argument can be given to override the file-type
guess that is normally performed using the file extension of the
given 'file'. Should start with a dot, e.g. '.tar.gz'.
"""
self._archive = self._archive_cls(file, ext=ext)(file)
@classmethod
def __classinit__(cls):
this_cls = globals().get('Archive', cls)
super(this_cls, cls).__classinit__()
cls.handled_archives = set()

@staticmethod
def _archive_cls(file, ext=''):
@classmethod
def register_subclass(cls, subcls):
if cls is not subcls:
cls.handled_archives.add(subcls)

@classmethod
def get(cls, file):
"""
Return the proper Archive implementation class, based on the file type.
"""
Expand All @@ -79,10 +82,10 @@ def _archive_cls(file, ext=''):
except AttributeError:
raise UnrecognizedArchiveFormat(
"File object not a recognized archive format.")
for cls in HANDLED_ARCHIVES:
if cls.is_archive(filename):
return cls

for subcls in cls.handled_archives:
if subcls.is_archive(filename):
return subcls(filename)
raise UnrecognizedArchiveFormat(
"Path not a recognized archive format: %s" % filename)

Expand All @@ -96,10 +99,12 @@ def filenames(self):
return self._archive.filenames()


class BaseArchive(object):
class BaseArchive(Archive):
"""
Base Archive class. Implementations should inherit this class.
"""
abstract = True

def __del__(self):
if hasattr(self, "_archive"):
self._archive.close()
Expand Down Expand Up @@ -187,5 +192,3 @@ def filenames(self):
@staticmethod
def is_archive(filename):
return zipfile.is_zipfile(filename)

HANDLED_ARCHIVES = (ZipArchive, TarArchive)
2 changes: 1 addition & 1 deletion sio/executors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _fake_run_as_exe_is_output_file(environ):
ft.download(environ, 'exe_file', tempcwd('outs_archive'))
problem_short_name = environ['problem_short_name']
test_name = f'{problem_short_name}{environ["name"]}.out'
archive = Archive(tempcwd('outs_archive'))
archive = Archive.get(tempcwd('outs_archive'))
logger.info('Archive with outs provided')
if test_name in archive.filenames():
archive.extract(test_name, to_path=tempcwd())
Expand Down

0 comments on commit 1a3d77f

Please sign in to comment.