Skip to content

Commit

Permalink
Changed labeling mechanism - added model parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
iXab3r committed Mar 26, 2024
1 parent 239106b commit 9911389
Show file tree
Hide file tree
Showing 14 changed files with 333 additions and 49 deletions.
56 changes: 41 additions & 15 deletions Sources/YoloEase.UI/Core/Yolo8PredictAccessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
using System.Linq;
using System.Threading;
using JetBrains.Annotations;
using Microsoft.ML.OnnxRuntime;
using PoeShared.Dialogs.Services;
using PoeShared.Services;
using YoloEase.UI.Dto;
using YoloEase.UI.Scaffolding;
using YoloEase.UI.Yolo;

namespace YoloEase.UI.Core;



public class Yolo8PredictAccessor : RefreshableReactiveObject
{
private static readonly Binder<Yolo8PredictAccessor> Binder = new();
Expand All @@ -19,18 +21,15 @@ static Yolo8PredictAccessor()
}

private readonly Yolo8CliWrapper cliWrapper;
private readonly IUniqueIdGenerator idGenerator;
private readonly IOpenFileDialog openFileDialog;
private readonly IScheduler uiScheduler;

public Yolo8PredictAccessor(
Yolo8CliWrapper cliWrapper,
IUniqueIdGenerator idGenerator,
IOpenFileDialog openFileDialog,
[Dependency(WellKnownSchedulers.UI)] IScheduler uiScheduler)
{
this.cliWrapper = cliWrapper;
this.idGenerator = idGenerator;
this.openFileDialog = openFileDialog;
this.uiScheduler = uiScheduler;

Expand Down Expand Up @@ -75,7 +74,7 @@ public void LoadModel(FileInfo fileInfo)
};
}

