Skip to content

Commit

Permalink
api cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 15, 2024
1 parent 7182c49 commit 8728317
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 55 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ authors = [{ email = "[email protected]", name = "Allen Goodman" }]
dependencies = [
"pooch",
"torch",
"tqdm",
]
dynamic = ["version"]
license = { file = "LICENSE" }
Expand Down
39 changes: 18 additions & 21 deletions src/beignet/datasets/__uni_ref_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def __init__(
root: str | PathLike | None = None,
known_hash: str | None = None,
*,
index: bool = True,
transform: Callable | Transform | None = None,
target_transform: Callable | Transform | None = None,
) -> None:
Expand All @@ -34,10 +33,6 @@ def __init__(
`download` is `True`, the directory where the dataset subdirectory
will be created and the dataset downloaded.
index : bool, optional
If `True`, caches the sequence indexes to disk for faster
re-initialization (default: `True`).
transform : Callable | Transform, optional
A `Callable` or `Transform` that that maps a sequence to a
transformed sequence (default: `None`).
Expand All @@ -56,32 +51,34 @@ def __init__(

name = self.__class__.__name__.replace("Dataset", "")

path = pooch.retrieve(
url,
known_hash,
f"{name}.fasta.gz",
root / name,
processor=Decompress(),
progressbar=True,
)

self._pattern = re.compile(r"^UniRef.+_([A-Z0-9]+)\s.+$")

super().__init__(path, index=index)
super().__init__(
pooch.retrieve(
url,
known_hash,
f"{name}.fasta.gz",
root / name,
processor=Decompress(
name=f"{name}.fasta",
),
progressbar=True,
),
)

self._transform = transform
self.transform = transform

self._target_transform = target_transform
self.target_transform = target_transform

def __getitem__(self, index: int) -> (str, str):
target, sequence = self.get(index)

(target,) = re.search(self._pattern, target).groups()

if self._transform:
sequence = self._transform(sequence)
if self.transform:
sequence = self.transform(sequence)

if self._target_transform:
target = self._target_transform(target)
if self.target_transform:
target = self.target_transform(target)

return sequence, target
42 changes: 23 additions & 19 deletions src/beignet/datasets/_fasta_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import subprocess
from os import PathLike
from pathlib import Path
from typing import Callable, Tuple, TypeVar

import numpy

from beignet.io import ThreadSafeFile

from ..transforms import Transform
from ._sized_sequence_dataset import SizedSequenceDataset

T = TypeVar("T")
Expand All @@ -14,52 +16,54 @@
class FASTADataset(SizedSequenceDataset):
def __init__(
self,
root: str | Path,
root: str | PathLike,
*,
index: bool = True,
transform: Callable[[T], T] | None = None,
transform: Callable | Transform | None = None,
) -> None:
self.root = Path(root)
if isinstance(root, str):
self.root = Path(root)

self.root = self.root.resolve()

if not self.root.exists():
raise FileNotFoundError

self._thread_safe_file = ThreadSafeFile(root, open)

self._index = Path(f"{self.root}.index.npy")
self.data = ThreadSafeFile(self.root, open)

if index:
if self._index.exists():
self.offsets, sizes = numpy.load(str(self._index))
else:
self.offsets, sizes = self._build_index()
offsets = Path(f"{self.root}.index.npy")

numpy.save(str(self._index), numpy.stack([self.offsets, sizes]))
if offsets.exists():
self.offsets, sizes = numpy.load(f"{offsets}")
else:
self.offsets, sizes = self._build_index()

self._transform = transform
numpy.save(
f"{offsets}",
numpy.stack([self.offsets, sizes]),
)

self.transform = transform

super().__init__(self.root, sizes)

def __getitem__(self, index: int) -> Tuple[str, str]:
x = self.get(index)

if self._transform:
x = self._transform(x)
if self.transform:
x = self.transform(x)

return x

def __len__(self) -> int:
return self.offsets.size

def get(self, index: int) -> str:
self._thread_safe_file.seek(self.offsets[index])
self.data.seek(self.offsets[index])

if index == len(self) - 1:
data = self._thread_safe_file.read()
data = self.data.read()
else:
data = self._thread_safe_file.read(
data = self.data.read(
self.offsets[index + 1] - self.offsets[index],
)

Expand Down
4 changes: 2 additions & 2 deletions src/beignet/datasets/_sized_sequence_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pathlib import Path
from os import PathLike

import numpy

Expand All @@ -8,7 +8,7 @@
class SizedSequenceDataset(SequenceDataset):
def __init__(
self,
root: str | Path,
root: str | PathLike,
sizes: numpy.ndarray,
*args,
**kwargs,
Expand Down
6 changes: 0 additions & 6 deletions src/beignet/datasets/_uniref100_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def __init__(
self,
root: str | Path,
*,
index: bool = True,
transform: Callable | Transform | None = None,
target_transform: Callable | Transform | None = None,
) -> None:
Expand All @@ -22,10 +21,6 @@ def __init__(
`download` is `True`, the directory where the dataset subdirectory
will be created and the dataset downloaded.
index : bool, optional
If `True`, caches the sequence indicies to disk for faster
re-initialization (default: `True`).
transform : Callable, optional
A `Callable` or `Transform` that that maps a sequence to a
transformed sequence (default: `None`).
Expand All @@ -38,7 +33,6 @@ def __init__(
"http://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref100/uniref100.fasta.gz",
root,
"md5:0354240a56f4ca91ff426f8241cfeb7d",
index=index,
transform=transform,
target_transform=target_transform,
)
6 changes: 0 additions & 6 deletions src/beignet/datasets/_uniref50_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def __init__(
self,
root: str | PathLike | None = None,
*,
index: bool = True,
transform: Callable | Transform | None = None,
target_transform: Callable | Transform | None = None,
) -> None:
Expand All @@ -23,10 +22,6 @@ def __init__(
`download` is `True`, the directory where the dataset subdirectory
will be created and the dataset downloaded.
index : bool, optional
If `True`, caches the sequence indexes to disk for faster
re-initialization (default: `True`).
transform : Callable, optional
A `Callable` or `Transform` that that maps a sequence to a
transformed sequence (default: `None`).
Expand All @@ -39,7 +34,6 @@ def __init__(
"http://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref50/uniref50.fasta.gz",
root,
"md5:e638c63230d13ad5e2098115b9cb5d8f",
index=index,
transform=transform,
target_transform=target_transform,
)
1 change: 0 additions & 1 deletion src/beignet/datasets/_uniref90_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
"http://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/uniref90.fasta.gz",
root,
"md5:6161bad4d7506365aee882fd5ff9c833",
index=index,
transform=transform,
target_transform=target_transform,
)

0 comments on commit 8728317

Please sign in to comment.