Skip to content

Commit

Permalink
Merge pull request #23 from mlexchange/unit_tests
Browse files Browse the repository at this point in the history
Add unit test and refactor TiledDataset
  • Loading branch information
TibbersHao authored Apr 15, 2024
2 parents f459c96 + b50ec92 commit e1d57e6
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ jobs:
- name: Test formatting with black
run: |
black . --check
- name: pytest
run: |
pytest
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

# vscode
.vscode/
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
[tool.isort]
profile = "black"

[tool.pytest.ini_options]
pythonpath = [
"src"
]
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ black==24.3.0
flake8==7.0.0
isort==5.13.2
pre-commit==3.6.2
tiled[all]==0.1.0a114
pytest
Empty file added src/__init__.py
Empty file.
Empty file added src/_tests/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions src/_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import pytest
from tiled.catalog import from_uri
from tiled.client import Context, from_context
from tiled.server.app import build_app


@pytest.fixture
def catalog(tmpdir):
adapter = from_uri(
f"sqlite+aiosqlite:///{tmpdir}/catalog.db",
writable_storage=str(tmpdir),
init_if_not_exists=True,
)
yield adapter


@pytest.fixture
def app(catalog):
app = build_app(catalog)
yield app


@pytest.fixture
def context(app):
with Context.from_app(app) as context:
yield context


@pytest.fixture
def client(context):
"Fixture for tests which only read data"
client = from_context(context)
recons_container = client.create_container("reconstructions")
recons_container.write_array(np.zeros((2, 3, 3), dtype=np.int8), key="recon1")
masks_container = client.create_container("uid0001", metadata={"mask_idx": ["0"]})
masks_container.write_array(np.zeros((1, 3, 3), dtype=np.int8), key="mask")
yield client
16 changes: 16 additions & 0 deletions src/_tests/test_tiled_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ..tiled_dataset import TiledDataset


def test_tiled_dataset(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"],
)
assert tiled_dataset
assert tiled_dataset[0].shape == (3, 3)


def test_tiled_dataset_with_masks(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"], mask_tiled_client=client["uid0001"]
)
assert tiled_dataset[0].shape == (3, 3)
15 changes: 11 additions & 4 deletions src/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import yaml
from qlty.qlty2D import NCYXQuilt
from tiled.client import from_uri
from torchvision import transforms

from network import baggin_smsnet_ensemble, load_network
Expand Down Expand Up @@ -52,11 +53,17 @@

print("Parameters loaded successfully.")

data_tiled_client = from_uri(
io_parameters.data_tiled_uri, api_key=io_parameters.data_tiled_api_key
)
mask_tiled_client = None
if io_parameters.mask_tiled_uri:
mask_tiled_client = from_uri(
io_parameters.mask_tiled_uri, api_key=io_parameters.mask_tiled_api_key
)
dataset = TiledDataset(
data_tiled_uri=io_parameters.data_tiled_uri,
data_tiled_api_key=io_parameters.data_tiled_api_key,
mask_tiled_uri=io_parameters.mask_tiled_uri,
mask_tiled_api_key=io_parameters.mask_tiled_api_key,
data_tiled_client,
mask_tiled_client=mask_tiled_client,
is_training=False,
using_qlty=False,
qlty_window=model_parameters.qlty_window,
Expand Down
26 changes: 10 additions & 16 deletions src/tiled_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import torch
from qlty import cleanup
from qlty.qlty2D import NCYXQuilt
from tiled.client import from_uri


class TiledDataset(torch.utils.data.Dataset):

def __init__(
self,
data_tiled_uri,
data_tiled_api_key=None,
mask_tiled_uri=None,
mask_tiled_api_key=None,
data_tiled_client,
mask_tiled_client=None,
is_training=None,
using_qlty=False,
qlty_window=50,
Expand All @@ -33,20 +31,16 @@ def __init__(
Return:
ml_data: tuple, (data_tensor, mask_tensor)
"""
self.data_tiled_uri = data_tiled_uri
self.data_client = from_uri(data_tiled_uri, api_key=data_tiled_api_key)
self.mask_tiled_uri = mask_tiled_uri
if mask_tiled_uri:
self.mask_client_one_up = from_uri(
mask_tiled_uri, api_key=mask_tiled_api_key
)
self.mask_client = self.mask_client_one_up["mask"]
self.mask_idx = [
int(idx) for idx in self.mask_client_one_up.metadata["mask_idx"]
]

self.data_client = data_tiled_client
self.mask_client = None
if mask_tiled_client:
self.mask_client = mask_tiled_client["mask"]
self.mask_idx = [int(idx) for idx in mask_tiled_client.metadata["mask_idx"]]
else:
self.mask_client = None
self.mask_idx = None

self.transform = transform
if using_qlty:
# this object handles unstitching and stitching
Expand Down
29 changes: 20 additions & 9 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import yaml
from dlsia.core.train_scripts import Trainer
from dvclive import Live
from tiled.client import from_uri
from torchvision import transforms

from network import build_network
Expand All @@ -17,15 +18,12 @@
TUNet3PlusParameters,
TUNetParameters,
)
from seg_utils import crop_split_load, train_segmentation
from seg_utils import crop_split_load
from tiled_dataset import TiledDataset
from utils import create_directory

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("yaml_path", type=str, help="path of yaml file for parameters")
args = parser.parse_args()

def train(args):
# Open the YAML file for all parameters
with open(args.yaml_path, "r") as file:
# Load parameters
Expand Down Expand Up @@ -59,11 +57,17 @@
# Create Result Directory if not existed
create_directory(model_dir)

data_tiled_client = from_uri(
io_parameters.data_tiled_uri, api_key=io_parameters.data_tiled_api_key
)
mask_tiled_client = None
if io_parameters.mask_tiled_uri:
mask_tiled_client = from_uri(
io_parameters.mask_tiled_uri, api_key=io_parameters.mask_tiled_api_key
)
dataset = TiledDataset(
data_tiled_uri=io_parameters.data_tiled_uri,
data_tiled_api_key=io_parameters.data_tiled_api_key,
mask_tiled_uri=io_parameters.mask_tiled_uri,
mask_tiled_api_key=io_parameters.mask_tiled_api_key,
data_tiled_client=data_tiled_client,
mask_tiled_client=mask_tiled_client,
is_training=True,
using_qlty=False,
qlty_window=model_parameters.qlty_window,
Expand Down Expand Up @@ -146,3 +150,10 @@
torch.cuda.empty_cache()

print(f"{network} trained successfully.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("yaml_path", type=str, help="path of yaml file for parameters")
args = parser.parse_args()
train(args)
8 changes: 6 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,13 @@ def allocate_array_space(
# For now, only save image 1 by 1 regardless of the batch_size_inference.
structure.chunks = ((1,) * array_shape[0], (array_shape[1],), (array_shape[2],))

mask_uri = None
if tiled_dataset.mask_client is not None:
mask_uri = tiled_dataset.mask_client.uri

metadata = {
"data_uri": tiled_dataset.data_tiled_uri,
"mask_uri": tiled_dataset.mask_tiled_uri,
"data_uri": tiled_dataset.data_client.uri,
"mask_uri": mask_uri,
"mask_idx": tiled_dataset.mask_idx,
"uid": uid,
"model": model,
Expand Down

0 comments on commit e1d57e6

Please sign in to comment.