From 772082f4d77b8b0bc1baf62c65372224b3cb396e Mon Sep 17 00:00:00 2001 From: Russ Cam Date: Wed, 20 Nov 2024 16:32:35 +1000 Subject: [PATCH] Support CLI cancellation --- src/RankLib.Cli/EvaluateCommand.cs | 12 ++- src/RankLib/Eval/Evaluator.cs | 77 +++++++++++-------- src/RankLib/Learning/Boosting/AdaRank.cs | 20 +++-- src/RankLib/Learning/Boosting/RankBoost.cs | 15 +++- src/RankLib/Learning/CoordinateAscent.cs | 9 ++- src/RankLib/Learning/IRanker.cs | 9 ++- src/RankLib/Learning/LinearRegression.cs | 9 ++- src/RankLib/Learning/NeuralNet/ListNet.cs | 10 ++- src/RankLib/Learning/NeuralNet/RankNet.cs | 12 ++- src/RankLib/Learning/Ranker.cs | 18 ++++- src/RankLib/Learning/RankerTrainer.cs | 16 ++-- src/RankLib/Learning/Tree/FeatureHistogram.cs | 49 ++++++------ src/RankLib/Learning/Tree/LambdaMART.cs | 38 +++++---- src/RankLib/Learning/Tree/MART.cs | 2 +- src/RankLib/Learning/Tree/RandomForests.cs | 14 +++- src/RankLib/Learning/Tree/RegressionTree.cs | 6 +- src/RankLib/Learning/Tree/Split.cs | 4 +- 17 files changed, 208 insertions(+), 112 deletions(-) diff --git a/src/RankLib.Cli/EvaluateCommand.cs b/src/RankLib.Cli/EvaluateCommand.cs index f66ecf3..b3bee37 100644 --- a/src/RankLib.Cli/EvaluateCommand.cs +++ b/src/RankLib.Cli/EvaluateCommand.cs @@ -468,7 +468,8 @@ await evaluator.EvaluateAsync( tvSplit, kcvModelDir!.FullName, kcvModelFile!, - rankerParameters).ConfigureAwait(false); + rankerParameters, + cancellationToken).ConfigureAwait(false); } else { @@ -481,7 +482,8 @@ await evaluator.EvaluateAsync( featureDescriptionFile?.FullName, ttSplit, options.ModelOutputFile?.FullName, - rankerParameters).ConfigureAwait(false); + rankerParameters, + cancellationToken).ConfigureAwait(false); } else if (tvSplit > 0.0) { @@ -492,7 +494,8 @@ await evaluator.EvaluateAsync( testFiles.LastOrDefault(), featureDescriptionFile?.FullName, options.ModelOutputFile?.FullName, - rankerParameters).ConfigureAwait(false); + rankerParameters, + cancellationToken).ConfigureAwait(false); } else { @@ -503,7 +506,8 @@ await evaluator.EvaluateAsync( testFiles.LastOrDefault(), featureDescriptionFile?.FullName, options.ModelOutputFile?.FullName, - rankerParameters).ConfigureAwait(false); + rankerParameters, + cancellationToken).ConfigureAwait(false); } } } diff --git a/src/RankLib/Eval/Evaluator.cs b/src/RankLib/Eval/Evaluator.cs index 4362564..d57afab 100644 --- a/src/RankLib/Eval/Evaluator.cs +++ b/src/RankLib/Eval/Evaluator.cs @@ -145,6 +145,7 @@ public double Evaluate(IRanker? ranker, List rankLists) /// The feature definitions /// A path to save the trained ranker to /// The ranker parameters + /// Token that can be used to cancel the operation /// The ranker type is not an public async Task EvaluateAsync( Type rankerType, @@ -153,7 +154,8 @@ public async Task EvaluateAsync( string? testFile, string? featureDefinitionFile, string? modelFile = null, - IRankerParameters? parameters = default) + IRankerParameters? parameters = default, + CancellationToken cancellationToken = default) { if (!typeof(IRanker).IsAssignableFrom(rankerType)) throw new ArgumentException($"Ranker type {rankerType} is not a ranker"); @@ -172,7 +174,7 @@ public async Task EvaluateAsync( Normalize(test, features); } - var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters) + var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters, cancellationToken) .ConfigureAwait(false); if (test != null) @@ -183,7 +185,7 @@ public async Task EvaluateAsync( if (!string.IsNullOrEmpty(modelFile)) { - await ranker.SaveAsync(modelFile); + await ranker.SaveAsync(modelFile, cancellationToken).ConfigureAwait(false); _logger.LogInformation("Model saved to: {ModelFile}", modelFile); } } @@ -198,16 +200,18 @@ public async Task EvaluateAsync( /// The feature definitions /// A path to save the trained ranker to /// The ranker parameters + /// Token that can be used to cancel the operation public Task EvaluateAsync( string trainFile, string? validationFile, string? testFile, string? featureDefinitionFile, string? modelFile = null, - TRankerParameters? parameters = default) + TRankerParameters? parameters = default, + CancellationToken cancellationToken = default) where TRanker : IRanker where TRankerParameters : IRankerParameters => - EvaluateAsync(typeof(TRanker), trainFile, validationFile, testFile, featureDefinitionFile, modelFile, parameters); + EvaluateAsync(typeof(TRanker), trainFile, validationFile, testFile, featureDefinitionFile, modelFile, parameters, cancellationToken); /// /// Evaluates a new instance of specified by and @@ -220,6 +224,7 @@ public Task EvaluateAsync( /// The percentage of to use for training data. /// A path to save the trained ranker to /// The ranker parameters + /// Token that can be used to cancel the operation /// The ranker type is not an public async Task EvaluateAsync( Type rankerType, @@ -228,7 +233,8 @@ public async Task EvaluateAsync( string? featureDefinitionFile, double percentTrain, string? modelFile = null, - IRankerParameters? parameters = default) + IRankerParameters? parameters = default, + CancellationToken cancellationToken = default) { var train = new List(); var test = new List(); @@ -238,7 +244,7 @@ public async Task EvaluateAsync( if (_normalize && validation != null) Normalize(validation, features); - var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters) + var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters, cancellationToken) .ConfigureAwait(false); var rankScore = Evaluate(ranker, test); @@ -246,7 +252,7 @@ public async Task EvaluateAsync( if (!string.IsNullOrEmpty(modelFile)) { - await ranker.SaveAsync(modelFile); + await ranker.SaveAsync(modelFile, cancellationToken); _logger.LogInformation("Model saved to: {ModelFile}", modelFile); } } @@ -261,16 +267,18 @@ public async Task EvaluateAsync( /// The percentage of to use for training data. /// A path to save the trained ranker to /// The ranker parameters + /// Token that can be used to cancel the operation public Task EvaluateAsync( string sampleFile, string? validationFile, string featureDefinitionFile, double percentTrain, string? modelFile = null, - TRankerParameters? parameters = default) + TRankerParameters? parameters = default, + CancellationToken cancellationToken = default) where TRanker : IRanker where TRankerParameters : IRankerParameters => - EvaluateAsync(typeof(TRanker), sampleFile, validationFile, featureDefinitionFile, percentTrain, modelFile, parameters); + EvaluateAsync(typeof(TRanker), sampleFile, validationFile, featureDefinitionFile, percentTrain, modelFile, parameters, cancellationToken); /// /// Evaluates a new instance of specified by and @@ -283,6 +291,7 @@ public Task EvaluateAsync( /// The feature definitions /// A path to save the trained ranker to /// The ranker parameters + /// Token that can be used to cancel the operation /// The ranker type is not an public async Task EvaluateAsync( Type rankerType, @@ -291,7 +300,8 @@ public async Task EvaluateAsync( string? testFile, string? featureDefinitionFile, string? modelFile = null, - IRankerParameters? parameters = default) + IRankerParameters? parameters = default, + CancellationToken cancellationToken = default) { var train = new List(); var validation = new List(); @@ -301,7 +311,7 @@ public async Task EvaluateAsync( if (_normalize && test != null) Normalize(test, features); - var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters) + var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters, cancellationToken) .ConfigureAwait(false); if (test != null) @@ -312,7 +322,7 @@ public async Task EvaluateAsync( if (!string.IsNullOrEmpty(modelFile)) { - await ranker.SaveAsync(modelFile); + await ranker.SaveAsync(modelFile, cancellationToken).ConfigureAwait(false); _logger.LogInformation("Model saved to: {ModelFile}", modelFile); } } @@ -327,16 +337,18 @@ public async Task EvaluateAsync( /// The feature definitions /// A path to save the trained ranker to /// The ranker parameters + /// Token that can be used to cancel the operation public Task EvaluateAsync( string trainFile, double percentTrain, string? testFile, string featureDefinitionFile, string? modelFile = null, - TRankerParameters? parameters = default) + TRankerParameters? parameters = default, + CancellationToken cancellationToken = default) where TRanker : IRanker where TRankerParameters : IRankerParameters => - EvaluateAsync(typeof(TRanker), trainFile, percentTrain, testFile, featureDefinitionFile, modelFile, parameters); + EvaluateAsync(typeof(TRanker), trainFile, percentTrain, testFile, featureDefinitionFile, modelFile, parameters, cancellationToken); /// /// Evaluates a new instance of specified by and @@ -348,16 +360,18 @@ public Task EvaluateAsync( /// The directory to save trained ranker models /// The name prefix of trained ranker models /// The ranker parameters + /// Token that can be used to cancel the operation public Task EvaluateAsync( string sampleFile, string featureDefinitionFile, int foldCount, string modelDir, string modelFile, - TRankerParameters? parameters = default) + TRankerParameters? parameters = default, + CancellationToken cancellationToken = default) where TRanker : IRanker where TRankerParameters : IRankerParameters => - EvaluateAsync(typeof(TRanker), sampleFile, featureDefinitionFile, foldCount, -1, modelDir, modelFile, parameters); + EvaluateAsync(typeof(TRanker), sampleFile, featureDefinitionFile, foldCount, -1, modelDir, modelFile, parameters, cancellationToken); /// /// Evaluates a new instance of specified by and @@ -371,6 +385,7 @@ public Task EvaluateAsync( /// The directory to save trained ranker models /// The name prefix of trained ranker models /// The ranker parameters + /// Token that can be used to cancel the operation public async Task EvaluateAsync( Type rankerType, string sampleFile, @@ -379,7 +394,8 @@ public async Task EvaluateAsync( float trainValidationSplit, string modelDir, string modelFile, - IRankerParameters? parameters = default) + IRankerParameters? parameters = default, + CancellationToken cancellationToken = default) { var trainingData = new List>(); var validationData = new List>(); @@ -412,7 +428,7 @@ public async Task EvaluateAsync( var validation = trainValidationSplit > 0 ? validationData[i] : null; var test = testData[i]; - var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters) + var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters, cancellationToken) .ConfigureAwait(false); var testScore = Evaluate(ranker, test); @@ -427,7 +443,7 @@ public async Task EvaluateAsync( if (!string.IsNullOrEmpty(modelDir)) { var foldModelFile = Path.Combine(modelDir, $"f{i + 1}.{modelFile}"); - await ranker.SaveAsync(foldModelFile).ConfigureAwait(false); + await ranker.SaveAsync(foldModelFile, cancellationToken).ConfigureAwait(false); _logger.LogInformation("Fold-{Fold} model saved to: {FoldModelFile}", i + 1, foldModelFile); } } @@ -453,6 +469,7 @@ public async Task EvaluateAsync( /// The directory to save trained ranker models /// The name prefix of trained ranker models /// The ranker parameters + /// Token that can be used to cancel the operation public Task EvaluateAsync( string sampleFile, string? featureDefinitionFile, @@ -460,10 +477,11 @@ public Task EvaluateAsync( float trainValidationSplit, string modelDir, string modelFile, - TRankerParameters? parameters = default) + TRankerParameters? parameters = default, + CancellationToken cancellationToken = default) where TRanker : IRanker where TRankerParameters : IRankerParameters => - EvaluateAsync(typeof(TRanker), sampleFile, featureDefinitionFile, foldCount, trainValidationSplit, modelDir, modelFile, parameters); + EvaluateAsync(typeof(TRanker), sampleFile, featureDefinitionFile, foldCount, trainValidationSplit, modelDir, modelFile, parameters, cancellationToken); public void Test(string testFile) { @@ -776,8 +794,7 @@ public void Rank(string modelFile, string testFile, string indriRankingFile) for (var j = 0; j < idx.Length; j++) { var k = idx[j]; - var str = $"{l.Id} Q0 {l[k].Description.Replace("#", "").Trim()} {j + 1} {SimpleMath.Round(scores[k], 5)} indri"; - outWriter.WriteLine(str); + outWriter.WriteLine($"{l.Id} Q0 {l[k].Description.AsSpan().Trim("#").Trim().ToString()} {j + 1} {SimpleMath.Round(scores[k], 5)} indri"); } } } @@ -797,10 +814,7 @@ public void Rank(string testFile, string indriRankingFile) foreach (var l in test) { for (var j = 0; j < l.Count; j++) - { - var str = $"{l.Id} Q0 {l[j].Description.Replace("#", "").Trim()} {j + 1} {SimpleMath.Round(1.0 - 0.0001 * j, 5)} indri"; - outWriter.WriteLine(str); - } + outWriter.WriteLine($"{l.Id} Q0 {l[j].Description.AsSpan().Trim("#").Trim().ToString()} {j + 1} {SimpleMath.Round(1.0 - 0.0001 * j, 5)} indri"); } } catch (IOException ex) @@ -848,7 +862,7 @@ public void Rank(List modelFiles, string testFile, string indriRankingFi } catch (IOException ex) { - throw RankLibException.Create("Error in Evaluator::Rank(): ", ex); + throw RankLibException.Create("Error ranking and writing indri ranking file", ex); } } @@ -878,15 +892,14 @@ public void Rank(List modelFiles, List testFiles, string indriRa for (var j = 0; j < idx.Length; j++) { var k = idx[j]; - var str = $"{l.Id} Q0 {l[k].Description.Replace("#", "").Trim()} {j + 1} {SimpleMath.Round(scores[k], 5)} indri"; - outWriter.WriteLine(str); + outWriter.WriteLine($"{l.Id} Q0 {l[k].Description.AsSpan().Trim("#").Trim().ToString()} {j + 1} {SimpleMath.Round(scores[k], 5)} indri"); } } } } catch (IOException ex) { - throw RankLibException.Create("Error in Evaluator::Rank(): ", ex); + throw RankLibException.Create("Error ranking and writing indri ranking file", ex); } } diff --git a/src/RankLib/Learning/Boosting/AdaRank.cs b/src/RankLib/Learning/Boosting/AdaRank.cs index 4c7b9b6..ea99bdd 100644 --- a/src/RankLib/Learning/Boosting/AdaRank.cs +++ b/src/RankLib/Learning/Boosting/AdaRank.cs @@ -142,12 +142,14 @@ private void UpdateBestModelOnValidation() return bestWeakRanker; } - private int Learn(int startIteration, bool withEnqueue) + private int Learn(int startIteration, bool withEnqueue, CancellationToken cancellationToken) { var t = startIteration; var bufferedLogger = new BufferedLogger(_logger, new StringBuilder()); for (; t <= Parameters.IterationCount; t++) { + CheckCancellation(_logger, cancellationToken); + bufferedLogger.PrintLog([7], [t.ToString()]); var bestWeakRanker = LearnWeakRanker(); @@ -266,7 +268,7 @@ private int Learn(int startIteration, bool withEnqueue) return t; } - public override Task InitAsync() + public override Task InitAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Initializing..."); _usedFeatures.Clear(); @@ -289,22 +291,24 @@ public override Task InitAsync() return Task.CompletedTask; } - public override Task LearnAsync() + public override Task LearnAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Training starts..."); _logger.PrintLog([7, 8, 9, 9, 9], ["#iter", "Sel. F.", Scorer.Name + "-T", Scorer.Name + "-V", "Status"]); + CheckCancellation(_logger, cancellationToken); + if (Parameters.TrainWithEnqueue) { - var t = Learn(1, true); + var t = Learn(1, true, cancellationToken); for (var i = _featureQueue.Count - 1; i >= 0; i--) { _featureQueue.RemoveAt(i); - t = Learn(t, false); + t = Learn(t, false, cancellationToken); } } else - Learn(1, false); + Learn(1, false, cancellationToken); if (ValidationSamples != null && _bestModelRankers.Count > 0) { @@ -318,6 +322,8 @@ public override Task LearnAsync() _logger.LogInformation("Finished successfully."); _logger.LogInformation($"{Scorer.Name} on training data: {TrainingDataScore}"); + CheckCancellation(_logger, cancellationToken); + if (ValidationSamples != null) { ValidationDataScore = Scorer.Score(Rank(ValidationSamples)); @@ -369,7 +375,7 @@ public override void LoadFromString(string model) } if (kvp == null) - throw new InvalidOperationException("Error in AdaRank::LoadFromString: Unable to load model"); + throw new InvalidOperationException("No model data found."); _rankerWeights = new List(); _rankers = new List(); diff --git a/src/RankLib/Learning/Boosting/RankBoost.cs b/src/RankLib/Learning/Boosting/RankBoost.cs index 09da1a8..c9f3c8b 100644 --- a/src/RankLib/Learning/Boosting/RankBoost.cs +++ b/src/RankLib/Learning/Boosting/RankBoost.cs @@ -203,7 +203,7 @@ private void UpdateSampleWeights(double alphaT) } } - public override Task InitAsync() + public override Task InitAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Initializing..."); @@ -212,6 +212,8 @@ public override Task InitAsync() _totalCorrectPairs = 0; for (var i = 0; i < Samples.Count; i++) { + CheckCancellation(_logger, cancellationToken); + Samples[i] = Samples[i].GetCorrectRanking(); // Ensure training samples are correctly ranked var rl = Samples[i]; for (var j = 0; j < rl.Count - 1; j++) @@ -224,6 +226,8 @@ public override Task InitAsync() _sweight = new double[Samples.Count][][]; for (var i = 0; i < Samples.Count; i++) { + CheckCancellation(_logger, cancellationToken); + var rl = Samples[i]; _sweight[i] = new double[rl.Count][]; for (var j = 0; j < rl.Count - 1; j++) @@ -238,6 +242,8 @@ public override Task InitAsync() for (var i = 0; i < Samples.Count; i++) _potential[i] = new double[Samples[i].Count]; + CheckCancellation(_logger, cancellationToken); + if (Parameters.Threshold <= 0) { var count = 0; @@ -304,6 +310,8 @@ public override Task InitAsync() for (var i = 0; i < Features.Length; i++) _tSortedIdx[i] = MergeSorter.Sort(_thresholds[i], false); + CheckCancellation(_logger, cancellationToken); + for (var i = 0; i < Features.Length; i++) { var idx = new List(); @@ -316,7 +324,7 @@ public override Task InitAsync() return Task.CompletedTask; } - public override Task LearnAsync() + public override Task LearnAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Training starts..."); _logger.PrintLog([7, 8, 9, 9, 9, 9], ["#iter", @@ -331,6 +339,7 @@ public override Task LearnAsync() for (var t = 1; t <= Parameters.IterationCount; t++) { + CheckCancellation(_logger, cancellationToken); UpdatePotential(); var wr = LearnWeakRanker(); if (wr == null) @@ -370,6 +379,8 @@ public override Task LearnAsync() bufferedLogger.FlushLog(); } + CheckCancellation(_logger, cancellationToken); + if (ValidationSamples != null && _bestModelRankers.Count > 0) { _wRankers.Clear(); diff --git a/src/RankLib/Learning/CoordinateAscent.cs b/src/RankLib/Learning/CoordinateAscent.cs index 085497b..dc4327a 100644 --- a/src/RankLib/Learning/CoordinateAscent.cs +++ b/src/RankLib/Learning/CoordinateAscent.cs @@ -115,7 +115,7 @@ public CoordinateAscent(List samples, int[] features, MetricScorer sco : base(samples, features, scorer) => _logger = logger ?? NullLogger.Instance; - public override Task InitAsync() + public override Task InitAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Initializing..."); Weight = new double[Features.Length]; @@ -123,7 +123,7 @@ public override Task InitAsync() return Task.CompletedTask; } - public override Task LearnAsync() + public override Task LearnAsync(CancellationToken cancellationToken = default) { var regVector = new double[Weight.Length]; Array.Copy(Weight, regVector, Weight.Length); // Uniform weight distribution @@ -136,6 +136,8 @@ public override Task LearnAsync() for (var r = 0; r < Parameters.RandomRestartCount; r++) { + CheckCancellation(_logger, cancellationToken); + _logger.LogInformation($"[+] Random restart #{r + 1}/{Parameters.RandomRestartCount}..."); var consecutiveFails = 0; @@ -151,6 +153,8 @@ public override Task LearnAsync() while ((Weight.Length > 1 && consecutiveFails < Weight.Length - 1) || (Weight.Length == 1 && consecutiveFails == 0)) { + CheckCancellation(_logger, cancellationToken); + _logger.LogInformation("Shuffling features' order..."); _logger.LogInformation("Optimizing weight vector... "); _logger.PrintLog([7, 8, 7], ["Feature", "weight", Scorer.Name]); @@ -255,6 +259,7 @@ public override Task LearnAsync() } } + CheckCancellation(_logger, cancellationToken); Array.Copy(bestModel!, Weight, bestModel!.Length); _currentFeature = -1; TrainingDataScore = Math.Round(Scorer.Score(Rank(Samples)), 4); diff --git a/src/RankLib/Learning/IRanker.cs b/src/RankLib/Learning/IRanker.cs index c19c774..22a85eb 100644 --- a/src/RankLib/Learning/IRanker.cs +++ b/src/RankLib/Learning/IRanker.cs @@ -46,14 +46,16 @@ public interface IRanker /// /// Initializes the ranker for training. /// + /// Token used to cancel the operation /// a new instance of that can be awaited. - Task InitAsync(); + Task InitAsync(CancellationToken cancellationToken = default); /// /// Trains the ranker to learn from the training samples. /// + /// Token used to cancel the operation /// a new instance of that can be awaited. - Task LearnAsync(); + Task LearnAsync(CancellationToken cancellationToken = default); /// /// Evaluates a data point. @@ -108,8 +110,9 @@ public interface IRanker /// Saves the model to file. /// /// The file path to save the model to. + /// Token used to cancel the operation /// a new instance of that can be awaited. - Task SaveAsync(string modelFile); + Task SaveAsync(string modelFile, CancellationToken cancellationToken = default); /// /// Gets the score from evaluation on the training data. diff --git a/src/RankLib/Learning/LinearRegression.cs b/src/RankLib/Learning/LinearRegression.cs index 41611c6..91355d4 100644 --- a/src/RankLib/Learning/LinearRegression.cs +++ b/src/RankLib/Learning/LinearRegression.cs @@ -47,15 +47,17 @@ public LinearRegression(List samples, int[] features, MetricScorer sco /// public override string Name => RankerName; + /// /// - public override Task InitAsync() + public override Task InitAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Initializing..."); return Task.CompletedTask; } + /// /// - public override Task LearnAsync() + public override Task LearnAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Training starts..."); _logger.LogInformation("Learning the least square model..."); @@ -81,6 +83,8 @@ public override Task LearnAsync() for (var s = 0; s < Samples.Count; s++) { + CheckCancellation(_logger, cancellationToken); + var rl = Samples[s]; for (var i = 0; i < rl.Count; i++) { @@ -109,6 +113,7 @@ public override Task LearnAsync() xTx[i][i] += Parameters.Lambda; } + CheckCancellation(_logger, cancellationToken); _weight = Solve(xTx, xTy); TrainingDataScore = SimpleMath.Round(Scorer.Score(Rank(Samples)), 4); diff --git a/src/RankLib/Learning/NeuralNet/ListNet.cs b/src/RankLib/Learning/NeuralNet/ListNet.cs index 842d146..2cf5ef2 100644 --- a/src/RankLib/Learning/NeuralNet/ListNet.cs +++ b/src/RankLib/Learning/NeuralNet/ListNet.cs @@ -95,7 +95,7 @@ protected override void EstimateLoss() _lastError = _error; } - public override Task InitAsync() + public override Task InitAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Initializing..."); @@ -106,13 +106,13 @@ public override Task InitAsync() if (ValidationSamples != null) { for (var i = 0; i < _layers.Count; i++) - _bestModelOnValidation.Add(new List()); + _bestModelOnValidation.Add([]); } return Task.CompletedTask; } - public override Task LearnAsync() + public override Task LearnAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Training starts..."); _logger.PrintLog([7, 14, 9, 9], ["#epoch", "C.E. Loss", Scorer.Name + "-T", Scorer.Name + "-V"]); @@ -121,6 +121,8 @@ public override Task LearnAsync() for (var i = 1; i <= Parameters.IterationCount; i++) { + CheckCancellation(_logger, cancellationToken); + for (var j = 0; j < Samples.Count; j++) { var labels = FeedForward(Samples[j]); @@ -147,6 +149,8 @@ public override Task LearnAsync() bufferedLogger.FlushLog(); } + CheckCancellation(_logger, cancellationToken); + // Restore the best model if validation data is used if (ValidationSamples != null) RestoreBestModelOnValidation(); diff --git a/src/RankLib/Learning/NeuralNet/RankNet.cs b/src/RankLib/Learning/NeuralNet/RankNet.cs index 96b7312..46113a5 100644 --- a/src/RankLib/Learning/NeuralNet/RankNet.cs +++ b/src/RankLib/Learning/NeuralNet/RankNet.cs @@ -291,16 +291,20 @@ protected virtual void EstimateLoss() _lastError = _error; } - public override Task InitAsync() + public override Task InitAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Initializing..."); + CheckCancellation(_logger, cancellationToken); + SetInputOutput(Features.Length, 1); for (var i = 0; i < Parameters.HiddenLayerCount; i++) AddHiddenLayer(Parameters.HiddenNodePerLayerCount); Wire(); + CheckCancellation(_logger, cancellationToken); + _totalPairs = 0; foreach (var rl in Samples) { @@ -318,7 +322,7 @@ public override Task InitAsync() return Task.CompletedTask; } - public override Task LearnAsync() + public override Task LearnAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Training starts..."); _logger.PrintLog([7, 14, 9, 9], @@ -329,6 +333,8 @@ public override Task LearnAsync() for (var i = 1; i <= Parameters.IterationCount; i++) { + CheckCancellation(_logger, cancellationToken); + for (var j = 0; j < Samples.Count; j++) { var rl = InternalReorder(Samples[j]); @@ -360,6 +366,8 @@ public override Task LearnAsync() bufferedLogger.FlushLog(); } + CheckCancellation(_logger, cancellationToken); + // Restore the best model if validation data was specified if (ValidationSamples != null) RestoreBestModelOnValidation(); diff --git a/src/RankLib/Learning/Ranker.cs b/src/RankLib/Learning/Ranker.cs index 1e66a60..541ea09 100644 --- a/src/RankLib/Learning/Ranker.cs +++ b/src/RankLib/Learning/Ranker.cs @@ -1,4 +1,5 @@ using System.Text; +using Microsoft.Extensions.Logging; using RankLib.Eval; using RankLib.Metric; using RankLib.Utilities; @@ -103,18 +104,20 @@ public List Rank(List rankLists) } /// - public async Task SaveAsync(string modelFile) + public async Task SaveAsync(string modelFile, CancellationToken cancellationToken = default) { var directory = Path.GetDirectoryName(Path.GetFullPath(modelFile)); Directory.CreateDirectory(directory!); await File.WriteAllTextAsync(modelFile, GetModel(), Encoding.ASCII); } + /// /// - public abstract Task InitAsync(); + public abstract Task InitAsync(CancellationToken cancellationToken = default); + /// /// - public abstract Task LearnAsync(); + public abstract Task LearnAsync(CancellationToken cancellationToken = default); /// public abstract double Eval(DataPoint dataPoint); @@ -127,4 +130,13 @@ public async Task SaveAsync(string modelFile) /// public abstract string Name { get; } + + internal static void CheckCancellation(ILogger logger, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + logger.LogInformation("The operation was cancelled."); + cancellationToken.ThrowIfCancellationRequested(); + } + } } diff --git a/src/RankLib/Learning/RankerTrainer.cs b/src/RankLib/Learning/RankerTrainer.cs index 2c23da1..72e4ab0 100644 --- a/src/RankLib/Learning/RankerTrainer.cs +++ b/src/RankLib/Learning/RankerTrainer.cs @@ -24,6 +24,7 @@ public class RankerTrainer /// The features /// the scorer used to measure the effectiveness of the ranker /// The ranking parameters + /// Token that can be used to cancel the operation /// A new instance of a trained public async Task TrainAsync( Type rankerType, @@ -31,12 +32,13 @@ public async Task TrainAsync( List? validationSamples, int[] features, MetricScorer scorer, - IRankerParameters? parameters = default) + IRankerParameters? parameters = default, + CancellationToken cancellationToken = default) { var ranker = _rankerFactory.CreateRanker(rankerType, trainingSamples, features, scorer, parameters); ranker.ValidationSamples = validationSamples; - await ranker.InitAsync().ConfigureAwait(false); - await ranker.LearnAsync().ConfigureAwait(false); + await ranker.InitAsync(cancellationToken).ConfigureAwait(false); + await ranker.LearnAsync(cancellationToken).ConfigureAwait(false); return ranker; } @@ -48,6 +50,7 @@ public async Task TrainAsync( /// The features /// the scorer used to measure the effectiveness of the ranker /// The ranking parameters + /// Token that can be used to cancel the operation /// A new instance of a trained /// The type of ranker /// The type of ranker parameters @@ -57,14 +60,15 @@ public async Task TrainAsync( List? validationSamples, int[] features, MetricScorer scorer, - TRankerParameters? parameters = default) + TRankerParameters? parameters = default, + CancellationToken cancellationToken = default) where TRanker : IRanker where TRankerParameters : IRankerParameters { var ranker = _rankerFactory.CreateRanker(trainingSamples, features, scorer, parameters); ranker.ValidationSamples = validationSamples; - await ranker.InitAsync().ConfigureAwait(false); - await ranker.LearnAsync().ConfigureAwait(false); + await ranker.InitAsync(cancellationToken).ConfigureAwait(false); + await ranker.LearnAsync(cancellationToken).ConfigureAwait(false); return ranker; } } diff --git a/src/RankLib/Learning/Tree/FeatureHistogram.cs b/src/RankLib/Learning/Tree/FeatureHistogram.cs index ae63bc5..e394d4d 100644 --- a/src/RankLib/Learning/Tree/FeatureHistogram.cs +++ b/src/RankLib/Learning/Tree/FeatureHistogram.cs @@ -39,7 +39,7 @@ public FeatureHistogram(float samplingRate, int? maxDegreesOfParallelism = null) _maxDegreesOfParallelism = maxDegreesOfParallelism ?? Environment.ProcessorCount; } - public async Task ConstructAsync(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds, double[] impacts) + public async Task ConstructAsync(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds, double[] impacts, CancellationToken cancellationToken = default) { _features = features; _thresholds = thresholds; @@ -58,17 +58,17 @@ public async Task ConstructAsync(DataPoint[] samples, double[] labels, int[][] s Partitioner.PartitionEnumerable(_features.Length, _maxDegreesOfParallelism); await Parallel.ForEachAsync( partitions, - new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism }, - async (range, cancellationToken) => + new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism, CancellationToken = cancellationToken }, + async (range, ct) => { await Task.Run( - () => Construct(samples, labels, sampleSortedIdx, thresholds, range.Start.Value, - range.End.Value), cancellationToken).ConfigureAwait(false); + () => Construct(samples, labels, sampleSortedIdx, thresholds, range.Start.Value, range.End.Value), + ct).ConfigureAwait(false); }).ConfigureAwait(false); } } - public async Task FindBestSplitAsync(Split sp, double[] labels, int minLeafSupport) + public async Task FindBestSplitAsync(Split sp, double[] labels, int minLeafSupport, CancellationToken cancellationToken = default) { if (sp.Deviance == 0) return false; // No need to split @@ -109,7 +109,8 @@ public async Task FindBestSplitAsync(Split sp, double[] labels, int minLea .Select>(range => new Task(() => FindBestSplit(usedFeatures, minLeafSupport, range.Start.Value, range.End.Value))) .ToList(); - await Parallel.ForEachAsync(tasks, new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism }, async (task, _) => + // TODO: pass cancellation token... + await Parallel.ForEachAsync(tasks, new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism, CancellationToken = cancellationToken }, async (task, ct) => { task.Start(); await task.ConfigureAwait(false); @@ -157,9 +158,9 @@ public async Task FindBestSplitAsync(Split sp, double[] labels, int minLea _impacts[best.FeatureIdx] += best.ErrReduced; var lh = new FeatureHistogram(_samplingRate, _maxDegreesOfParallelism); - await lh.ConstructAsync(sp.Histogram!, left, labels).ConfigureAwait(false); + await lh.ConstructAsync(sp.Histogram!, left, labels, cancellationToken).ConfigureAwait(false); var rh = new FeatureHistogram(_samplingRate, _maxDegreesOfParallelism); - await rh.ConstructAsync(sp.Histogram!, lh, !sp.IsRoot).ConfigureAwait(false); + await rh.ConstructAsync(sp.Histogram!, lh, !sp.IsRoot, cancellationToken).ConfigureAwait(false); var var = _sqSumResponse - _sumResponse * _sumResponse / idx.Length; var varLeft = lh._sqSumResponse - lh._sumResponse * lh._sumResponse / left.Length; @@ -214,7 +215,7 @@ private void Construct(DataPoint[] samples, double[] labels, int[][] sampleSorte } } - internal async Task UpdateAsync(double[] labels) + internal async Task UpdateAsync(double[] labels, CancellationToken cancellationToken = default) { _sumResponse = 0; _sqSumResponse = 0; @@ -227,10 +228,10 @@ internal async Task UpdateAsync(double[] labels) Partitioner.PartitionEnumerable(_features.Length, _maxDegreesOfParallelism); await Parallel.ForEachAsync( partitions, - new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism }, - async (range, cancellationToken) => + new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism, CancellationToken = cancellationToken }, + async (range, ct) => { - await Task.Run(() => Update(labels, range.Start.Value, range.End.Value), cancellationToken) + await Task.Run(() => Update(labels, range.Start.Value, range.End.Value), ct) .ConfigureAwait(false); }).ConfigureAwait(false); } @@ -276,7 +277,7 @@ private void Update(double[] labels, int start, int end) } } - private async Task ConstructAsync(FeatureHistogram parent, int[] soi, double[] labels) + private async Task ConstructAsync(FeatureHistogram parent, int[] soi, double[] labels, CancellationToken cancellationToken = default) { _features = parent._features; _thresholds = parent._thresholds; @@ -296,11 +297,12 @@ private async Task ConstructAsync(FeatureHistogram parent, int[] soi, double[] l await Parallel.ForEachAsync( partitions, - new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism }, - async (range, cancellationToken) => + new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism, CancellationToken = cancellationToken }, + async (range, ct) => { - await Task.Run(() => Construct(soi, labels, range.Start.Value, range.End.Value), - cancellationToken).ConfigureAwait(false); + await Task.Run( + () => Construct(soi, labels, range.Start.Value, range.End.Value), ct) + .ConfigureAwait(false); }).ConfigureAwait(false); } } @@ -342,7 +344,7 @@ private void Construct(int[] soi, double[] labels, int start, int end) } } - private async Task ConstructAsync(FeatureHistogram parent, FeatureHistogram leftSibling, bool reuseParent) + private async Task ConstructAsync(FeatureHistogram parent, FeatureHistogram leftSibling, bool reuseParent, CancellationToken cancellationToken = default) { _reuseParent = reuseParent; _features = parent._features; @@ -371,11 +373,12 @@ private async Task ConstructAsync(FeatureHistogram parent, FeatureHistogram left Partitioner.PartitionEnumerable(_features.Length, _maxDegreesOfParallelism); await Parallel.ForEachAsync( partitions, - new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism }, - async (range, cancellationToken) => + new ParallelOptions { MaxDegreeOfParallelism = _maxDegreesOfParallelism, CancellationToken = cancellationToken }, + async (range, ct) => { - await Task.Run(() => Construct(parent, leftSibling, range.Start.Value, range.End.Value), - cancellationToken).ConfigureAwait(false); + await Task.Run( + () => Construct(parent, leftSibling, range.Start.Value, range.End.Value), + ct).ConfigureAwait(false); }).ConfigureAwait(false); } } diff --git a/src/RankLib/Learning/Tree/LambdaMART.cs b/src/RankLib/Learning/Tree/LambdaMART.cs index 3e9dcf1..5382fb5 100644 --- a/src/RankLib/Learning/Tree/LambdaMART.cs +++ b/src/RankLib/Learning/Tree/LambdaMART.cs @@ -211,7 +211,7 @@ public LambdaMART(List samples, int[] features, MetricScorer scorer, I public Ensemble Ensemble => _ensemble; /// - public override async Task InitAsync() + public override async Task InitAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Initializing..."); @@ -239,6 +239,7 @@ public override async Task InitAsync() // Sort samples by each feature _sortedIdx = new int[Features.Length][]; + CheckCancellation(_logger, cancellationToken); if (Parameters.MaxDegreeOfParallelism == 1) SortSamplesByFeature(0, Features.Length - 1); else @@ -247,11 +248,11 @@ public override async Task InitAsync() Partitioner.PartitionEnumerable(Features.Length, Parameters.MaxDegreeOfParallelism); await Parallel.ForEachAsync( partitions, - new ParallelOptions { MaxDegreeOfParallelism = Parameters.MaxDegreeOfParallelism }, - async (range, cancellationToken) => + new ParallelOptions { MaxDegreeOfParallelism = Parameters.MaxDegreeOfParallelism, CancellationToken = cancellationToken }, + async (range, ct) => { await Task.Run(() => - SortSamplesByFeature(range.Start.Value, range.End.Value), cancellationToken).ConfigureAwait(false); + SortSamplesByFeature(range.Start.Value, range.End.Value), ct).ConfigureAwait(false); }).ConfigureAwait(false); } @@ -259,6 +260,8 @@ await Task.Run(() => _thresholds = new float[Features.Length][]; for (var f = 0; f < Features.Length; f++) { + CheckCancellation(_logger, cancellationToken); + //For this feature, keep track of the list of unique values and the max/min var values = new List(MARTSamples.Length); var fMax = float.NegativeInfinity; @@ -309,6 +312,7 @@ await Task.Run(() => } } + CheckCancellation(_logger, cancellationToken); if (ValidationSamples != null) { _modelScoresOnValidation = new double[ValidationSamples.Count][]; @@ -319,16 +323,19 @@ await Task.Run(() => } } + CheckCancellation(_logger, cancellationToken); _histogram = new FeatureHistogram(Parameters.SamplingRate, Parameters.MaxDegreeOfParallelism); - await _histogram.ConstructAsync(MARTSamples, PseudoResponses, _sortedIdx, Features, _thresholds, Impacts).ConfigureAwait(false); + await _histogram.ConstructAsync(MARTSamples, PseudoResponses, _sortedIdx, Features, _thresholds, Impacts, cancellationToken).ConfigureAwait(false); //we no longer need the sorted indexes of samples _sortedIdx = []; } /// - public override async Task LearnAsync() + public override async Task LearnAsync(CancellationToken cancellationToken = default) { + CheckCancellation(_logger, cancellationToken); + _ensemble = new Ensemble(); _logger.LogInformation("Training starts..."); @@ -340,20 +347,25 @@ public override async Task LearnAsync() var bufferedLogger = new BufferedLogger(_logger, new StringBuilder()); for (var m = 0; m < Parameters.TreeCount; m++) { + CheckCancellation(_logger, cancellationToken); bufferedLogger.PrintLog([7], [(m + 1).ToString()]); //Compute lambdas (which act as the "pseudo responses") //Create training instances for MART: // - Each document is a training sample // - The lambda for this document serves as its training label - await ComputePseudoResponsesAsync().ConfigureAwait(false); + await ComputePseudoResponsesAsync(cancellationToken).ConfigureAwait(false); + + CheckCancellation(_logger, cancellationToken); //update the histogram with these training labels (the feature histogram will be used to find the best tree split) - await _histogram.UpdateAsync(PseudoResponses).ConfigureAwait(false); + await _histogram.UpdateAsync(PseudoResponses, cancellationToken).ConfigureAwait(false); + + CheckCancellation(_logger, cancellationToken); //Fit a regression tree var tree = new RegressionTree(Parameters.TreeLeavesCount, MARTSamples, PseudoResponses, _histogram, Parameters.MinimumLeafSupport); - await tree.FitAsync().ConfigureAwait(false); + await tree.FitAsync(cancellationToken).ConfigureAwait(false); //Add this tree to the ensemble (our model) _ensemble.Add(tree, Parameters.LearningRate); @@ -454,7 +466,7 @@ public override void LoadFromString(string model) /// /// Computes the pseudo responses for an iteration of learning. /// - protected virtual async Task ComputePseudoResponsesAsync() + protected virtual async Task ComputePseudoResponsesAsync(CancellationToken cancellationToken = default) { Array.Fill(PseudoResponses, 0); Array.Fill(_weights, 0); @@ -483,11 +495,11 @@ protected virtual async Task ComputePseudoResponsesAsync() var parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = Parameters.MaxDegreeOfParallelism, - CancellationToken = default + CancellationToken = cancellationToken }; - await Parallel.ForEachAsync(tuples, parallelOptions, async (values, cancellationToken) => - await Task.Run(() => ComputePseudoResponses(values.start, values.end, values.current), cancellationToken).ConfigureAwait(false)) + await Parallel.ForEachAsync(tuples, parallelOptions, async (values, ct) => + await Task.Run(() => ComputePseudoResponses(values.start, values.end, values.current), ct).ConfigureAwait(false)) .ConfigureAwait(false); } } diff --git a/src/RankLib/Learning/Tree/MART.cs b/src/RankLib/Learning/Tree/MART.cs index bd40594..bb6cd79 100644 --- a/src/RankLib/Learning/Tree/MART.cs +++ b/src/RankLib/Learning/Tree/MART.cs @@ -65,7 +65,7 @@ public MART(LambdaMARTParameters parameters, List samples, int[] featu public override string Name => RankerName; /// - protected override Task ComputePseudoResponsesAsync() + protected override Task ComputePseudoResponsesAsync(CancellationToken cancellationToken = default) { for (var i = 0; i < MARTSamples.Length; i++) PseudoResponses[i] = MARTSamples[i].Label - ModelScores[i]; diff --git a/src/RankLib/Learning/Tree/RandomForests.cs b/src/RankLib/Learning/Tree/RandomForests.cs index 3579e19..7c4418e 100644 --- a/src/RankLib/Learning/Tree/RandomForests.cs +++ b/src/RankLib/Learning/Tree/RandomForests.cs @@ -161,7 +161,7 @@ public RandomForests(List samples, int[] features, MetricScorer scorer _logger = _loggerFactory.CreateLogger(); } - public override Task InitAsync() + public override Task InitAsync(CancellationToken cancellationToken = default) { _logger.LogInformation("Initializing..."); Ensembles = new Ensemble[Parameters.BagCount]; @@ -182,7 +182,7 @@ public override Task InitAsync() return Task.CompletedTask; } - public override async Task LearnAsync() + public override async Task LearnAsync(CancellationToken cancellationToken = default) { var rankerFactory = new RankerFactory(_loggerFactory); _logger.LogInformation("Training starts..."); @@ -193,11 +193,13 @@ public override async Task LearnAsync() // Start the bagging process for (var i = 0; i < Parameters.BagCount; i++) { + CheckCancellation(_logger, cancellationToken); + // Create a "bag" of samples by random sampling from the training set var (bag, _) = Sampler.Sample(Samples, Parameters.SubSamplingRate, true); var ranker = (LambdaMART)rankerFactory.CreateRanker(Parameters.RankerType, bag, Features, Scorer, _lambdaMARTParameters); - await ranker.InitAsync().ConfigureAwait(false); - await ranker.LearnAsync().ConfigureAwait(false); + await ranker.InitAsync(cancellationToken).ConfigureAwait(false); + await ranker.LearnAsync(cancellationToken).ConfigureAwait(false); // Accumulate impacts if (impacts == null) @@ -211,6 +213,8 @@ public override async Task LearnAsync() Ensembles[i] = ranker.Ensemble; } + CheckCancellation(_logger, cancellationToken); + // Finishing up TrainingDataScore = Scorer.Score(Rank(Samples)); _logger.LogInformation("Finished successfully."); @@ -222,6 +226,8 @@ public override async Task LearnAsync() _logger.LogInformation(Scorer.Name + " on validation data: " + SimpleMath.Round(ValidationDataScore, 4)); } + CheckCancellation(_logger, cancellationToken); + // Print feature impacts _logger.LogInformation("-- FEATURE IMPACTS"); if (_logger.IsEnabled(LogLevel.Information)) diff --git a/src/RankLib/Learning/Tree/RegressionTree.cs b/src/RankLib/Learning/Tree/RegressionTree.cs index cce17dd..6ae8e54 100644 --- a/src/RankLib/Learning/Tree/RegressionTree.cs +++ b/src/RankLib/Learning/Tree/RegressionTree.cs @@ -39,7 +39,7 @@ public RegressionTree(int treeLeavesCount, DataPoint[] trainingSamples, double[] /// /// Fits the tree from the specified training data. /// - public async Task FitAsync() + public async Task FitAsync(CancellationToken cancellationToken = default) { var queue = new List(); _root = new Split(_index, _hist, float.MaxValue, 0) @@ -48,7 +48,7 @@ public async Task FitAsync() }; // Ensure inserts occur only after successful splits - if (await _root.TrySplitAsync(_trainingLabels, _minLeafSupport).ConfigureAwait(false)) + if (await _root.TrySplitAsync(_trainingLabels, _minLeafSupport, cancellationToken).ConfigureAwait(false)) { Insert(queue, _root.Left); Insert(queue, _root.Right); @@ -67,7 +67,7 @@ public async Task FitAsync() } // unsplit-able (i.e. variance(s)==0; or after-split variance is higher than before) - if (!await leaf.TrySplitAsync(_trainingLabels, _minLeafSupport).ConfigureAwait(false)) + if (!await leaf.TrySplitAsync(_trainingLabels, _minLeafSupport, cancellationToken).ConfigureAwait(false)) taken++; else { diff --git a/src/RankLib/Learning/Tree/Split.cs b/src/RankLib/Learning/Tree/Split.cs index 58f34eb..e73508c 100644 --- a/src/RankLib/Learning/Tree/Split.cs +++ b/src/RankLib/Learning/Tree/Split.cs @@ -163,12 +163,12 @@ private string GetString(string indent) // Internal functions (ONLY used during learning) //*DO NOT* attempt to call them once the training is done - internal async Task TrySplitAsync(double[] trainingLabels, int minLeafSupport) + internal async Task TrySplitAsync(double[] trainingLabels, int minLeafSupport, CancellationToken cancellationToken = default) { if (Histogram is null) throw new InvalidOperationException("Histogram is null"); - return await Histogram.FindBestSplitAsync(this, trainingLabels, minLeafSupport).ConfigureAwait(false); + return await Histogram.FindBestSplitAsync(this, trainingLabels, minLeafSupport, cancellationToken).ConfigureAwait(false); } public int[] GetSamples() => _sortedSampleIDs != null ? _sortedSampleIDs[0] : _samples;