diff --git a/src/sio3pack/__init__.py b/src/sio3pack/__init__.py index 51fe899..9761228 100644 --- a/src/sio3pack/__init__.py +++ b/src/sio3pack/__init__.py @@ -1,10 +1,10 @@ __version__ = "0.0.1" -from sio3pack.files.file import File +from sio3pack.files import LocalFile from sio3pack.packages.package import Package -def from_file(file: str | File, django_settings=None) -> Package: +def from_file(file: str | LocalFile, django_settings=None) -> Package: """ Initialize a package object from a file (archive or directory). :param file: The file path or File object. @@ -12,5 +12,5 @@ def from_file(file: str | File, django_settings=None) -> Package: :return: The package object. """ if isinstance(file, str): - file = File(file) - return Package.from_file(file, django_settings) + file = LocalFile(file) + return Package.from_file(file, django_settings=django_settings) diff --git a/src/sio3pack/files/__init__.py b/src/sio3pack/files/__init__.py index bedfbe6..ca063b6 100644 --- a/src/sio3pack/files/__init__.py +++ b/src/sio3pack/files/__init__.py @@ -1,2 +1,3 @@ from sio3pack.files.filetracker_file import FiletrackerFile from sio3pack.files.local_file import LocalFile +from sio3pack.files.file import File diff --git a/src/sio3pack/files/file.py b/src/sio3pack/files/file.py index 779ec3a..1ae0908 100644 --- a/src/sio3pack/files/file.py +++ b/src/sio3pack/files/file.py @@ -1,36 +1,16 @@ -import os.path - - class File: """ Base class for all files in a package. """ - @classmethod - def get_file_matching_extension(cls, dir: str, filename: str, extensions: list[str]) -> "File": - """ - Get the file with the given filename and one of the given extensions. - :param dir: The directory to search in. - :param filename: The filename. - :param extensions: The extensions. - :return: The file object. - """ - for ext in extensions: - path = os.path.join(dir, filename + ext) - if os.path.exists(path): - return cls(path) - raise FileNotFoundError - def __init__(self, path: str): - if not os.path.exists(path): - raise FileNotFoundError self.path = path - self.filename = os.path.basename(path) + + def __str__(self): + return f"<{self.__class__.__name__} {self.path}>" def read(self) -> str: - with open(self.path, "r") as f: - return f.read() + raise NotImplementedError() def write(self, text: str): - with open(self.path, "w") as f: - f.write(text) + raise NotImplementedError() diff --git a/src/sio3pack/files/filetracker_file.py b/src/sio3pack/files/filetracker_file.py index 3698e2d..9ee022e 100644 --- a/src/sio3pack/files/filetracker_file.py +++ b/src/sio3pack/files/filetracker_file.py @@ -8,3 +8,5 @@ class FiletrackerFile(File): def __init__(self, path: str): super().__init__(path) + # TODO: should raise FileNotFoundError if file is not tracked + raise NotImplementedError() diff --git a/src/sio3pack/files/local_file.py b/src/sio3pack/files/local_file.py index 95f4009..43b1c15 100644 --- a/src/sio3pack/files/local_file.py +++ b/src/sio3pack/files/local_file.py @@ -1,3 +1,5 @@ +import os + from sio3pack.files.file import File @@ -6,5 +8,31 @@ class LocalFile(File): Base class for all files in a package that are stored locally. """ + @classmethod + def get_file_matching_extension(cls, dir: str, filename: str, extensions: list[str]) -> "LocalFile": + """ + Get the file with the given filename and one of the given extensions. + :param dir: The directory to search in. + :param filename: The filename. + :param extensions: The extensions. + :return: The file object. + """ + for ext in extensions: + path = os.path.join(dir, filename + ext) + if os.path.exists(path): + return cls(path) + raise FileNotFoundError + def __init__(self, path: str): + if not os.path.exists(path): + raise FileNotFoundError super().__init__(path) + self.filename = os.path.basename(path) + + def read(self) -> str: + with open(self.path, "r") as f: + return f.read() + + def write(self, text: str): + with open(self.path, "w") as f: + f.write(text) diff --git a/src/sio3pack/graph/__init__.py b/src/sio3pack/graph/__init__.py index bf9957e..79efe46 100644 --- a/src/sio3pack/graph/__init__.py +++ b/src/sio3pack/graph/__init__.py @@ -1 +1,3 @@ -from sio3pack.graph import Graph +from sio3pack.graph.graph import Graph +from sio3pack.graph.graph_manager import GraphManager +from sio3pack.graph.graph_op import GraphOperation diff --git a/src/sio3pack/graph/graph.py b/src/sio3pack/graph/graph.py index cdd50b2..370cbb8 100644 --- a/src/sio3pack/graph/graph.py +++ b/src/sio3pack/graph/graph.py @@ -5,7 +5,7 @@ class Graph: @classmethod def from_dict(cls, data: dict): - raise NotImplemented + raise NotImplementedError() def __init__(self, name: str): self.name = name @@ -14,4 +14,4 @@ def get_prog_files(self) -> list[str]: """ Get all program files in the graph. """ - raise NotImplemented + raise NotImplementedError() diff --git a/src/sio3pack/graph/graph_manager.py b/src/sio3pack/graph/graph_manager.py index 585ba8b..4a8e4f8 100644 --- a/src/sio3pack/graph/graph_manager.py +++ b/src/sio3pack/graph/graph_manager.py @@ -1,6 +1,6 @@ import json -from sio3pack import File +from sio3pack.files import File from sio3pack.graph.graph import Graph diff --git a/src/sio3pack/packages/exceptions.py b/src/sio3pack/packages/exceptions.py index ca91db2..307b736 100644 --- a/src/sio3pack/packages/exceptions.py +++ b/src/sio3pack/packages/exceptions.py @@ -1,2 +1,4 @@ class UnknownPackageType(Exception): - pass + def __init__(self, path: str) -> None: + self.path = path + super().__init__(f"Unknown package type for file {path}.") diff --git a/src/sio3pack/packages/package/model.py b/src/sio3pack/packages/package/model.py index 85972e1..0a863ec 100644 --- a/src/sio3pack/packages/package/model.py +++ b/src/sio3pack/packages/package/model.py @@ -1,9 +1,10 @@ from typing import Any -from sio3pack.files.file import File -from sio3pack.graph.graph import Graph +from sio3pack import LocalFile +from sio3pack.files import File +from sio3pack.graph import Graph from sio3pack.packages.exceptions import UnknownPackageType -from sio3pack.test.test import Test +from sio3pack.test import Test from sio3pack.utils.archive import Archive from sio3pack.utils.classinit import RegisteredSubclassesBase @@ -18,17 +19,18 @@ class Package(RegisteredSubclassesBase): def __init__(self, file: File): super().__init__() self.file = file - if Archive.is_archive(file.path): - self.is_archive = True - else: - self.is_archive = False + if isinstance(file, LocalFile): + if Archive.is_archive(file.path): + self.is_archive = True + else: + self.is_archive = False @classmethod - def from_file(cls, file: File, django_settings=None): + def from_file(cls, file: LocalFile, django_settings=None): for subclass in cls.subclasses: if subclass.identify(file): return subclass(file, django_settings) - raise UnknownPackageType + raise UnknownPackageType(file.path) def get_task_id(self) -> str: pass diff --git a/src/sio3pack/packages/sinolpack/model.py b/src/sio3pack/packages/sinolpack/model.py index 65fb3b8..09fbea4 100644 --- a/src/sio3pack/packages/sinolpack/model.py +++ b/src/sio3pack/packages/sinolpack/model.py @@ -4,10 +4,11 @@ import yaml -from sio3pack.files.file import File -from sio3pack.graph.graph import Graph -from sio3pack.graph.graph_manager import GraphManager -from sio3pack.graph.graph_op import GraphOperation +from sio3pack import LocalFile +from sio3pack.files import File +from sio3pack.graph import Graph +from sio3pack.graph import GraphManager +from sio3pack.graph import GraphOperation from sio3pack.packages.package import Package from sio3pack.packages.sinolpack.enums import ModelSolutionKind from sio3pack.util import naturalsort_key @@ -37,7 +38,7 @@ def _find_main_dir(cls, archive: Archive) -> str | None: return None @classmethod - def identify(cls, file: File) -> bool: + def identify(cls, file: LocalFile) -> bool: """ Identifies whether file is a Sinolpack. @@ -57,21 +58,25 @@ def __del__(self): def __init__(self, file: File, django_settings=None): super().__init__(file) - if self.is_archive: - archive = Archive(file.path) - self.short_name = self._find_main_dir(archive) - self.tmpdir = tempfile.TemporaryDirectory() - archive.extract(to_path=self.tmpdir.name) - self.rootdir = os.path.join(self.tmpdir.name, self.short_name) - else: - self.short_name = os.path.basename(file.path) - self.rootdir = file.path - try: - graph_file = self.get_in_root("graph.json") - self.graph_manager = GraphManager.from_file(graph_file) - except FileNotFoundError: - self.has_custom_graph = False + if isinstance(file, LocalFile): + if self.is_archive: + archive = Archive(file.path) + self.short_name = self._find_main_dir(archive) + self.tmpdir = tempfile.TemporaryDirectory() + archive.extract(to_path=self.tmpdir.name) + self.rootdir = os.path.join(self.tmpdir.name, self.short_name) + else: + self.short_name = os.path.basename(file.path) + self.rootdir = file.path + + try: + graph_file = self.get_in_root("graph.json") + self.graph_manager = GraphManager.from_file(graph_file) + except FileNotFoundError: + self.has_custom_graph = False + else: + raise NotImplementedError() self.django_settings = django_settings @@ -98,17 +103,17 @@ def get_doc_dir(self) -> str: """ return os.path.join(self.rootdir, "doc") - def get_in_doc_dir(self, filename: str) -> File: + def get_in_doc_dir(self, filename: str) -> LocalFile: """ Returns the path to the input file in the documents' directory. """ - return File(os.path.join(self.get_doc_dir(), filename)) + return LocalFile(os.path.join(self.get_doc_dir(), filename)) - def get_in_root(self, filename: str) -> File: + def get_in_root(self, filename: str) -> LocalFile: """ Returns the path to the input file in the root directory. """ - return File(os.path.join(self.rootdir, filename)) + return LocalFile(os.path.join(self.rootdir, filename)) def get_prog_dir(self) -> str: """ @@ -116,11 +121,11 @@ def get_prog_dir(self) -> str: """ return os.path.join(self.rootdir, "prog") - def get_in_prog_dir(self, filename: str) -> File: + def get_in_prog_dir(self, filename: str) -> LocalFile: """ Returns the path to the input file in the program directory. """ - return File(os.path.join(self.get_prog_dir(), filename)) + return LocalFile(os.path.join(self.get_prog_dir(), filename)) def get_attachments_dir(self) -> str: """ @@ -247,7 +252,7 @@ def _process_prog_files(self): for file in ("ingen", "inwer", "soc", "chk"): try: self.additional_files.append( - File.get_file_matching_extension( + LocalFile.get_file_matching_extension( self.get_prog_dir(), self.short_name + file, extensions ).filename ) diff --git a/src/sio3pack/test/__init__.py b/src/sio3pack/test/__init__.py index 0478c51..9c36d80 100644 --- a/src/sio3pack/test/__init__.py +++ b/src/sio3pack/test/__init__.py @@ -1 +1 @@ -from sio3pack.test import Test +from sio3pack.test.test import Test diff --git a/src/sio3pack/test/simple_test.py b/src/sio3pack/test/simple_test.py index 57c375c..3532aa0 100644 --- a/src/sio3pack/test/simple_test.py +++ b/src/sio3pack/test/simple_test.py @@ -1,5 +1,5 @@ -from sio3pack.files.file import File -from sio3pack.test.test import Test +from sio3pack.files import File +from sio3pack.test import Test class SimpleTest(Test): diff --git a/src/sio3pack/utils/archive.py b/src/sio3pack/utils/archive.py index 847eb81..533dd1d 100644 --- a/src/sio3pack/utils/archive.py +++ b/src/sio3pack/utils/archive.py @@ -65,6 +65,9 @@ def __init__(self, file, ext=""): self.filename = file self._archive = self._archive_cls(self.filename, ext=ext)(self.filename) + def __str__(self): + return f'' + @staticmethod def _archive_cls(file, ext=""): """ @@ -220,7 +223,13 @@ def filenames(self): return [zipinfo.filename for zipinfo in self._archive.infolist() if not zipinfo.is_dir()] def dirnames(self): - return [zipinfo.filename for zipinfo in self._archive.infolist() if zipinfo.is_dir()] + dirs = set() + for zipinfo in self._archive.infolist(): + if zipinfo.is_dir(): + dirs.add(zipinfo.filename) + else: + dirs.add(os.path.dirname(zipinfo.filename)) + return list(dirs) extension_map = { diff --git a/tests/fixtures.py b/tests/fixtures.py index 7e31c57..1e2a433 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -19,6 +19,17 @@ class Compression(Enum): all_compressions = [c.value for c in Compression if c != Compression.NONE] +class PackageInfo: + def __init__(self, path, type, task_id, compression): + self.path = path + self.type = type + self.task_id = task_id + self.compression = compression + + def is_archive(self): + return self.compression != Compression.NONE + + def _tar_archive(dir, dest, compression=None): """ Create a tar archive of the specified directory. @@ -38,10 +49,12 @@ def _zip_archive(dir, dest): with zipfile.ZipFile(dest, "w") as zip: for root, dirs, files in os.walk(dir): for file in files: - zip.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), dir)) + file_path = os.path.join(root, file) + arcname = os.path.join(os.path.basename(dir), os.path.relpath(file_path, dir)) + zip.write(file_path, arcname) -def _create_package(package_name, tmpdir, archive=False, extension="zip"): +def _create_package(package_name, tmpdir, archive=False, extension=Compression.ZIP): packages = os.path.join(os.path.dirname(__file__), "test_packages") if not os.path.exists(os.path.join(packages, package_name)): raise FileNotFoundError(f"Package {package_name} does not exist") @@ -58,43 +71,46 @@ def _create_package(package_name, tmpdir, archive=False, extension="zip"): os.unlink(os.path.join(package_path, "__init__.py")) if archive: - if extension == "zip": + if extension == Compression.ZIP: _zip_archive(package_path, os.path.join(tmpdir.name, f"{task_id}.zip")) - elif extension == "tar": - _tar_archive(package_path, os.path.join(tmpdir.name, f"{task_id}.tar")) - elif extension == "tar.gz" or extension == "tgz": - _tar_archive(package_path, os.path.join(tmpdir.name, f"{task_id}.{extension}"), "gz") + elif extension == Compression.TAR_GZ or extension == Compression.TGZ: + _tar_archive(package_path, os.path.join(tmpdir.name, f"{task_id}.{extension.value}"), "gz") else: raise ValueError(f"Unknown extension {extension}") - package_path = os.path.join(tmpdir.name, f"{task_id}.{extension}") + package_path = os.path.join(tmpdir.name, f"{task_id}.{extension.value}") - return package_path, type + return PackageInfo( + path=package_path, + type=type, + task_id=task_id, + compression=extension, + ) @pytest.fixture -def package(request): +def get_package(request): """ Fixture to create a temporary directory with specified package. """ package_name = request.param tmpdir = tempfile.TemporaryDirectory() - package_path, type = _create_package(package_name, tmpdir) + package_info = _create_package(package_name, tmpdir) - yield package_path, type + yield lambda: package_info tmpdir.cleanup() @pytest.fixture -def package_archived(request): +def get_archived_package(request): """ Fixture to create a temporary directory with specified package, but archived. """ package_name, extension = request.param archive = extension != Compression.NONE tmpdir = tempfile.TemporaryDirectory() - package_path, type = _create_package(package_name, tmpdir, archive, extension) + package_info = _create_package(package_name, tmpdir, archive, extension) - yield package_path, type + yield lambda: package_info tmpdir.cleanup() diff --git a/tests/packages/sinolpack/test_sinolpack.py b/tests/packages/sinolpack/test_sinolpack.py index c591fad..d4a6d13 100644 --- a/tests/packages/sinolpack/test_sinolpack.py +++ b/tests/packages/sinolpack/test_sinolpack.py @@ -2,20 +2,19 @@ import pytest -from tests.fixtures import Compression, all_compressions, package, package_archived +import sio3pack +from sio3pack.packages import Sinolpack +from tests.fixtures import Compression, all_compressions, get_package, get_archived_package, PackageInfo -@pytest.mark.parametrize("package", ["simple"], indirect=True) -def test_simple(package): - package_path, type = package - assert type == "sinolpack" - print(os.listdir(package_path)) - assert os.path.isdir(package_path) - - -@pytest.mark.parametrize("package_archived", [("simple", c) for c in all_compressions], indirect=True) -def test_archive(package_archived): - package_path, type = package_archived - assert type == "sinolpack" - print(package_path) - assert os.path.isfile(package_path) +@pytest.mark.parametrize("get_archived_package", [("simple", c) for c in Compression], indirect=True) +def test_from_file(get_archived_package): + package_info: PackageInfo = get_archived_package() + assert package_info.type == "sinolpack" + package = sio3pack.from_file(package_info.path) + assert isinstance(package, Sinolpack) + assert package.short_name == package_info.task_id + if package_info.is_archive(): + assert package.is_archive + else: + assert package.rootdir == package_info.path diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..e69de29