diff --git a/docs/source/conf.py b/docs/source/conf.py index 8832766d..c5d3676f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.2.2.post1" +release = "0.2.2.post2" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index ab2fa77f..8ef51b13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.2.2.post1" +version = "0.2.2.post2" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, @@ -18,7 +18,6 @@ keywords = [ "ensembles", "neural-networks", "predictive-uncertainty", - "pytorch", "reliable-ai", "trustworthy-machine-learning", "uncertainty", @@ -44,6 +43,7 @@ dependencies = [ "numpy<2", "opencv-python", "glest==0.0.1a0", + "rich>=10.2.2", ] [project.optional-dependencies] diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 89180b4e..92d299e8 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -64,6 +64,7 @@ def __init__( self.weight_subset = weight_subset self.hessian_struct = hessian_struct self.batch_size = batch_size + self.optimize_prior_precision = optimize_prior_precision if model is not None: self.set_model(model) @@ -80,7 +81,8 @@ def set_model(self, model: nn.Module) -> None: def fit(self, dataset: Dataset) -> None: dl = DataLoader(dataset, batch_size=self.batch_size) self.la.fit(train_loader=dl) - self.la.optimize_prior_precision(method="marglik") + if self.optimize_prior_precision: + self.la.optimize_prior_precision(method="marglik") def forward( self, diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index c6e92755..c2ae660c 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -196,15 +196,15 @@ def _init_metrics(self) -> None: ), "sc/AURC": AURC(), "sc/AUGRC": AUGRC(), - "sc/CovAt5Risk": CovAt5Risk(), - "sc/RiskAt80Cov": RiskAt80Cov(), + "sc/Cov@5Risk": CovAt5Risk(), + "sc/Risk@80Cov": RiskAt80Cov(), }, compute_groups=[ ["cls/Acc"], ["cls/Brier"], ["cls/NLL"], ["cal/ECE", "cal/aECE"], - ["sc/AURC", "sc/AUGRC", "sc/CovAt5Risk", "sc/RiskAt80Cov"], + ["sc/AURC", "sc/AUGRC", "sc/Cov@5Risk", "sc/Risk@80Cov"], ], ) @@ -212,7 +212,7 @@ def _init_metrics(self) -> None: self.test_cls_metrics = cls_metrics.clone(prefix="test/") if self.post_processing is not None: - self.ts_cls_metrics = cls_metrics.clone(prefix="test/ts_") + self.post_cls_metrics = cls_metrics.clone(prefix="test/post/") self.test_id_entropy = Entropy() @@ -463,7 +463,7 @@ def test_step( ) self.test_id_entropy(probs) self.log( - "test/cls/entropy", + "test/cls/Entropy", self.test_id_entropy, on_epoch=True, add_dataloader_idx=False, @@ -486,7 +486,7 @@ def test_step( pp_probs = F.softmax(pp_logits, dim=-1) else: pp_probs = pp_logits - self.ts_cls_metrics.update(pp_probs, targets) + self.post_cls_metrics.update(pp_probs, targets) elif self.eval_ood and dataloader_idx == 1: self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) @@ -529,7 +529,7 @@ def on_test_epoch_end(self) -> None: ) if self.post_processing is not None: - tmp_metrics = self.ts_cls_metrics.compute() + tmp_metrics = self.post_cls_metrics.compute() self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -573,7 +573,7 @@ def on_test_epoch_end(self) -> None: if self.post_processing is not None: self.logger.experiment.add_figure( "Reliabity diagram after calibration", - self.ts_cls_metrics["cal/ECE"].plot()[0], + self.post_cls_metrics["cal/ECE"].plot()[0], ) # plot histograms of logits and likelihoods diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index 2c36ce52..bd5cd6eb 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -99,17 +99,17 @@ def __init__( depth_metrics = MetricCollection( { - "SILog": SILog(), - "log10": Log10(), - "ARE": MeanGTRelativeAbsoluteError(), - "RSRE": MeanGTRelativeSquaredError(squared=False), - "RMSE": MeanSquaredError(squared=False), - "RMSELog": MeanSquaredLogError(squared=False), - "iMAE": MeanAbsoluteErrorInverse(), - "iRMSE": MeanSquaredErrorInverse(squared=False), - "d1": ThresholdAccuracy(power=1), - "d2": ThresholdAccuracy(power=2), - "d3": ThresholdAccuracy(power=3), + "reg/SILog": SILog(), + "reg/log10": Log10(), + "reg/ARE": MeanGTRelativeAbsoluteError(), + "reg/RSRE": MeanGTRelativeSquaredError(squared=False), + "reg/RMSE": MeanSquaredError(squared=False), + "reg/RMSELog": MeanSquaredLogError(squared=False), + "reg/iMAE": MeanAbsoluteErrorInverse(), + "reg/iRMSE": MeanSquaredErrorInverse(squared=False), + "reg/d1": ThresholdAccuracy(power=1), + "reg/d2": ThresholdAccuracy(power=2), + "reg/d3": ThresholdAccuracy(power=3), }, compute_groups=False, ) @@ -119,7 +119,7 @@ def __init__( if self.probabilistic: depth_prob_metrics = MetricCollection( - {"NLL": DistributionNLL(reduction="mean")} + {"reg/NLL": DistributionNLL(reduction="mean")} ) self.val_prob_metrics = depth_prob_metrics.clone(prefix="val/") self.test_prob_metrics = depth_prob_metrics.clone(prefix="test/") diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index cccf712d..12538db5 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -84,9 +84,9 @@ def __init__( reg_metrics = MetricCollection( { - "MAE": MeanAbsoluteError(), - "MSE": MeanSquaredError(squared=True), - "RMSE": MeanSquaredError(squared=False), + "reg/MAE": MeanAbsoluteError(), + "reg/MSE": MeanSquaredError(squared=True), + "reg/RMSE": MeanSquaredError(squared=False), }, compute_groups=True, ) @@ -96,7 +96,7 @@ def __init__( if self.probabilistic: reg_prob_metrics = MetricCollection( - {"NLL": DistributionNLL(reduction="mean")} + {"reg/NLL": DistributionNLL(reduction="mean")} ) self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/") self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/") diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index ac11c02a..3872ee6d 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -1,126 +1,201 @@ -import os -import shutil -import sys -from typing import Any +from collections import OrderedDict -from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop from lightning.pytorch.trainer.connectors.logger_connector.result import ( _OUT_DICT, ) -from lightning_utilities.core.apply_func import apply_to_collection -from torch import Tensor +from rich import get_console +from rich.console import Group +from rich.table import Table class TUEvaluationLoop(_EvaluationLoop): @staticmethod def _print_results(results: list[_OUT_DICT], stage: str) -> None: - # remove the dl idx suffix - results = [ - {k.split("/dataloader_idx_")[0]: v for k, v in result.items()} - for result in results + # test/cls: Classification Metrics + # test/cal: Calibration Metrics + # ood: OOD Detection Metrics + # test/sc: Selective Classification Metrics + # test/post: Post-Processing Metrics + # test/seg: Segmentation Metrics + + # In percentage + percentage_metrics = [ + "Acc", + "AUPR", + "AUROC", + "FPR95", + "Cov@5Risk", + "Risk@80Cov", + "pixAcc", + "mIoU", ] - metrics_paths = { - k - for keys in apply_to_collection( - results, dict, _EvaluationLoop._get_keys - ) - for k in keys - } - if not metrics_paths: - return - - metrics_strs = [":".join(metric) for metric in metrics_paths] - # sort both lists based on metrics_strs - metrics_strs, metrics_paths = zip( - *sorted(zip(metrics_strs, metrics_paths, strict=False)), - strict=False, - ) - - if len(results) == 2: - headers = ["In-Distribution", "Out-of-Distribution"] - else: - headers = [f"DataLoader {i}" for i in range(len(results))] - - # fallback is useful for testing of printed output - term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120 - max_length = int( - min( - max( - len(max(metrics_strs, key=len)), - len(max(headers, key=len)), - 25, - ), - term_size / 2, - ) - ) - - rows: list[list[Any]] = [[] for _ in metrics_paths] + metrics = {} for result in results: - for metric, row in zip(metrics_paths, rows, strict=False): - val = _EvaluationLoop._find_value(result, metric) - if val is not None: - if isinstance(val, Tensor): - val = val.item() if val.numel() == 1 else val.tolist() - row.append(f"{val:.5f}") + for key, value in result.items(): + if key.startswith("test/cls"): + if "cls" not in metrics: + metrics["cls"] = {} + metric_name = key.split("/")[-1] + metrics["cls"].update({metric_name: value}) + elif key.startswith("test/cal"): + if "cal" not in metrics: + metrics["cal"] = {} + metric_name = key.split("/")[-1] + metrics["cal"].update({metric_name: value}) + elif key.startswith("ood"): + if "ood" not in metrics: + metrics["ood"] = {} + metric_name = key.split("/")[-1] + metrics["ood"].update({metric_name: value}) + elif key.startswith("test/sc"): + if "sc" not in metrics: + metrics["sc"] = {} + metric_name = key.split("/")[-1] + metrics["sc"].update({metric_name: value}) + elif key.startswith("test/post"): + if "post" not in metrics: + metrics["post"] = {} + metric_name = key.split("/")[-1] + metrics["post"].update({metric_name: value}) + elif key.startswith("test/seg"): + if "seg" not in metrics: + metrics["seg"] = {} + metric_name = key.split("/")[-1] + metrics["seg"].update({metric_name: value}) + elif key.startswith("test/reg"): + if "reg" not in metrics: + metrics["reg"] = {} + metric_name = key.split("/")[-1] + metrics["reg"].update({metric_name: value}) + + tables = [] + + first_col_name = f"{stage.capitalize()} metric" + + if "cls" in metrics: + table = Table() + table.add_column( + first_col_name, justify="center", style="cyan", width=12 + ) + table.add_column( + "Classification", justify="center", style="magenta", width=25 + ) + cls_metrics = OrderedDict(sorted(metrics["cls"].items())) + for metric, value in cls_metrics.items(): + if metric in percentage_metrics: + value = value * 100 + table.add_row(metric, f"{value.item():.2f}%") else: - row.append(" ") - - # keep one column with max length for metrics - num_cols = int((term_size - max_length) / max_length) + table.add_row(metric, f"{value.item():.5f}") + tables.append(table) - for i in range(0, len(headers), num_cols): - table_headers = headers[i : (i + num_cols)] - table_rows = [row[i : (i + num_cols)] for row in rows] - - table_headers.insert(0, f"{stage} Metric".capitalize()) + if "seg" in metrics: + table = Table() + table.add_column( + first_col_name, justify="center", style="cyan", width=12 + ) + table.add_column( + "Segmentation", justify="center", style="magenta", width=25 + ) + seg_metrics = OrderedDict(sorted(metrics["seg"].items())) + for metric, value in seg_metrics.items(): + if metric in percentage_metrics: + value = value * 100 + table.add_row(metric, f"{value.item():.2f}%") + else: + table.add_row(metric, f"{value.item():.5f}") + tables.append(table) - if _RICH_AVAILABLE: - from rich import get_console - from rich.table import Column, Table + if "reg" in metrics: + table = Table() + table.add_column( + first_col_name, justify="center", style="cyan", width=12 + ) + table.add_column( + "Regression", justify="center", style="magenta", width=25 + ) + reg_metrics = OrderedDict(sorted(metrics["reg"].items())) + for metric, value in reg_metrics.items(): + if metric in percentage_metrics: # coverage: ignore + value = value * 100 + table.add_row(metric, f"{value.item():.2f}%") + else: + table.add_row(metric, f"{value.item():.5f}") + tables.append(table) - columns = [ - Column( - h, justify="center", style="magenta", width=max_length - ) - for h in table_headers - ] - columns[0].style = "cyan" + if "cal" in metrics: + table = Table() + table.add_column( + first_col_name, justify="center", style="cyan", width=12 + ) + table.add_column( + "Calibration", justify="center", style="magenta", width=25 + ) + cal_metrics = OrderedDict(sorted(metrics["cal"].items())) + for metric, value in cal_metrics.items(): + if metric in percentage_metrics: + value = value * 100 + table.add_row(metric, f"{value.item():.2f}%") + else: + table.add_row(metric, f"{value.item():.5f}") + tables.append(table) - table = Table(*columns) - for metric, row in zip(metrics_strs, table_rows, strict=False): - row.insert(0, metric) - table.add_row(*row) + if "ood" in metrics: + table = Table() + table.add_column( + first_col_name, justify="center", style="cyan", width=12 + ) + table.add_column( + "OOD Detection", justify="center", style="magenta", width=25 + ) + ood_metrics = OrderedDict(sorted(metrics["ood"].items())) + for metric, value in ood_metrics.items(): + if metric in percentage_metrics: + value = value * 100 + table.add_row(metric, f"{value.item():.2f}%") + else: + table.add_row(metric, f"{value.item():.5f}") + tables.append(table) - console = get_console() - console.print(table) - else: # coverage: ignore - row_format = f"{{:^{max_length}}}" * len(table_headers) - half_term_size = int(term_size / 2) + if "sc" in metrics: + table = Table() + table.add_column( + first_col_name, justify="center", style="cyan", width=12 + ) + table.add_column( + "Selective Classification", + justify="center", + style="magenta", + width=25, + ) + sc_metrics = OrderedDict(sorted(metrics["sc"].items())) + for metric, value in sc_metrics.items(): + if metric in percentage_metrics: + value = value * 100 + table.add_row(metric, f"{value.item():.2f}%") + else: + table.add_row(metric, f"{value.item():.5f}") + tables.append(table) - try: - # some terminals do not support this character - if sys.stdout.encoding is not None: - "─".encode(sys.stdout.encoding) - except UnicodeEncodeError: - bar_character = "-" + if "post" in metrics: + table = Table() + table.add_column( + first_col_name, justify="center", style="cyan", width=12 + ) + table.add_column( + "Post-Processing", justify="center", style="magenta", width=25 + ) + post_metrics = OrderedDict(sorted(metrics["post"].items())) + for metric, value in post_metrics.items(): + if metric in percentage_metrics: + value = value * 100 + table.add_row(metric, f"{value.item():.2f}%") else: - bar_character = "─" - bar = bar_character * term_size + table.add_row(metric, f"{value.item():.5f}") + tables.append(table) - lines = [bar, row_format.format(*table_headers).rstrip(), bar] - for metric, row in zip(metrics_strs, table_rows, strict=False): - # deal with column overflow - if len(metric) > half_term_size: - while len(metric) > half_term_size: - row_metric = metric[:half_term_size] - metric = metric[half_term_size:] - lines.append( - row_format.format(row_metric, *row).rstrip() - ) - lines.append(row_format.format(metric, " ").rstrip()) - else: - lines.append(row_format.format(metric, *row).rstrip()) - lines.append(bar) - print(os.linesep.join(lines)) + console = get_console() + group = Group(*tables) + console.print(group)