-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from Forest-Recovery-Digital-Companion/0.0.3
0.0.3
- Loading branch information
Showing
51 changed files
with
3,512 additions
and
987 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# See https://pre-commit.com for more information | ||
# See https://pre-commit.com/hooks.html for more hooks | ||
repos: | ||
- repo: https://github.com/psf/black | ||
rev: 23.10.0 | ||
hooks: | ||
- id: black | ||
- repo: https://github.com/PyCQA/flake8 | ||
rev: 6.1.0 | ||
hooks: | ||
- id: flake8 | ||
args: [ --max-line-length=79 ] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Model Tests | ||
|
||
This directory contains full tests model architectures. | ||
|
||
## `chestnut_dec_may` | ||
|
||
This test is the classic FRDC tests used in research papers. | ||
It uses December's data to train and May's data to test. | ||
|
||
The current baseline is 40% accuracy. | ||
|
||
### Confusion Matrix | ||
|
||
![chestnut_dec_may](chestnut_dec_may/confusion_matrix.png) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import torch | ||
from torchvision.transforms.v2 import RandomHorizontalFlip, RandomVerticalFlip | ||
|
||
|
||
def augmentation(t: torch.Tensor) -> torch.Tensor: | ||
"""Runs out augmentation on a tensor.""" | ||
t = RandomHorizontalFlip()(t) | ||
t = RandomVerticalFlip()(t) | ||
return t |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import lightning as pl | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import torch | ||
from seaborn import heatmap | ||
from sklearn.metrics import confusion_matrix | ||
|
||
from frdc.train import FRDCDataModule | ||
from frdc.train import FRDCModule | ||
from pipeline.model_tests.chestnut_dec_may.preprocess import preprocess | ||
from pipeline.model_tests.utils import get_dataset | ||
|
||
# Get our Test | ||
# TODO: Ideally, we should have a separate dataset for testing. | ||
segments, labels = get_dataset( | ||
"chestnut_nature_park", "20210510", "90deg43m85pct255deg/map" | ||
) | ||
|
||
# Prepare the datamodule and trainer | ||
dm = FRDCDataModule(segments=segments, preprocess=preprocess, batch_size=5) | ||
|
||
# TODO: Hacky way to load our LabelEncoder | ||
dm.le.classes_ = np.load("le.npy", allow_pickle=True) | ||
|
||
# Load the model | ||
m = FRDCModule.load_from_checkpoint( | ||
"lightning_logs/version_88/checkpoints/epoch=99-step=700.ckpt" | ||
) | ||
|
||
# Make predictions | ||
trainer = pl.Trainer(logger=False) | ||
pred = trainer.predict(m, datamodule=dm) | ||
y_pred = torch.concat(pred, dim=0).argmax(dim=1) | ||
y_true = dm.le.transform(labels) | ||
|
||
# Plot the confusion matrix | ||
cm = confusion_matrix(y_true, y_pred) | ||
|
||
plt.figure(figsize=(10, 10)) | ||
|
||
heatmap( | ||
cm, | ||
annot=True, | ||
xticklabels=dm.le.classes_, | ||
yticklabels=dm.le.classes_, | ||
cbar=False, | ||
) | ||
|
||
plt.tight_layout(pad=3) | ||
plt.title("Confusion Matrix") | ||
plt.xlabel("Predicted Label") | ||
plt.ylabel("True Label") | ||
plt.savefig("confusion_matrix.png") |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" Tests for the FaceNet model. | ||
This test is done by training a model on the 20201218 dataset, then testing on | ||
the 20210510 dataset. | ||
""" | ||
|
||
import lightning as pl | ||
import numpy as np | ||
import torch | ||
from lightning.pytorch.callbacks import ( | ||
LearningRateMonitor, | ||
ModelCheckpoint, | ||
EarlyStopping, | ||
) | ||
from torch.utils.data import TensorDataset, Dataset, Subset | ||
|
||
from frdc.models import FaceNet | ||
from frdc.train import FRDCDataModule, FRDCModule | ||
from pipeline.model_tests.chestnut_dec_may.augmentation import augmentation | ||
from pipeline.model_tests.chestnut_dec_may.preprocess import preprocess | ||
from pipeline.model_tests.utils import get_dataset | ||
|
||
|
||
def train_val_test_split(x: TensorDataset) -> list[Dataset, Dataset, Dataset]: | ||
# Defines how to split the dataset into train, val, test subsets. | ||
# TODO: Quite ugly as it uses the global variables segments_0 and | ||
# segments_1. Will need to refactor this. | ||
return [ | ||
Subset(x, list(range(len(segments_0)))), | ||
Subset( | ||
x, list(range(len(segments_0), len(segments_0) + len(segments_1))) | ||
), | ||
[], | ||
] | ||
|
||
|
||
# Prepare the dataset | ||
segments_0, labels_0 = get_dataset("chestnut_nature_park", "20201218", None) | ||
segments_1, labels_1 = get_dataset( | ||
"chestnut_nature_park", "20210510", "90deg43m85pct255deg/map" | ||
) | ||
|
||
# Concatenate the datasets | ||
segments = [*segments_0, *segments_1] | ||
labels = [*labels_0, *labels_1] | ||
|
||
BATCH_SIZE = 5 | ||
EPOCHS = 100 | ||
LR = 1e-3 | ||
|
||
# Prepare the datamodule and trainer | ||
dm = FRDCDataModule( | ||
# Input to the model | ||
segments=segments, | ||
# Output of the model | ||
labels=labels, | ||
# Preprocessing function | ||
preprocess=preprocess, | ||
# Augmentation function (Only on train) | ||
augmentation=augmentation, | ||
# Splitting function | ||
train_val_test_split=train_val_test_split, | ||
# Batch size | ||
batch_size=BATCH_SIZE, | ||
) | ||
|
||
trainer = pl.Trainer( | ||
max_epochs=EPOCHS, | ||
# Set the seed for reproducibility | ||
# TODO: Though this is set, the results are still not reproducible. | ||
deterministic=True, | ||
log_every_n_steps=4, | ||
callbacks=[ | ||
# Stop training if the validation loss doesn't improve for 4 epochs | ||
EarlyStopping(monitor="val_loss", patience=4, mode="min"), | ||
# Log the learning rate on TensorBoard | ||
LearningRateMonitor(logging_interval="epoch"), | ||
# Save the best model | ||
ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1), | ||
], | ||
) | ||
|
||
m = FRDCModule( | ||
# Our model is the "FaceNet" model | ||
# TODO: It's not really the FaceNet model, but a modified version of it. | ||
model_cls=FaceNet, | ||
model_kwargs=dict(n_out_classes=len(set(labels))), | ||
# We use the Adam optimizer | ||
optim_cls=torch.optim.Adam, | ||
# TODO: This is not fine-tuned. | ||
optim_kwargs=dict(lr=LR, weight_decay=1e-4, amsgrad=True), | ||
) | ||
|
||
trainer.fit(m, datamodule=dm) | ||
# TODO: Quite hacky, but we need to save the label encoder for prediction. | ||
np.save("le.npy", dm.le.classes_) |
Oops, something went wrong.