diff --git a/.vscode/launch.json b/.vscode/launch.json index 364a72e..3a3bd88 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -16,7 +16,7 @@ "websockets", ], "jinja": true, - "justMyCode": true, + "justMyCode": false, "env": { "PYTHONPATH": "${workspaceFolder}:${PYTHONPATH}", "LOG_LEVEL": "DEBUG" diff --git a/README.md b/README.md index 88873c2..bb3b8e7 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,15 @@ The Evaluator for Model Benchmark is a versatile application designed to assess the performance of various machine learning models in a consistent and reliable manner. This app provides a streamlined process for evaluating models and generating comprehensive reports to help you learn different metrics and make informed decisions. +The Evaluator app offers a range of evaluation metrics, including precision, recall, F1 score, mAP, and more. The app also includes a **Model Comparison** feature that allows you to compare the performance of multiple models side by side. + +**Changelog:** + +- **v0.1.0** – Public release (for object detection task type) +- **v0.1.2** – Support for instance segmentation task type +- **v0.1.4** – Speedtest benchmark added +- **v0.1.15** – Model Comparison feature added + ## Preparation Before running the Evaluator for Model Benchmark, please ensure that you have the following: @@ -31,10 +40,20 @@ Before running the Evaluator for Model Benchmark, please ensure that you have th ## How To Run -**Step 1:** Open the app from the Supervisely Ecosystem. +**Step 1:** Open and launch the app from the Supervisely Ecosystem. + +**Step 2**: + +- _Model Evaluation_: + + **Step 2.1:** Select the Ground Truth project and the model you want to evaluate. + + **Step 2.2:** Press the “Evaluate” button to start the evaluation process. After the evaluation is complete, you can find a link to the report in the app’s interface. + +- _Model Comparison:_ -**Step 2:** Select the project you wish to evaluate. + **Step 2.1:** Select the folder with the Ground Truth project name. -**Step 3:** Choose the model you want to evaluate from the list of served models. + **Step 2.1:** Select one or more evaluation folders with the model name. -**Step 4:** Start the evaluation process by clicking the “Run” button. The app will process the data and evaluate the model(s) based on the selected benchmarks. You can monitor the progress in the app’s interface. + **Step 2.2:** Press the “Compare” button to start the comparison process. After the comparison is complete, you can find a link to the report in the app’s interface. diff --git a/config.json b/config.json index 29b3ff6..e83dbd5 100644 --- a/config.json +++ b/config.json @@ -2,8 +2,8 @@ "type": "app", "version": "2.0.0", "name": "Evaluator for Model Benchmark", - "description": "Evaluate the performance of the NN model", - "categories": ["neural network", "images", "object detection"], + "description": "Evaluate the performance of the NN model and compare it with the results of other models", + "categories": ["neural network", "images", "object detection", "instance segmentation"], "icon": "https://github.com/supervisely-ecosystem/model-benchmark/releases/download/v0.0.4/icon-mb.png", "icon_cover": true, "headless": false, @@ -11,19 +11,6 @@ "task_location": "workspace_tasks", "entrypoint": "python -m uvicorn src.main:app --host 0.0.0.0 --port 8000", "port": 8000, - "docker_image": "supervisely/model-benchmark:1.0.13", - "instance_version": "6.11.19", - "context_menu": { - "target": ["images_project"] - }, - "modal_template": "src/modal.html", - "modal_template_state": { - "sessionId": null, - "autoStart": false, - "sessionOptions": { - "sessionTags": ["deployed_nn"], - "showLabel": false, - "size": "small" - } - } + "docker_image": "supervisely/model-benchmark:1.0.15", + "instance_version": "6.11.19" } diff --git a/descriptions.md b/descriptions.md deleted file mode 100644 index 7d37e34..0000000 --- a/descriptions.md +++ /dev/null @@ -1,332 +0,0 @@ -# Overview - -## Key Metrics - -Here, we comprehensively assess the model's performance by presenting a broad set of metrics, including mAP (mean Average Precision), Precision, Recall, IoU (Intersection over Union), Classification Accuracy, Calibration Score, and Inference Speed. - -- **Mean Average Precision (mAP)**: An overall measure of detection performance. mAP calculates the average precision across all classes at different levels of IoU thresholds and precision-recall trade-offs. -- **Precision**: Precision indicates how often the model's predictions are actually correct when it predicts an object. This calculates the ratio of correct detections to the total number of detections made by the model. -- **Recall**: Recall measures the model's ability to find all relevant objects in a dataset. This calculates the ratio of correct detections to the total number of instances in a dataset. -- **Intersection over Union (IoU)**: IoU measures how closely predicted bounding boxes match the actual (ground truth) bounding boxes. It is calculated as the area of overlap between the predicted bounding box and the ground truth bounding box, divided by the area of union of these bounding boxes. -- **Classification Accuracy**: We separately measure the model's capability to correctly classify objects. It’s calculated as a proportion of correctly classified objects among all matched detections. The predicted detection is considered matched if it overlaps a ground true bounding box with IoU higher than 0.5. -- **Calibration Score**: This score represents the consistency of predicted probabilities (or confidence scores) made by the model, evaluating how well the predicted probabilities align with actual outcomes. A well-calibrated model means that when it predicts a detection with, say, 80% confidence, approximately 80% of those predictions should actually be correct. -- **Inference Speed**: The number of frames per second (FPS) the model can process, measured with a batch size of 1. The inference speed is important in applications, where real-time object detection is required. Additionally, slower models pour more GPU resources, so their inference cost is higher. - - - -## Model Predictions - -In this section you can visually assess the model performance through examples. This helps users better understand model capabilities and limitations, giving an intuitive grasp of prediction quality in different scenarios. - -**(!) Info** - -You can choose one of the sorting method: - -- **Auto**: The algorithm is trying to gather a diverse set of images that illustrate the model's performance across various scenarios. -- **Least accurate**: Displays images where the model made more errors. -- **Most accurate**: Displays images where the model made fewer or no errors. -- **Dataset order**: Displays images in the original order of the dataset. - - - -**Prediction Table** - -The table helps you in finding samples with specific cases of interest. You can sort by parameters such as the number of predictions, or specific a metric, e.g, recall, then click on a row to view this image and predictions. - -**(!) Info** - -**Example**: you can sort by **FN (**False Negatives) in descending order to identify samples where the model failed to detect many objects. - - - -## What is YOLOv8 model (collapse)? - -Можно также добавить ссылку на наш блог пост, если есть - -!\[blog post link\] - -**О чем еще здесь можно рассказать:** - -- Ключевая инфа о модели текстом: год, конференция, paper, гитхаб, какой скор на лидерборде от авторов, в каком сценарии эта модель была или есть SOTA и в каком году. Что-то ещё из того что писали про свою модель сами авторы, взять из ридми на гитхабе. -- Особенности модели, чем отличается от остальных, какую проблему решали авторы этой моделью. -- Для чего эта модель идеально подходит, какие сценарии использования? Возможно авторы проектировали модель под специальный use case, описать это. Например, YOLO подходит для real-time object detection, для real-time detection на видео. -- Историческая справка, как развивалась модель, прошлые версии. -- Краткий анализ метрик. На чем модель фейлит, а в чем хорошо предсказывает. - -## Expert insights? - -linkedin - ответ на вопрос когда применять когда нет, что лучше или хуже, что нужно учитывать. текст в свободной форме - -## How To Use: Training, inference, evaluation loop (collapse) - -Однотипная диаграмка, и небольшой текст со ссылками - Sly apps, inference notebooks, docker images, … небольшой раздел со ссылками на документацию (embeddings sampling, improvement loop, active learning, labeling jobs, model comparison, .… – стандартизован для всех моделей). какие-то модели будут частично интегрированы - -Jupyter notebooks + python scripts + apps + videos + guides + … - - - -# Detailed Metrics Analysis - -**Note about confidence threshold:** - -To calculate various metrics, we must set a _confidence threshold_, which also is necessary in deploying a model and applying it to any task. This hyperparameter significantly influences the results of metrics. To eliminate human bias in this process, we automate the determination of the confidence threshold. The threshold is selected based on the best _f1-score_ (guaranteed to give the best f1-score on the given dataset), ensuring a balanced trade-off between precision and recall. - -**| F1-optimal confidence threshold = 0.35** (calculated for the given model and dataset) - -Подробнее о том как мы считаем best confidence threshold: \[link\] - - - -## Outcome Counts - -This chart is used to evaluate the overall model performance by breaking down all predictions into True Positives (TP), False Positives (FP), and False Negatives (FN). This helps to visually assess the type of errors the model often encounters. - - -## Recall - -This section measures the ability of the model to detect **all relevant instances in the dataset**. In other words, this answers the question: “Of all instances in the dataset, how many of them is the model managed to find out?” - -To measure this, we calculate **Recall.** Recall counts errors, when the model does not detect an object that actually is present in a dataset and should be detected. Recall is calculated as the portion of correct predictions (true positives) over all instances in the dataset (true positives + false negatives). - -More information: \[link\]: - -Там рассказать что recall считается отдельно по всем классам и для каждого IoU threshold, а затем берется среднее по всему этому. - -Recall is the portion of **correct** predictions (true positives) over all actual instances in the dataset (true positives + false negative). A recall of 0.7 indicates that the model identifies 70% of all actual positives in the dataset. - -Recall is averaged across all classes and IoU thresholds \[0.50:0.95\]. - -**| Recall** (?) **\= 0.51** _(green-red color scale)_ - -The model correctly found **4615 of 9012** total instances in the dataset. - -**Per-class Recall** - -This chart further analyzes Recall, breaking it down to each class in separate. - -**(!) Info** - -Since the overall recall is calculated as an average across all classes, we provide a chart showing the recall for each individual class. This illustrates how much each class contributes to the overall recall. - -_Bars in the chart are sorted by F1-score to keep a unified order of classes between different charts._ - - - -## Precision - -This section measures the accuracy of all predictions made by the model. In other words, this answers the question: “Of all predictions made by the model, how many of them are actually correct?”. - -To measure this, we calculate **Precision.** Precision counts errors, when the model predicts an object (bounding box), but the image has no objects in this place (or it has another class than the model predicted). Precision is calculated as a portion of correct predictions (true positives) over all model’s predictions (true positives + false positives). - -More information: \[link\]: - -(?) - Precision is the portion of **correct** predictions (true positives) over all model’s predictions (true positives + false positives). A precision of 0.8 means that 80% of the instances that the model predicted as positive (e.g., detected objects) are actually positive (correct detections). - -Precision is averaged across all classes and IoU thresholds \[0.50:0.95\]. - -**| Precision (?) = 0.66** - -The model correctly predicted **5012 of 6061** predictions made by the model in total. - -**Per-class Precision** - -This chart further analyzes Precision, breaking it down to each class in separate. - -**(!) Info** - -Since the overall precision is computed as an average across all classes, we provide a chart showing the precision for each class individually. This illustrates how much each class contributes to the overall precision. - -_Bars in the chart are sorted by F1-score to keep a unified order of classes between different charts._ - - - -## Recall vs. Precision - -This section compares Precision and Recall on a common graph, identifying **disbalance** between these two. - -_Bars in the chart are sorted by F1-score to keep a unified order of classes between different charts._ - - - -## Precision-Recall Curve - -Precision-Recall curve is an overall performance indicator. It helps to visually assess both precision and recall for all predictions made by the model on the whole dataset. This gives you an understanding of how precision changes as you attempt to increase recall, providing a view of **trade-offs between precision and recall**. Ideally, a high-quality model will maintain strong precision as recall increases. This means that as you move from left to right on the curve, there should not be a significant drop in precision. Such a model is capable of finding many relevant instances, maintaining a high level of precision. - -🔽(Collapse) **About Trade-offs between precision and recall** - -A system with high recall but low precision returns many results, but most of its predictions are incorrect or redundant (false positive). A system with high precision but low recall is just the opposite, returning very few results, most of its predictions are correct. An ideal system with high precision and high recall will return many results, with all results predicted correctly. - -🔽(Collapse) **What is PR curve?** - -Imagine you sort all the predictions by their confidence scores from highest to lowest and write it down in a table. As you iterate over each sorted prediction, you classify it as a true positive (TP) or a false positive (FP). For each prediction, you then calculate the cumulative precision and recall so far. Each prediction is plotted as a point on a graph, with recall on the x-axis and precision on the y-axis. Now you have a plot very similar to the PR-curve, but it appears as a zig-zag curve due to variations as you move from one prediction to the next. - -**Forming the Actual PR Curve**: The true PR curve is derived by plotting only the maximum precision value for each recall level across all thresholds. This means you connect only the highest points of precision for each segment of recall, smoothing out the zig-zags and forming a curve that typically slopes downward as recall increases. - -**mAP = 0.51** - - - -### Precision-Recall curve by Class - -In this plot, you can evaluate PR curve for each class individually. - - - -## Classification Accuracy - -This section investigates cases where the model correctly localizes a bounding box, but predicts a wrong class label. For example, the model might confuse a motorbike with a bicycle. In this case, the model correctly identified that the object is present on the image, but assigned a wrong label to it. - -To quantify it, we calculate **Classification accuracy**. This is a portion of correctly classified objects to the total number of correctly localized objects ?-_(the object is localized correctly if the IoU between a prediction and a ground truth box is more than 0.5)_. In other words, if the model correctly found that an object is present on the image, how often it assigns a correct label to it? - -**| Classification Accuracy: 0.96** - -The model correctly classified **52** predictions **of 54** total predictions, that are matched to the ground truth. - -### Confusion Matrix - -Confusion matrix helps to find the number of confusions between different classes made by the model. Each row of the matrix represents the instances in a ground truth class, while each column represents the instances in a predicted class. The diagonal elements represent the number of correct predictions for each class (True Positives), and the off-diagonal elements show misclassifications. - - - -**Mini Confusion Matrix** - -- skip for now - - -### Frequently Confused Classes - -This chart displays the most frequently confused pairs of classes. In general, it finds out which classes visually seem very similar to the model. - -The chart calculates the **probability of confusion** between different pairs of classes. For instance, if the probability of confusion for the pair “car - truck” is 0.15, this means that when the model predicts either “car” or “truck”, there is a 15% chance that the model might mistakenly predict one instead of the other. - -The measure is class-symmetric, meaning that the probability of confusing a car with a truck is equal to the probability of confusing a truck with a car. - -_switch: Probability / Amount_ - - - -## Localization Accuracy (IoU) - -This section measures how closely predicted bounding boxes generated by the model are aligned with the actual (ground truth) bounding boxes. - -To measure it, we calculate the **Intersection over Union (IoU).** Intuitively, the higher the IoU, the closer two bounding boxes are. IoU is calculated by dividing the **area of overlap** between the predicted bounding box and the ground truth bounding box by the **area of union** of these two boxes. - -**| Avg. IoU** (?) **= 0.86** - -### IoU Distribution - -This histogram represents the distribution of IoU scores between all predictions and their matched ground truth objects. This gives you a sense of how well the model aligns bounding boxes. Ideally, if the model aligns boxes very well, rightmost bars (from 0.9 to 1.0 IoU) should be much higher than others. - - - -## Calibration Score - -This section analyzes confidence scores (or predicted probabilities) that the model generates for every predicted bounding box. - -🔽(Collapse) **What is calibration?** - -In some applications, it's crucial for a model not only to make accurate predictions but also to provide reliable **confidence levels**. A well-calibrated model aligns its confidence scores with the actual likelihood of predictions being correct. For example, if a model claims 90% confidence for predictions but they are correct only half the time, it is **overconfident**. Conversely, **underconfidence** occurs when a model assigns lower confidence scores than the actual likelihood of its predictions. In the context of autonomous driving, this might cause a vehicle to brake or slow down too frequently, reducing travel efficiency and potentially causing traffic issues. - -To evaluate the calibration, we draw a **Reliability Diagram** and calculate **Expected Calibration Error** (ECE) and **Maximum Calibration Error** (MCE). - -### Reliability Diagram - -Reliability diagram, also known as a Calibration curve, helps in understanding whether the confidence scores of detections accurately represent the true probability of a correct detection. A well-calibrated model means that when it predicts a detection with, say, 80% confidence, approximately 80% of those predictions should actually be correct. - -🔽(Collapse) **How to interpret the Calibration curve:** - -1. **The curve is above the Ideal Line (Underconfidence):** If the calibration curve is consistently above the ideal line, this indicates underconfidence. The model’s predictions are more correct than the confidence scores suggest. For example, if the model predicts a detection with 70% confidence but, empirically, 90% of such detections are correct, the model is underconfident. -2. **The curve is below the Ideal Line (Overconfidence):** If the calibration curve is below the ideal line, the model exhibits overconfidence. This means it is too sure of its predictions. For instance, if the model predicts with 80% confidence but only 60% of these predictions are correct, it is overconfident. - -To quantify the calibration score, we calculate **Expected Calibration Error (ECE).** Intuitively, ECE can be viewed as a deviation of the Calibration curve from the Perfectly calibrated line. When ECE is high, we can not trust predicted probabilities so much. - -**| Expected Calibration Error (ECE)** (?) **= 0.15** - -## Confidence Score Profile - -This section is going deeper in analyzing confidence scores. It gives you an intuition about how these scores are distributed and helps to find the best confidence threshold suitable for your task or application. - -**Confidence Score Profile** - -This chart provides a comprehensive view about predicted confidence scores. It is used to determine an optimal _confidence threshold_ based on your requirements. - -This plot shows you what the metrics will be if you choose a specific confidence threshold. For example, if you set the threshold to 0.32, you can see on the plot what the precision, recall and f1-score will be for this threshold. - -🔽(Collapse) **How to plot Confidence score Profile?** - -First, we sort all predictions by confidence scores from highest to lowest. As we iterate over each prediction we calculate the cumulative precision, recall and f1-score so far. Each prediction is plotted as a point on a graph, with a confidence score on the x-axis and one of three metrics on the y-axis (precision, recall, f1-score). - -**To find an optimal threshold**, you can select the confidence score under the maximum of the f1-score line. This f1-optimal threshold ensures the balance between precision and recall. You can select a threshold according to your desired trade-offs. - -**F1-optimal confidence threshold = _0.263_** - - - -### Confidence Distribution - -This graph helps to assess whether high confidence scores correlate with correct detections (True Positives) and whether low confidence scores are mostly associated with incorrect detections (False Positives). - -Additionally, it provides a view of how predicted probabilities are distributed. Whether the model skews probabilities to lower or higher values, leading to imbalance? - -Ideally, the histogram for TP predictions should have higher confidence, indicating that the model is sure about its correct predictions, and the FP predictions should have very low confidence, or not present at all. - -В описании приложить схематично идеальный график. Объяснения - -_Сделать stacked bar chart._ - - - -## Class Comparison - -This section analyzes the model's performance for all classes in a common plot. It discovers which classes the model identifies correctly, and which ones it often gets wrong. - -### Average Precision by Class - -A quick visual comparison of the model performance across all classes. Each axis in the chart represents a different class, and the distance to the center indicates the Average Precision for that class. - -### Outcome Counts by Class - -This chart breaks down all predictions into True Positives (TP), False Positives (FP), and False Negatives (FN) by classes. This helps to visually assess the type of errors the model often encounters for each class. - -**Normalization:** by default the normalization is used for better intraclass comparison. The total outcome counts are divided by the number of ground truth instances of the corresponding class. This is useful, because the sum of TP + FN always gives 1.0, representing all ground truth instances for a class, that gives a visual understanding of what portion of instances the model detected. So, if a green bar (TP outcomes) reaches the 1.0, this means the model is managed to predict all objects for the class. Everything that is higher than 1.0 corresponds to False Positives, i.e, redundant predictions. You can turn off the normalization switching to absolute values. - -_Bars in the chart are sorted by F1-score to keep a unified order of classes between different charts._ - - - -_switch to absolute values:_ - - - -## Inference speed - -We evaluate the inference speed in two scenarios: real-time inference (batch size is 1), and batch processing. We also run the model in optimized runtime environments, such as ONNX Runtime and Tensor RT, using consistent hardware. This approach provides a fair comparison of model efficiency and speed. To assess the inference speed we run the model forward 100 times and average it. - -\[**Methodology**\] 🔽 (collapsable) - -Setting 1: **Real-time processing** - -We measure the time spent processing each image individually by setting batch size to 1. This simulates real-time data processing conditions, such as those encountered in video streams, ensuring the model performs effectively in scenarios where data is processed frame by frame. - -Setting 2: **Parallel processing** - -To evaluate the model’s efficiency in parallel processing, we measure the processing speed with batch size of 8 and 16. This helps us understand how well the model scales when processing multiple images simultaneously, which is crucial for applications requiring high throughput. - -Setting 3: **Optimized runtime** - -We run the model in various runtime environments, including **ONNX Runtime** and **TensorRT**. This is important because python code can be suboptimal. These runtimes often provide significant performance improvements. - -**Consistent hardware for fair comparison** - -To ensure a fair comparison, we use a single hardware setup, specifically an NVIDIA RTX 3060 GPU. - -**Inference details** - -We divide the inference process into three stages: **preprocess, inference,** and **postprocess** to provide insights into where optimization efforts should be focused. Additionally, it gives us another verification level to ensure that time is measured correctly for each model. - -**Preprocess**: The stage where images are prepared for input into the model. This includes image reading, resizing, and any necessary transformations. - -**Inference**: The main computation phase where the _forward_ pass of the model is running. **Note:** we include not only the forward pass, but also modules like NMS (Non-Maximum Suppression), decoding module, and everything that is done to get a **meaningful** prediction. - -**Postprocess**: This stage includes tasks such as resizing output masks, aligning predictions with the input image, converting bounding boxes into a specific format or filtering out low-confidence detections. - diff --git a/dev_requirements.txt b/dev_requirements.txt index 19db2b7..f8e4733 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,2 +1,2 @@ # git+https://github.com/supervisely/supervisely.git@model-benchmark -supervisely[model-benchmark]==6.73.208 \ No newline at end of file +supervisely[model-benchmark]==6.73.215 \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index b6cba3d..0bbb6be 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,5 +1,5 @@ FROM supervisely/base-py-sdk:6.73.208 -RUN python3 -m pip install supervisely[model-benchmark]==6.73.208 +RUN python3 -m pip install supervisely[model-benchmark]==6.73.215 -LABEL python_sdk_version=6.73.208 \ No newline at end of file +LABEL python_sdk_version=6.73.215 \ No newline at end of file diff --git a/local.env b/local.env index 000e343..224b8fe 100644 --- a/local.env +++ b/local.env @@ -1,14 +1,7 @@ -# Ars -# TEAM_ID = 449 # * Required! Copy your team id here: https://dev.supervisely.com/teams/my -# WORKSPACE_ID = 691 # * Required! Copy your workspace id here: https://dev.supervisely.com/workspaces/ -# PROJECT_ID = 32796 # * Optional. Copy your project id here if needed: https://dev.supervisely.com/projects/ -# DATASET_ID = 81589 # * Optional. Copy your dataset id here if needed: https://dev.supervisely.com/projects//datasets/ -# SLY_APP_DATA_DIR = "/home/grokhi/supervisely/tasks/model-benchmark/APP_DATA" # * Optional. Path to the local folder for application data. Make sure the app will have rights to read/write from/to this folder. - -# PB -TEAM_ID = 440 -WORKSPACE_ID = 1105 -PROJECT_ID = 40299 +TEAM_ID = 447 +WORKSPACE_ID = 680 +# PROJECT_ID = 41021 SLY_APP_DATA_DIR = "APP_DATA" -TASK_ID = 64030 \ No newline at end of file +TASK_ID = 60447 +# modal.state.sessionId=66693 \ No newline at end of file diff --git a/src/functions.py b/src/functions.py index fb6d5a3..9580ad0 100644 --- a/src/functions.py +++ b/src/functions.py @@ -1,3 +1,6 @@ +import os +from typing import List, Tuple + import src.globals as g import supervisely as sly from supervisely.nn import TaskType @@ -47,3 +50,55 @@ def get_classes(): not_matched_model_cls.append(obj_class) return (matched_proj_cls, matched_model_cls), (not_matched_proj_cls, not_matched_model_cls) + + +def validate_paths(paths: List[str]): + if not paths: + raise ValueError("No paths selected") + + split_paths = [path.strip("/").split(os.sep) for path in paths] + path_length = min(len(p) for p in split_paths) + + if not all(len(p) == path_length for p in split_paths): + raise ValueError(f"Selected paths not on the correct level: {paths}") + + if not all(p.startswith("/model-benchmark") for p in paths): + raise ValueError(f"Selected paths are not in the benchmark directory: {paths}") + + if not all(p[1] == split_paths[0][1] for p in split_paths): + raise ValueError(f"Project names are different: {paths}") + + +def get_parent_paths(paths: List[str]) -> Tuple[str, List[str]]: + split_paths = [path.strip("/").split(os.sep) for path in paths] + project_name = split_paths[0][1] + eval_dirs = [p[2] for p in split_paths] + + return project_name, eval_dirs + + +def get_res_dir(eval_dirs: List[str]) -> str: + + res_dir = "/model-comparison" + project_name, eval_dirs = get_parent_paths(eval_dirs) + res_dir += "/" + project_name + "/" + res_dir += " vs ".join(eval_dirs) + + res_dir = g.api.file.get_free_dir_name(g.team_id, res_dir) + + return res_dir + + +# ! temp fix (to allow the app to receive requests) +def with_clean_up_progress(pbar): + def decorator(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + finally: + with pbar(message="Application is started ...", total=1) as pb: + pb.update(1) + + return wrapper + + return decorator diff --git a/src/globals.py b/src/globals.py index 34d190d..28f90ea 100644 --- a/src/globals.py +++ b/src/globals.py @@ -15,7 +15,6 @@ STORAGE_DIR = sly.app.get_data_dir() STATIC_DIR = os.path.join(STORAGE_DIR, "static") sly.fs.mkdir(STATIC_DIR) -TF_RESULT_DIR = "/model-benchmark/layout" deployed_nn_tags = ["deployed_nn"] @@ -27,6 +26,6 @@ if session_id is not None: session_id = int(session_id) session = None -autostart = bool(strtobool(os.environ.get("modal.state.autoStart", "false"))) -selected_classes = None \ No newline at end of file +selected_classes = None +eval_dirs = None diff --git a/src/main.py b/src/main.py index ffbb058..a5b45fc 100644 --- a/src/main.py +++ b/src/main.py @@ -1,263 +1,58 @@ -from typing import Optional +from fastapi import Request -import yaml - -import src.functions as f import src.globals as g -import src.workflow as w import supervisely as sly import supervisely.app.widgets as widgets -from supervisely.nn import TaskType -from supervisely.nn.benchmark import ( - InstanceSegmentationBenchmark, - ObjectDetectionBenchmark, -) -from supervisely.nn.benchmark.evaluation.base_evaluator import BaseEvaluator -from supervisely.nn.benchmark.evaluation.instance_segmentation_evaluator import ( - InstanceSegmentationEvaluator, -) -from supervisely.nn.benchmark.evaluation.object_detection_evaluator import ( - ObjectDetectionEvaluator, -) -from supervisely.nn.inference.session import SessionJSON - - -def main_func(): - api = g.api - project = api.project.get_info_by_id(g.project_id) - if g.session is None: - g.session = SessionJSON(api, g.session_id) - task_type = g.session.get_deploy_info()["task_type"] - - # ==================== Workflow input ==================== - w.workflow_input(api, project, g.session_id) - # ======================================================= - - report_model_benchmark.hide() - - set_selected_classes_and_show_info() - if g.selected_classes is None or len(g.selected_classes) == 0: - return - - pbar.show() - sec_pbar.show() - evaluation_parameters = yaml.safe_load(eval_params.get_value()) - if task_type == "object detection": - bm = ObjectDetectionBenchmark( - api, - project.id, - output_dir=g.STORAGE_DIR + "/benchmark", - progress=pbar, - progress_secondary=sec_pbar, - classes_whitelist=g.selected_classes, - evaluation_params=evaluation_parameters, - ) - elif task_type == "instance segmentation": - bm = InstanceSegmentationBenchmark( - api, - project.id, - output_dir=g.STORAGE_DIR + "/benchmark", - progress=pbar, - progress_secondary=sec_pbar, - classes_whitelist=g.selected_classes, - evaluation_params=evaluation_parameters, - ) - sly.logger.info(f"{g.session_id = }") - - task_info = api.task.get_info_by_id(g.session_id) - task_dir = f"{g.session_id}_{task_info['meta']['app']['name']}" - - res_dir = f"/model-benchmark/{project.id}_{project.name}/{task_dir}/" - res_dir = api.storage.get_free_dir_name(g.team_id, res_dir) - - session_info = g.session.get_session_info() - support_batch_inference = session_info.get("batch_inference_support", False) - max_batch_size = session_info.get("max_batch_size") - batch_size = 16 - if not support_batch_inference: - batch_size = 1 - if max_batch_size is not None: - batch_size = min(max_batch_size, 16) - bm.run_evaluation(model_session=g.session_id, batch_size=batch_size) - - try: - batch_sizes = (1, 8, 16) - if not support_batch_inference: - batch_sizes = (1,) - elif max_batch_size is not None: - batch_sizes = tuple([bs for bs in batch_sizes if bs <= max_batch_size]) - bm.run_speedtest(g.session_id, g.project_id, batch_sizes=batch_sizes) - sec_pbar.hide() - bm.upload_speedtest_results(res_dir + "/speedtest/") - except Exception as e: - sly.logger.warn(f"Speedtest failed. Skipping. {e}") - - bm.visualize() - - bm.upload_eval_results(res_dir + "/evaluation/") - remote_dir = bm.upload_visualizations(res_dir + "/visualizations/") - - report = bm.upload_report_link(remote_dir) - api.task.set_output_report(g.task_id, report.id, report.name) - - template_vis_file = api.file.get_info_by_path( - sly.env.team_id(), res_dir + "/visualizations/template.vue" - ) - report_model_benchmark.set(template_vis_file) - report_model_benchmark.show() - pbar.hide() - - # ==================== Workflow output ==================== - w.workflow_output(api, res_dir, template_vis_file) - # ======================================================= - - sly.logger.info( - f"Predictions project: " - f" name {bm.dt_project_info.name}, " - f" workspace_id {bm.dt_project_info.workspace_id}. " - f"Differences project: " - f" name {bm.diff_project_info.name}, " - f" workspace_id {bm.diff_project_info.workspace_id}" - ) - - button.loading = False - app.stop() - - -no_classes_label = widgets.Text( - "Not found any classes in the project that are present in the model", status="error" -) -no_classes_label.hide() -total_classes_text = widgets.Text(status="info") -selected_matched_text = widgets.Text(status="success") -not_matched_text = widgets.Text(status="warning") - -sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True) -sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id) +from src.ui.compare import compare_button, compare_contatiner, run_compare +from src.ui.evaluation import eval_button, evaluation_container, run_evaluation -eval_params = widgets.Editor( - initial_text=None, - language_mode="yaml", - height_lines=16, +tabs = widgets.Tabs( + labels=["Model Evaluation", "Model Comparison"], + contents=[evaluation_container, compare_contatiner], ) -eval_params_card = widgets.Card( - title="Evaluation parameters", - content=eval_params, - collapsable=True, +tabs_card = widgets.Card( + title="Model Benchmark", + content=tabs, + description="Select the task you want to perform", ) -eval_params_card.collapse() - - -button = widgets.Button("Evaluate") -button.disable() - -pbar = widgets.SlyTqdm() -sec_pbar = widgets.Progress("") - -report_model_benchmark = widgets.ReportThumbnail() -report_model_benchmark.hide() - -controls_card = widgets.Card( - title="Settings", - description="Select Ground Truth project and deployed model session", - content=widgets.Container( - [ - sel_project, - sel_app_session, - eval_params_card, - button, - report_model_benchmark, - pbar, - sec_pbar, - ] - ), -) - layout = widgets.Container( - widgets=[controls_card, widgets.Empty(), widgets.Empty()], # , matched_card, not_matched_card], + widgets=[tabs_card, widgets.Empty(), widgets.Empty()], direction="horizontal", fractions=[1, 1, 1], ) -main_layout = widgets.Container( - widgets=[layout, total_classes_text, selected_matched_text, not_matched_text, no_classes_label] -) - - -def set_selected_classes_and_show_info(): - matched, not_matched = f.get_classes() - _, matched_model_classes = matched - _, not_matched_model_classes = not_matched - total_classes_text.text = ( - f"{len(matched_model_classes) + len(not_matched_model_classes)} classes found in the model." - ) - selected_matched_text.text = f"{len(matched_model_classes)} classes can be used for evaluation." - not_matched_text.text = f"{len(not_matched_model_classes)} classes are not available for evaluation (not found in the GT project or have different geometry type)." - if len(matched_model_classes) > 0: - g.selected_classes = [obj_cls.name for obj_cls in matched_model_classes] - selected_matched_text.show() - if len(not_matched_model_classes) > 0: - not_matched_text.show() - else: - no_classes_label.show() - +app = sly.Application(layout=layout, static_dir=g.STATIC_DIR) +server = app.get_server() -def update_eval_params(): - if g.session is None: - g.session = SessionJSON(g.api, g.session_id) - task_type = g.session.get_deploy_info()["task_type"] - if task_type == TaskType.OBJECT_DETECTION: - params = ObjectDetectionEvaluator.load_yaml_evaluation_params() - elif task_type == TaskType.INSTANCE_SEGMENTATION: - params = InstanceSegmentationEvaluator.load_yaml_evaluation_params() - eval_params.set_text(params, language_mode="yaml") - eval_params_card.uncollapse() - -def handle_selectors(active: bool): - no_classes_label.hide() - selected_matched_text.hide() - not_matched_text.hide() - if active: - button.enable() - else: - button.disable() - - -@sel_project.value_changed -def handle_sel_project(project_id: Optional[int]): - g.project_id = project_id - active = project_id is not None and g.session_id is not None - handle_selectors(active) - - -@sel_app_session.value_changed -def handle_sel_app_session(session_id: Optional[int]): - g.session_id = session_id - active = session_id is not None and g.project_id is not None - handle_selectors(active) - - if g.session_id: - update_eval_params() - - -@button.click +@eval_button.click def start_evaluation(): - main_func() + run_evaluation() -app = sly.Application(layout=main_layout, static_dir=g.STATIC_DIR) +@compare_button.click +def start_comparison(): + run_compare() -if g.project_id: - sel_project.set_project_id(g.project_id) -if g.session_id: - sel_app_session.set_session_id(g.session_id) - update_eval_params() +@server.post("/run_evaluation") +async def evaluate(request: Request): + req = await request.json() + try: + state = req["state"] + return {"data": run_evaluation(state["session_id"], state["project_id"])} + except Exception as e: + sly.logger.error(f"Error during model evaluation: {e}") + return {"error": str(e)} -if g.autostart: - start_evaluation() -if g.project_id and g.session_id: - handle_selectors(True) +@server.post("/run_comparison") +async def compare(request: Request): + req = await request.json() + try: + state = req["state"] + return {"data": run_compare(state["eval_dirs"])} + except Exception as e: + sly.logger.error(f"Error during model comparison: {e}") + return {"error": str(e)} diff --git a/src/modal.html b/src/modal.html deleted file mode 100644 index c5fffcd..0000000 --- a/src/modal.html +++ /dev/null @@ -1,17 +0,0 @@ -
- - - - - - Enable - -
diff --git a/src/ui/compare.py b/src/ui/compare.py new file mode 100644 index 0000000..36861b3 --- /dev/null +++ b/src/ui/compare.py @@ -0,0 +1,71 @@ +from typing import List + +import src.functions as f +import src.globals as g +import src.workflow as w +import supervisely as sly +import supervisely.app.widgets as widgets +from supervisely._utils import rand_str +from supervisely.nn.benchmark.comparison.model_comparison import ModelComparison + +compare_button = widgets.Button("Compare") +comp_pbar = widgets.SlyTqdm() +models_comparison_report = widgets.ReportThumbnail( + title="Models Comparison Report", + color="#ffc084", + bg_color="#fff2e6", +) +models_comparison_report.hide() +team_files_selector = widgets.TeamFilesSelector( + g.team_id, + multiple_selection=True, + selection_file_type="folder", + max_height=350, + initial_folder="/model-benchmark", +) + +compare_contatiner = widgets.Container( + [ + team_files_selector, + compare_button, + models_comparison_report, + comp_pbar, + ] +) + + +@f.with_clean_up_progress(comp_pbar) +def run_compare(eval_dirs: List[str] = None): + workdir = g.STORAGE_DIR + "/model-comparison-" + rand_str(6) + team_files_selector.disable() + models_comparison_report.hide() + comp_pbar.show() + + g.eval_dirs = eval_dirs or team_files_selector.get_selected_paths() + f.validate_paths(g.eval_dirs) + + # ==================== Workflow input ==================== + w.workflow_input(g.api, team_files_dirs=g.eval_dirs) + # ======================================================= + + comp = ModelComparison(g.api, g.eval_dirs, progress=comp_pbar, workdir=workdir) + comp.visualize() + res_dir = f.get_res_dir(g.eval_dirs) + res_dir = comp.upload_results(g.team_id, remote_dir=res_dir, progress=comp_pbar) + + report = g.api.file.get_info_by_path(g.team_id, comp.get_report_link()) + g.api.task.set_output_report(g.task_id, report.id, report.name) + + models_comparison_report.set(report) + models_comparison_report.show() + + # ==================== Workflow output ==================== + w.workflow_output(g.api, model_comparison_report=report) + # ======================================================= + + comp_pbar.hide() + compare_button.loading = False + + sly.logger.info(f"Model comparison report uploaded to: {res_dir}") + + return res_dir diff --git a/src/ui/evaluation.py b/src/ui/evaluation.py new file mode 100644 index 0000000..8ac69ee --- /dev/null +++ b/src/ui/evaluation.py @@ -0,0 +1,253 @@ +from typing import Dict, Optional, Union + +import yaml + +import src.functions as f +import src.globals as g +import src.workflow as w +import supervisely as sly +import supervisely.app.widgets as widgets +from supervisely._utils import rand_str +from supervisely.nn import TaskType +from supervisely.nn.benchmark import ( + InstanceSegmentationBenchmark, + ObjectDetectionBenchmark, +) +from supervisely.nn.benchmark.evaluation.instance_segmentation_evaluator import ( + InstanceSegmentationEvaluator, +) +from supervisely.nn.benchmark.evaluation.object_detection_evaluator import ( + ObjectDetectionEvaluator, +) +from supervisely.nn.inference.session import SessionJSON + + +no_classes_label = widgets.Text( + "Not found any classes in the project that are present in the model", status="error" +) +no_classes_label.hide() +total_classes_text = widgets.Text(status="info") +selected_matched_text = widgets.Text(status="success") +not_matched_text = widgets.Text(status="warning") + +sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True) +sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id) + +eval_params = widgets.Editor( + initial_text=None, + language_mode="yaml", + height_lines=16, +) +eval_params_card = widgets.Card( + title="Evaluation parameters", + content=eval_params, + collapsable=True, +) +eval_params_card.collapse() + + +eval_button = widgets.Button("Evaluate") +eval_button.disable() + +eval_pbar = widgets.SlyTqdm() +sec_eval_pbar = widgets.Progress("") + +report_model_benchmark = widgets.ReportThumbnail() +report_model_benchmark.hide() + +evaluation_container = widgets.Container( + [ + sel_project, + sel_app_session, + eval_params_card, + eval_button, + report_model_benchmark, + eval_pbar, + sec_eval_pbar, + total_classes_text, + selected_matched_text, + not_matched_text, + no_classes_label, + ] +) + + +@f.with_clean_up_progress(eval_pbar) +def run_evaluation( + session_id: Optional[int] = None, + project_id: Optional[int] = None, + params: Optional[Union[str, Dict]] = None, +): + work_dir = g.STORAGE_DIR + "/benchmark_" + rand_str(6) + + if session_id is not None: + g.session_id = session_id + if project_id is not None: + g.project_id = project_id + + project = g.api.project.get_info_by_id(g.project_id) + if g.session is None: + g.session = SessionJSON(g.api, g.session_id) + task_type = g.session.get_deploy_info()["task_type"] + + # ==================== Workflow input ==================== + w.workflow_input(g.api, project, g.session_id) + # ======================================================= + + report_model_benchmark.hide() + + set_selected_classes_and_show_info() + if g.selected_classes is None or len(g.selected_classes) == 0: + return + + eval_pbar.show() + sec_eval_pbar.show() + + evaluation_params = eval_params.get_value() or params + if isinstance(evaluation_params, str): + evaluation_params = yaml.safe_load(evaluation_params) + + if task_type == TaskType.OBJECT_DETECTION: + if evaluation_params is None: + evaluation_params = ObjectDetectionEvaluator.load_yaml_evaluation_params() + evaluation_params = yaml.safe_load(evaluation_params) + bm = ObjectDetectionBenchmark( + g.api, + project.id, + output_dir=work_dir, + progress=eval_pbar, + progress_secondary=sec_eval_pbar, + classes_whitelist=g.selected_classes, + evaluation_params=evaluation_params, + ) + elif task_type == TaskType.INSTANCE_SEGMENTATION: + if evaluation_params is None: + evaluation_params = InstanceSegmentationEvaluator.load_yaml_evaluation_params() + evaluation_params = yaml.safe_load(evaluation_params) + bm = InstanceSegmentationBenchmark( + g.api, + project.id, + output_dir=work_dir, + progress=eval_pbar, + progress_secondary=sec_eval_pbar, + classes_whitelist=g.selected_classes, + evaluation_params=evaluation_params, + ) + sly.logger.info(f"{g.session_id = }") + + task_info = g.api.task.get_info_by_id(g.session_id) + task_dir = f"{g.session_id}_{task_info['meta']['app']['name']}" + + res_dir = f"/model-benchmark/{project.id}_{project.name}/{task_dir}/" + res_dir = g.api.storage.get_free_dir_name(g.team_id, res_dir) + + session_info = g.session.get_session_info() + support_batch_inference = session_info.get("batch_inference_support", False) + max_batch_size = session_info.get("max_batch_size") + batch_size = 16 + if not support_batch_inference: + batch_size = 1 + if max_batch_size is not None: + batch_size = min(max_batch_size, 16) + bm.run_evaluation(model_session=g.session_id, batch_size=batch_size) + + try: + batch_sizes = (1, 8, 16) + if not support_batch_inference: + batch_sizes = (1,) + elif max_batch_size is not None: + batch_sizes = tuple([bs for bs in batch_sizes if bs <= max_batch_size]) + bm.run_speedtest(g.session_id, g.project_id, batch_sizes=batch_sizes) + sec_eval_pbar.hide() + bm.upload_speedtest_results(res_dir + "/speedtest/") + except Exception as e: + sly.logger.warning(f"Speedtest failed. Skipping. {e}") + + bm.visualize() + + bm.upload_eval_results(res_dir + "/evaluation/") + remote_dir = bm.upload_visualizations(res_dir + "/visualizations/") + + report = bm.upload_report_link(remote_dir) + g.api.task.set_output_report(g.task_id, report.id, report.name) + + template_vis_file = g.api.file.get_info_by_path( + sly.env.team_id(), res_dir + "/visualizations/template.vue" + ) + report_model_benchmark.set(template_vis_file) + report_model_benchmark.show() + eval_pbar.hide() + + # ==================== Workflow output ==================== + w.workflow_output(g.api, res_dir, template_vis_file) + # ======================================================= + + sly.logger.info( + f"Predictions project: " + f" name {bm.dt_project_info.name}, " + f" workspace_id {bm.dt_project_info.workspace_id}. " + f"Differences project: " + f" name {bm.diff_project_info.name}, " + f" workspace_id {bm.diff_project_info.workspace_id}" + ) + + eval_button.loading = False + + return res_dir + + +def set_selected_classes_and_show_info(): + matched, not_matched = f.get_classes() + _, matched_model_classes = matched + _, not_matched_model_classes = not_matched + total_classes_text.text = ( + f"{len(matched_model_classes) + len(not_matched_model_classes)} classes found in the model." + ) + selected_matched_text.text = f"{len(matched_model_classes)} classes can be used for evaluation." + not_matched_text.text = f"{len(not_matched_model_classes)} classes are not available for evaluation (not found in the GT project or have different geometry type)." + if len(matched_model_classes) > 0: + g.selected_classes = [obj_cls.name for obj_cls in matched_model_classes] + selected_matched_text.show() + if len(not_matched_model_classes) > 0: + not_matched_text.show() + else: + no_classes_label.show() + + +def update_eval_params(): + if g.session is None: + g.session = SessionJSON(g.api, g.session_id) + task_type = g.session.get_deploy_info()["task_type"] + if task_type == TaskType.OBJECT_DETECTION: + params = ObjectDetectionEvaluator.load_yaml_evaluation_params() + elif task_type == TaskType.INSTANCE_SEGMENTATION: + params = InstanceSegmentationEvaluator.load_yaml_evaluation_params() + eval_params.set_text(params, language_mode="yaml") + eval_params_card.uncollapse() + + +def handle_selectors(active: bool): + no_classes_label.hide() + selected_matched_text.hide() + not_matched_text.hide() + if active: + eval_button.enable() + else: + eval_button.disable() + + +@sel_project.value_changed +def handle_sel_project(project_id: Optional[int]): + g.project_id = project_id + active = project_id is not None and g.session_id is not None + handle_selectors(active) + + +@sel_app_session.value_changed +def handle_sel_app_session(session_id: Optional[int]): + g.session_id = session_id + active = session_id is not None and g.project_id is not None + handle_selectors(active) + + if g.session_id: + update_eval_params() diff --git a/src/ui/inference_speed.py b/src/ui/inference_speed.py deleted file mode 100644 index d9490b4..0000000 --- a/src/ui/inference_speed.py +++ /dev/null @@ -1,97 +0,0 @@ -import os -import random -from collections import defaultdict - -import numpy as np -import pandas as pd -import plotly.express as px -import plotly.graph_objects as go -from matplotlib import pyplot as plt -from pycocotools.coco import COCO -from pycocotools.cocoeval import COCOeval, Params - -import src.globals as g -import supervisely as sly -from supervisely.app.widgets import ( - Button, - Card, - Collapse, - Container, - DatasetThumbnail, - IFrame, - Markdown, - NotificationBox, - OneOf, - SelectDataset, - Switch, - Table, - Text, -) - -markdown_inference_speed_1 = Markdown( - """ -## Inference speed - -We evaluate the inference speed in two scenarios: real-time inference (batch size is 1), and batch processing. We also run the model in optimized runtime environments, such as ONNX Runtime and Tensor RT, using consistent hardware. This approach provides a fair comparison of model efficiency and speed. To assess the inference speed we run the model forward 100 times and average it. -""", - show_border=False, -) -collapsables = Collapse( - [ - Collapse.Item( - "Methodology", - "Methodology", - Container( - [ - Markdown( - """ -Setting 1: **Real-time processing** - -We measure the time spent processing each image individually by setting batch size to 1. This simulates real-time data processing conditions, such as those encountered in video streams, ensuring the model performs effectively in scenarios where data is processed frame by frame. - -Setting 2: **Parallel processing** - -To evaluate the model's efficiency in parallel processing, we measure the processing speed with batch size of 8 and 16. This helps us understand how well the model scales when processing multiple images simultaneously, which is crucial for applications requiring high throughput. - -Setting 3: **Optimized runtime** - -We run the model in various runtime environments, including **ONNX Runtime** and **TensorRT**. This is important because python code can be suboptimal. These runtimes often provide significant performance improvements. -""", - show_border=False, - ), - ] - ), - ) - ] -) -markdown_inference_speed_2 = Markdown( - """ -#### Consistent hardware for fair comparison - -To ensure a fair comparison, we use a single hardware setup, specifically an NVIDIA RTX 3060 GPU. - -#### Inference details - -We divide the inference process into three stages: **preprocess, inference,** and **postprocess** to provide insights into where optimization efforts should be focused. Additionally, it gives us another verification level to ensure that time is measured correctly for each model. - -#### Preprocess - -The stage where images are prepared for input into the model. This includes image reading, resizing, and any necessary transformations. - -#### Inference - -The main computation phase where the _forward_ pass of the model is running. **Note:** we include not only the forward pass, but also modules like NMS (Non-Maximum Suppression), decoding module, and everything that is done to get a **meaningful** prediction. - -#### Postprocess - -This stage includes tasks such as resizing output masks, aligning predictions with the input image, converting bounding boxes into a specific format or filtering out low-confidence detections. -""", - show_border=False, -) -container = Container( - widgets=[ - markdown_inference_speed_1, - collapsables, - markdown_inference_speed_2, - ] -) diff --git a/src/workflow.py b/src/workflow.py index c867ded..fa32ba8 100644 --- a/src/workflow.py +++ b/src/workflow.py @@ -1,69 +1,102 @@ # This module contains functions that are used to configure the input and output of the workflow for the current app, # and versioning feature that creates a project version before the task starts. +from typing import List, Optional + import supervisely as sly def workflow_input( api: sly.Api, - project_info: sly.ProjectInfo, - session_id: int, + project_info: Optional[sly.ProjectInfo] = None, + session_id: Optional[int] = None, + team_files_dirs: Optional[List[str]] = None, ): - # Create a project version before the task starts - try: - project_version_id = api.project.version.create( - project_info, - f"Evaluator for Model Benchmark", - f"This backup was created automatically by Supervisely before the Evaluator for Model Benchmark task with ID: {api.task_id}", - ) - except Exception as e: - sly.logger.debug(f"Failed to create a project version: {repr(e)}") - project_version_id = None + if project_info: + # Create a project version before the task starts + try: + project_version_id = api.project.version.create( + project_info, + f"Evaluator for Model Benchmark", + f"This backup was created automatically by Supervisely before the Evaluator for Model Benchmark task with ID: {api.task_id}", + ) + except Exception as e: + sly.logger.debug(f"Failed to create a project version: {repr(e)}") + project_version_id = None - # Add input project to the workflow - try: - if project_version_id is None: - project_version_id = ( - project_info.version.get("id", None) if project_info.version else None + # Add input project to the workflow + try: + if project_version_id is None: + project_version_id = ( + project_info.version.get("id", None) if project_info.version else None + ) + api.app.workflow.add_input_project(project_info.id, version_id=project_version_id) + sly.logger.debug( + f"Workflow Input: Project ID - {project_info.id}, Project Version ID - {project_version_id}" ) - api.app.workflow.add_input_project(project_info.id, version_id=project_version_id) - sly.logger.debug( - f"Workflow Input: Project ID - {project_info.id}, Project Version ID - {project_version_id}" - ) - except Exception as e: - sly.logger.debug(f"Failed to add input to the workflow: {repr(e)}") + except Exception as e: + sly.logger.debug(f"Failed to add input to the workflow: {repr(e)}") - # Add input model session to the workflow - try: - api.app.workflow.add_input_task(session_id) - sly.logger.debug(f"Workflow Input: Session ID - {session_id}") - except Exception as e: - sly.logger.debug(f"Failed to add input to the workflow: {repr(e)}") + # Add input model session to the workflow + try: + api.app.workflow.add_input_task(session_id) + sly.logger.debug(f"Workflow Input: Session ID - {session_id}") + except Exception as e: + sly.logger.debug(f"Failed to add input to the workflow: {repr(e)}") + + if team_files_dirs: + # Add input evaluation results folders to the workflow + try: + for team_files_dir in team_files_dirs: + api.app.workflow.add_input_folder(team_files_dir) + sly.logger.debug(f"Workflow Input: Team Files dir - {team_files_dir}") + except Exception as e: + sly.logger.debug(f"Failed to add input to the workflow: {repr(e)}") def workflow_output( api: sly.Api, - eval_team_files_dir: str, - model_benchmark_report: sly.api.file_api.FileInfo, + eval_team_files_dir: Optional[str] = None, + model_benchmark_report: Optional[sly.api.file_api.FileInfo] = None, + model_comparison_report: Optional[sly.api.file_api.FileInfo] = None, ): - try: - # Add output evaluation results folder to the workflow - eval_dir_relation_settings = sly.WorkflowSettings(title="Evaluation Artifacts") - eval_dir_meta = sly.WorkflowMeta(relation_settings=eval_dir_relation_settings) - api.app.workflow.add_output_folder(eval_team_files_dir, meta=eval_dir_meta) - sly.logger.debug(f"Workflow Output: Team Files dir - {eval_team_files_dir}") + if model_benchmark_report: + try: + # Add output evaluation results folder to the workflow + eval_dir_relation_settings = sly.WorkflowSettings(title="Evaluation Artifacts") + eval_dir_meta = sly.WorkflowMeta(relation_settings=eval_dir_relation_settings) + api.app.workflow.add_output_folder(eval_team_files_dir, meta=eval_dir_meta) + sly.logger.debug(f"Workflow Output: Team Files dir - {eval_team_files_dir}") - # Add output model benchmark report to the workflow - mb_relation_settings = sly.WorkflowSettings( - title="Model Benchmark", - icon="assignment", - icon_color="#674EA7", - icon_bg_color="#CCCCFF", - url=f"/model-benchmark?id={model_benchmark_report.id}", - url_title="Open Report", - ) - meta = sly.WorkflowMeta(relation_settings=mb_relation_settings) - api.app.workflow.add_output_file(model_benchmark_report, meta=meta) - sly.logger.debug("Model Benchmark Report ID - {model_benchmark_report.id}") + # Add output model benchmark report to the workflow + mb_relation_settings = sly.WorkflowSettings( + title="Model Evaluation", + icon="assignment", + icon_color="#dcb0ff", + icon_bg_color="#faebff", + url=f"/model-benchmark?id={model_benchmark_report.id}", + url_title="Open Benchmark Report", + ) + meta = sly.WorkflowMeta(relation_settings=mb_relation_settings) + api.app.workflow.add_output_file(model_benchmark_report, meta=meta) + sly.logger.debug(f"Model Evaluation Report ID - {model_benchmark_report.id}") + + except Exception as e: + sly.logger.debug(f"Failed to add output to the workflow: {repr(e)}") + + if model_comparison_report: + try: + # Add output model benchmark report to the workflow + comparison_relation_settings = sly.WorkflowSettings( + title="Model Evaluation", + icon="assignment", + icon_color="#ffc084", + icon_bg_color="#fff2e6", + url=f"/model-benchmark?id={model_comparison_report.id}", + url_title="Open Comparison Report", + ) + meta = sly.WorkflowMeta(relation_settings=comparison_relation_settings) + api.app.workflow.add_output_file(model_comparison_report, meta=meta) + sly.logger.debug(f"Model Comparison Report ID - {model_comparison_report.id}") - except Exception as e: - sly.logger.debug(f"Failed to add output to the workflow: {repr(e)}") + except Exception as e: + sly.logger.debug(f"Failed to add output to the workflow: {repr(e)}")