Skip to content

Commit

Permalink
Fix bugs, refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
MasloMaslane committed Nov 30, 2024
1 parent c577483 commit 8e92a50
Show file tree
Hide file tree
Showing 17 changed files with 149 additions and 103 deletions.
8 changes: 4 additions & 4 deletions src/sio3pack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
__version__ = "0.0.1"

from sio3pack.files.file import File
from sio3pack.files import LocalFile
from sio3pack.packages.package import Package


def from_file(file: str | File, django_settings=None) -> Package:
def from_file(file: str | LocalFile, django_settings=None) -> Package:
"""
Initialize a package object from a file (archive or directory).
:param file: The file path or File object.
:param django_settings: Django settings object.
:return: The package object.
"""
if isinstance(file, str):
file = File(file)
return Package.from_file(file, django_settings)
file = LocalFile(file)
return Package.from_file(file, django_settings=django_settings)
1 change: 1 addition & 0 deletions src/sio3pack/files/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from sio3pack.files.filetracker_file import FiletrackerFile
from sio3pack.files.local_file import LocalFile
from sio3pack.files.file import File
30 changes: 5 additions & 25 deletions src/sio3pack/files/file.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,16 @@
import os.path


class File:
"""
Base class for all files in a package.
"""

@classmethod
def get_file_matching_extension(cls, dir: str, filename: str, extensions: list[str]) -> "File":
"""
Get the file with the given filename and one of the given extensions.
:param dir: The directory to search in.
:param filename: The filename.
:param extensions: The extensions.
:return: The file object.
"""
for ext in extensions:
path = os.path.join(dir, filename + ext)
if os.path.exists(path):
return cls(path)
raise FileNotFoundError

def __init__(self, path: str):
if not os.path.exists(path):
raise FileNotFoundError
self.path = path
self.filename = os.path.basename(path)

def __str__(self):
return f"<{self.__class__.__name__} {self.path}>"

def read(self) -> str:
with open(self.path, "r") as f:
return f.read()
raise NotImplementedError()

def write(self, text: str):
with open(self.path, "w") as f:
f.write(text)
raise NotImplementedError()
2 changes: 2 additions & 0 deletions src/sio3pack/files/filetracker_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ class FiletrackerFile(File):

def __init__(self, path: str):
super().__init__(path)
# TODO: should raise FileNotFoundError if file is not tracked
raise NotImplementedError()
28 changes: 28 additions & 0 deletions src/sio3pack/files/local_file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from sio3pack.files.file import File


Expand All @@ -6,5 +8,31 @@ class LocalFile(File):
Base class for all files in a package that are stored locally.
"""

@classmethod
def get_file_matching_extension(cls, dir: str, filename: str, extensions: list[str]) -> "LocalFile":
"""
Get the file with the given filename and one of the given extensions.
:param dir: The directory to search in.
:param filename: The filename.
:param extensions: The extensions.
:return: The file object.
"""
for ext in extensions:
path = os.path.join(dir, filename + ext)
if os.path.exists(path):
return cls(path)
raise FileNotFoundError

def __init__(self, path: str):
if not os.path.exists(path):
raise FileNotFoundError
super().__init__(path)
self.filename = os.path.basename(path)

def read(self) -> str:
with open(self.path, "r") as f:
return f.read()

def write(self, text: str):
with open(self.path, "w") as f:
f.write(text)
4 changes: 3 additions & 1 deletion src/sio3pack/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from sio3pack.graph import Graph
from sio3pack.graph.graph import Graph
from sio3pack.graph.graph_manager import GraphManager
from sio3pack.graph.graph_op import GraphOperation
4 changes: 2 additions & 2 deletions src/sio3pack/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class Graph:

@classmethod
def from_dict(cls, data: dict):
raise NotImplemented
raise NotImplementedError()

def __init__(self, name: str):
self.name = name
Expand All @@ -14,4 +14,4 @@ def get_prog_files(self) -> list[str]:
"""
Get all program files in the graph.
"""
raise NotImplemented
raise NotImplementedError()
2 changes: 1 addition & 1 deletion src/sio3pack/graph/graph_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from sio3pack import File
from sio3pack.files import File
from sio3pack.graph.graph import Graph


Expand Down
4 changes: 3 additions & 1 deletion src/sio3pack/packages/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
class UnknownPackageType(Exception):
pass
def __init__(self, path: str) -> None:
self.path = path
super().__init__(f"Unknown package type for file {path}.")
20 changes: 11 additions & 9 deletions src/sio3pack/packages/package/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any

from sio3pack.files.file import File
from sio3pack.graph.graph import Graph
from sio3pack import LocalFile
from sio3pack.files import File
from sio3pack.graph import Graph
from sio3pack.packages.exceptions import UnknownPackageType
from sio3pack.test.test import Test
from sio3pack.test import Test
from sio3pack.utils.archive import Archive
from sio3pack.utils.classinit import RegisteredSubclassesBase

Expand All @@ -18,17 +19,18 @@ class Package(RegisteredSubclassesBase):
def __init__(self, file: File):
super().__init__()
self.file = file
if Archive.is_archive(file.path):
self.is_archive = True
else:
self.is_archive = False
if isinstance(file, LocalFile):
if Archive.is_archive(file.path):
self.is_archive = True
else:
self.is_archive = False

@classmethod
def from_file(cls, file: File, django_settings=None):
def from_file(cls, file: LocalFile, django_settings=None):
for subclass in cls.subclasses:
if subclass.identify(file):
return subclass(file, django_settings)
raise UnknownPackageType
raise UnknownPackageType(file.path)

def get_task_id(self) -> str:
pass
Expand Down
57 changes: 31 additions & 26 deletions src/sio3pack/packages/sinolpack/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import yaml

from sio3pack.files.file import File
from sio3pack.graph.graph import Graph
from sio3pack.graph.graph_manager import GraphManager
from sio3pack.graph.graph_op import GraphOperation
from sio3pack import LocalFile
from sio3pack.files import File
from sio3pack.graph import Graph
from sio3pack.graph import GraphManager
from sio3pack.graph import GraphOperation
from sio3pack.packages.package import Package
from sio3pack.packages.sinolpack.enums import ModelSolutionKind
from sio3pack.util import naturalsort_key
Expand Down Expand Up @@ -37,7 +38,7 @@ def _find_main_dir(cls, archive: Archive) -> str | None:
return None

@classmethod
def identify(cls, file: File) -> bool:
def identify(cls, file: LocalFile) -> bool:
"""
Identifies whether file is a Sinolpack.
Expand All @@ -57,21 +58,25 @@ def __del__(self):