private async Task<DatasetPredictInfo> GetPredictionsOrDefault(TrainedModelFileInfo modelFileInfo, DirectoryInfo predictionsDir)
private async Task<DatasetPredictInfo> GetPredictionsOrDefault(TrainedModelFileInfo modelFileInfo, DirectoryInfo predictionsDir, YoloModelDescription yoloModelDescription)
{
var modelDirectory = GetModelPredictionsFolder(modelFileInfo);
if (!modelDirectory.Exists)
Expand All @@ -87,7 +86,7 @@ private async Task<DatasetPredictInfo> GetPredictionsOrDefault(TrainedModelFileI
{
return null;
}
var predictions = ParsePredictions(predictionsDir).ToArray();
var predictions = ParsePredictions(predictionsDir, yoloModelDescription).ToArray();
return new DatasetPredictInfo()
{
OutputDirectory = modelDirectory,
Expand All @@ -108,42 +107,56 @@ public async Task<DatasetPredictInfo> Predict(
throw new DirectoryNotFoundException($"Model file not found @ {modelFile.FullName}");
}

Log.Info($"Loading model from file {modelFile.FullName}");
var modelData = await File.ReadAllBytesAsync(modelFile.FullName, cancellationToken);
var modelOptions = new SessionOptions();
using var yoloModel = new YoloModel(modelData, modelOptions);
Log.Info($"Loaded model from file {modelFile.FullName}: {yoloModel.Description.Dump()}");

var modelDirectory = GetModelPredictionsFolder(modelFileInfo);
if (modelDirectory.Exists)
{
modelDirectory.Delete(recursive: true);
}

modelDirectory.Create();

var predictDirectory = await cliWrapper.Predict(new Yolo8PredictArguments()
{
Model = modelFileInfo.ModelFile.FullName,
WorkingDirectory = modelDirectory,
Source = inputDirectory.FullName,
Confidence = ConfidenceThresholdPercentage / 100,
IoU = IoUThresholdPercentage / 100,
ImageSize = yoloModel.Description.Size.Width.ToString(),
AdditionalArguments = PredictAdditionalArguments,
}, updateHandler: updateHandler, cancellationToken: cancellationToken);

return await GetPredictionsOrDefault(modelFileInfo, predictDirectory);
return await GetPredictionsOrDefault(modelFileInfo, predictDirectory, yoloModel.Description);
}

private static IEnumerable<PredictInfo> ParsePredictions(
DirectoryInfo predictDirectory)
DirectoryInfo predictDirectory,
YoloModelDescription modelDescription)
{
var result = new List<PredictInfo>();
var labelsDirectory = new DirectoryInfo(Path.Combine(predictDirectory.FullName, "labels"));
if (!labelsDirectory.Exists)
{
return Array.Empty<PredictInfo>();
}

if (modelDescription.Labels.IsEmpty())
{
throw new ArgumentException($"Model does not contain valid labels");
}
var labelsByClassIdx = modelDescription.Labels.ToDictionary(x => x.Id, x => x);

var predictImages = predictDirectory.GetFiles("*.*", SearchOption.TopDirectoryOnly);
foreach (var predictImage in predictImages)
{
var imageSize = ImageUtils.GetImageSize(predictImage);
var labels = new List<YoloLabel>();
var labels = new List<YoloPredictionInfo>();
var labelFileName = new FileInfo(Path.Combine(labelsDirectory.FullName, Path.ChangeExtension(predictImage.Name, "txt")));
if (labelFileName.Exists)
{
Expand All @@ -162,15 +175,28 @@ private static IEnumerable<PredictInfo> ParsePredictions(
var prediction = new PredictInfo()
{
File = predictImage,
Labels = labels.ToArray()
Labels = labels.Select(x =>
{
if (!labelsByClassIdx.TryGetValue(x.ClassIdx, out var label))
{
throw new ArgumentException($"Failed to map class Idx {x.ClassIdx}, to label, known labels: {labelsByClassIdx.DumpToString()}");
}

return new YoloPrediction()
{
BoundingBox = x.BoundingBox,
Score = x.Score,
Label = label
};
}).ToArray()
};
result.Add(prediction);
}

return result;
}

private static IEnumerable<YoloLabel> ParseYoloLabels(FileInfo file)
private static IEnumerable<YoloPredictionInfo> ParseYoloLabels(FileInfo file)
{
using var reader = file.OpenText();

Expand All @@ -191,11 +217,11 @@ private static IEnumerable<YoloLabel> ParseYoloLabels(FileInfo file)
var height = float.Parse(parts[4], CultureInfo.InvariantCulture);
var confidence = float.Parse(parts[5], CultureInfo.InvariantCulture);

yield return new YoloLabel
yield return new YoloPredictionInfo
{
Id = id,
ClassIdx = id,
BoundingBox = RectangleD.FromYolo(centerX, centerY, width, height),
Confidence = confidence
Score = confidence
};
}
}
Expand Down
4 changes: 3 additions & 1 deletion Sources/YoloEase.UI/Dto/PredictInfo.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using YoloEase.UI.Yolo;

namespace YoloEase.UI.Dto;

public sealed record PredictInfo
{
public FileInfo File { get; init; }

public YoloLabel[] Labels { get; init; } = Array.Empty<YoloLabel>();
public YoloPrediction[] Labels { get; init; } = Array.Empty<YoloPrediction>();
}
8 changes: 0 additions & 8 deletions Sources/YoloEase.UI/Dto/YoloLabel.cs

This file was deleted.

11 changes: 10 additions & 1 deletion Sources/YoloEase.UI/Properties/GlobalUsings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,13 @@
global using System;

global using FileInfo = System.IO.FileInfo;
global using Type = System.Type;
global using Type = System.Type;
global using WinSize = System.Drawing.Size;
global using WinSizeF = System.Drawing.SizeF;
global using WinPoint = System.Drawing.Point;
global using WinPointF = System.Drawing.PointF;
global using WinRect = System.Drawing.Rectangle;
global using WinRectangle = System.Drawing.Rectangle;
global using WinRectangleF = System.Drawing.RectangleF;
global using WpfColor = System.Windows.Media.Color;
global using WinColor = System.Drawing.Color;
7 changes: 4 additions & 3 deletions Sources/YoloEase.UI/TrainingTimeline/AutomaticTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using PoeShared.Common;
using YoloEase.UI.Core;
using YoloEase.UI.Dto;
using YoloEase.UI.Yolo;

namespace YoloEase.UI.TrainingTimeline;

Expand All @@ -19,7 +20,7 @@ public class AutomaticTrainer : RefreshableReactiveObject, ICanBeSelected
private readonly IFactory<PreTrainingTimelineEntry, TimelineController, DatasetInfo, Yolo8DatasetAccessor> preTrainingEntryFactory;
private readonly IFactory<TrainingTimelineEntry, TimelineController, DatasetInfo, Yolo8DatasetAccessor, Yolo8PredictAccessor> trainingEntryFactory;
private readonly IFactory<PredictTimelineEntry, TimelineController, TrainedModelFileInfo, DirectoryInfo, Yolo8DatasetAccessor, Yolo8PredictAccessor> predictEntryFactory;
private readonly CircularSourceList<TimelineEntry> timelineSource = new(200);
private readonly CircularSourceList<TimelineEntry> timelineSource = new(100);
private readonly TimelineController timelineController;
private CancellationTokenSource activeTrainingCancellationTokenSource;

Expand Down Expand Up @@ -112,7 +113,7 @@ private void RecalculateAutoAnnotationStats(float confidence, DatasetPredictInfo
var predictions = datasetPredictions.Predictions
.Select(x => x with
{
Labels = x.Labels.EmptyIfNull().Where(y => y.Confidence >= confidence).ToArray()
Labels = x.Labels.EmptyIfNull().Where(y => y.Score >= confidence).ToArray()
})
.Where(x => x.Labels.Any())
.ToDictionary(x => x.File.Name);
Expand All @@ -121,7 +122,7 @@ private void RecalculateAutoAnnotationStats(float confidence, DatasetPredictInfo
.Select(x =>
{
var predictionsForFile = predictions.GetValueOrDefault(x.Name);
return new {File = x, Labels = predictionsForFile?.Labels ?? Array.Empty<YoloLabel>(), Score = predictionsForFile?.Labels.Max(y => y.Confidence)};
return new {File = x, Labels = predictionsForFile?.Labels ?? Array.Empty<YoloPrediction>(), Score = predictionsForFile?.Labels.Max(y => y.Score)};
})
.OrderByDescending(x => x.Score)
.ToArray();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
}
else
{
var labels = predictions.Predictions.SelectMany(x => x.Labels).Select(x => x.Confidence ?? 0).ToArray();
var labels = predictions.Predictions.SelectMany(x => x.Labels).Select(x => x.Score).ToArray();
<div class="d-flex gap-1">
<div class="badge bg-success">
@($"prediction".ToQuantity(predictions.Predictions.Length))
Expand Down
26 changes: 8 additions & 18 deletions Sources/YoloEase.UI/TrainingTimeline/CreateTaskTimelineEntry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Humanizer;
using YoloEase.UI.Core;
using YoloEase.UI.Dto;
using YoloEase.UI.Yolo;

namespace YoloEase.UI.TrainingTimeline;

Expand Down Expand Up @@ -72,7 +73,7 @@ private async Task<FileInfo[]> PickAnnotated()
var predictions = this.DatasetPredictions.Predictions
.Select(x => x with
{
Labels = x.Labels.EmptyIfNull().Where(y => y.Confidence >= AutoAnnotateConfidenceThreshold).ToArray()
Labels = x.Labels.EmptyIfNull().Where(y => y.Score >= AutoAnnotateConfidenceThreshold).ToArray()
})
.Where(x => x.Labels.Any())
.ToDictionary(x => x.File.Name);
Expand All @@ -81,7 +82,7 @@ private async Task<FileInfo[]> PickAnnotated()
.Select(x =>
{
var predictionsForFile = predictions.GetValueOrDefault(x.Name);
return new { File = x, Labels = predictionsForFile?.Labels ?? Array.Empty<YoloLabel>(), Score = predictionsForFile?.Labels.Max(y => y.Confidence)};
return new { File = x, Labels = predictionsForFile?.Labels ?? Array.Empty<YoloPrediction>(), Score = predictionsForFile?.Labels.Max(y => y.Score)};
})
.OrderByDescending(x => x.Score)
.ToArray();
Expand Down Expand Up @@ -111,41 +112,30 @@ private async Task<AnnotationsRead> UploadAnnotations(FileInfo[] files, TaskRead

var projectLabelsByName = cvatProjectAccessor.Labels.Items
.ToDictionary(x => x.Name, x => x);

var yoloLabelsById = projectLabelsByName
.OrderBy(x => x.Value.Id)
.Select(x => x.Value.Name)
.Select((labelName, idx) => new {x = labelName, idx})
.ToDictionary(x => x.idx, x => x.x);

var labels = files
.Select(x => predictions.GetValueOrDefault(x.Name))
.Where(x => x != null)
.Select(x => x with
{
Labels = x.Labels.EmptyIfNull().Where(y => y.Confidence >= AutoAnnotateConfidenceThreshold).ToArray()
Labels = x.Labels.EmptyIfNull().Where(y => y.Score >= AutoAnnotateConfidenceThreshold).ToArray()
})
.Where(x => x.Labels.Any())
.Select(x => x.Labels.Select(label =>
.Select(x => x.Labels.Select(prediction =>
{
if (!taskFramesByFileName.TryGetValue(x.File.Name, out var taskFrame))
{
throw new InvalidStateException($"Failed to resolve frame using name {x.File.Name}");
}

if (!yoloLabelsById.TryGetValue(label.Id, out var labelName))
if (!projectLabelsByName.TryGetValue(prediction.Label.Name, out var cvatLabel))
{
throw new InvalidStateException($"Failed to resolve Yolo label using Id {label.Id}, known labels: {yoloLabelsById.DumpToString()}");
}

if (!projectLabelsByName.TryGetValue(labelName, out var cvatLabel))
{
throw new InvalidStateException($"Failed to resolve CVAT label using Name {labelName}, known labels: {projectLabelsByName.DumpToString()}");
throw new InvalidStateException($"Failed to resolve CVAT label using Name {prediction.Label.Name}, known labels: {projectLabelsByName.DumpToString()}");
}

return new CvatRectangleAnnotation()
{
BoundingBox = label.BoundingBox,
BoundingBox = prediction.BoundingBox,
LabelId = cvatLabel.Id.Value,
FrameIndex = taskFrame.FrameIdx
};
Expand Down
4 changes: 2 additions & 2 deletions Sources/YoloEase.UI/TrainingTimeline/PredictTimelineEntry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ protected override async Task<DatasetPredictInfo> RunInternal(CancellationToken
var labelsById = predictionResults.Predictions
.Select(x => x.Labels)
.SelectMany(x => x)
.GroupBy(x => x.Id)
.Select(x => new { Id = x.Key, Count = x.Count(), AvgConfidence = x.Average(y => y.Confidence) })
.GroupBy(x => x.Label.Name)
.Select(x => new { Id = x.Key, Count = x.Count(), AvgConfidence = x.Average(y => y.Score) })
.ToArray();

Text = $"Prediction completed in {sw.Elapsed.Humanize(culture: CultureInfo.InvariantCulture)}, images: {predictionResults.Predictions.Length}, labels: {labelsById.Select(x => $"Id: {x.Id}, Count: {x.Count}, AvgConf: {x.AvgConfidence}").DumpToString()}";
Expand Down
30 changes: 30 additions & 0 deletions Sources/YoloEase.UI/Yolo/YoloLabel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
namespace YoloEase.UI.Yolo;

/// <summary>
/// Represents a single label as identified by the YOLO (You Only Look Once) object detection system.
/// This record structure encapsulates the unique attributes of a detected object's label,
/// including its identifier, name, kind, and associated color.
/// </summary>
public readonly record struct YoloLabel
{
/// <summary>
/// Gets the unique identifier for this label.
/// The ID is an integer value that uniquely represents a specific class or type of object
/// detected by the YOLO system.
/// </summary>
public int Id { get; init; }

/// <summary>
/// Gets the name of the label.
/// This is a human-readable string that describes the class or type of the object detected,
/// such as 'car', 'person', etc.
/// </summary>
public string Name { get; init; }

/// <summary>
/// Gets the color associated with this label.
/// This color is typically used for visualization purposes, such as drawing bounding boxes
/// or segmentation masks in the color corresponding to the detected object's class.
/// </summary>
public WinColor Color { get; init; }
}
Loading

0 comments on commit 9911389

Please sign in to comment.