def __init__(self, file: File, django_settings=None):
super().__init__(file)
if self.is_archive:
archive = Archive(file.path)
self.short_name = self._find_main_dir(archive)
self.tmpdir = tempfile.TemporaryDirectory()
archive.extract(to_path=self.tmpdir.name)
self.rootdir = os.path.join(self.tmpdir.name, self.short_name)
else:
self.short_name = os.path.basename(file.path)
self.rootdir = file.path

try:
graph_file = self.get_in_root("graph.json")
self.graph_manager = GraphManager.from_file(graph_file)
except FileNotFoundError:
self.has_custom_graph = False
if isinstance(file, LocalFile):
if self.is_archive:
archive = Archive(file.path)
self.short_name = self._find_main_dir(archive)
self.tmpdir = tempfile.TemporaryDirectory()
archive.extract(to_path=self.tmpdir.name)
self.rootdir = os.path.join(self.tmpdir.name, self.short_name)
else:
self.short_name = os.path.basename(file.path)
self.rootdir = file.path

try:
graph_file = self.get_in_root("graph.json")
self.graph_manager = GraphManager.from_file(graph_file)
except FileNotFoundError:
self.has_custom_graph = False
else:
raise NotImplementedError()

self.django_settings = django_settings

Expand All @@ -98,29 +103,29 @@ def get_doc_dir(self) -> str:
"""
return os.path.join(self.rootdir, "doc")

def get_in_doc_dir(self, filename: str) -> File:
def get_in_doc_dir(self, filename: str) -> LocalFile:
"""
Returns the path to the input file in the documents' directory.
"""
return File(os.path.join(self.get_doc_dir(), filename))
return LocalFile(os.path.join(self.get_doc_dir(), filename))

def get_in_root(self, filename: str) -> File:
def get_in_root(self, filename: str) -> LocalFile:
"""
Returns the path to the input file in the root directory.
"""
return File(os.path.join(self.rootdir, filename))
return LocalFile(os.path.join(self.rootdir, filename))

def get_prog_dir(self) -> str:
"""
Returns the path to the directory containing the problem's program files.
"""
return os.path.join(self.rootdir, "prog")

def get_in_prog_dir(self, filename: str) -> File:
def get_in_prog_dir(self, filename: str) -> LocalFile:
"""
Returns the path to the input file in the program directory.
"""
return File(os.path.join(self.get_prog_dir(), filename))
return LocalFile(os.path.join(self.get_prog_dir(), filename))

def get_attachments_dir(self) -> str:
"""
Expand Down Expand Up @@ -247,7 +252,7 @@ def _process_prog_files(self):
for file in ("ingen", "inwer", "soc", "chk"):
try:
self.additional_files.append(
File.get_file_matching_extension(
LocalFile.get_file_matching_extension(
self.get_prog_dir(), self.short_name + file, extensions
).filename
)
Expand Down
2 changes: 1 addition & 1 deletion src/sio3pack/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from sio3pack.test import Test
from sio3pack.test.test import Test
4 changes: 2 additions & 2 deletions src/sio3pack/test/simple_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sio3pack.files.file import File
from sio3pack.test.test import Test
from sio3pack.files import File
from sio3pack.test import Test


class SimpleTest(Test):
Expand Down
11 changes: 10 additions & 1 deletion src/sio3pack/utils/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(self, file, ext=""):
self.filename = file
self._archive = self._archive_cls(self.filename, ext=ext)(self.filename)

def __str__(self):
return f'<Archive({self._archive.__class__.__name__}) {self.filename}>'

@staticmethod
def _archive_cls(file, ext=""):
"""
Expand Down Expand Up @@ -220,7 +223,13 @@ def filenames(self):
return [zipinfo.filename for zipinfo in self._archive.infolist() if not zipinfo.is_dir()]

def dirnames(self):
return [zipinfo.filename for zipinfo in self._archive.infolist() if zipinfo.is_dir()]
dirs = set()
for zipinfo in self._archive.infolist():
if zipinfo.is_dir():
dirs.add(zipinfo.filename)
else:
dirs.add(os.path.dirname(zipinfo.filename))
return list(dirs)


extension_map = {
Expand Down
Loading

0 comments on commit 8e92a50

Please sign in to comment.