Skip to content

Commit

Permalink
Support CLI cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
russcam committed Nov 20, 2024
1 parent 8fe4bff commit 772082f
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 112 deletions.
12 changes: 8 additions & 4 deletions src/RankLib.Cli/EvaluateCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,8 @@ await evaluator.EvaluateAsync(
tvSplit,
kcvModelDir!.FullName,
kcvModelFile!,
rankerParameters).ConfigureAwait(false);
rankerParameters,
cancellationToken).ConfigureAwait(false);
}
else
{
Expand All @@ -481,7 +482,8 @@ await evaluator.EvaluateAsync(
featureDescriptionFile?.FullName,
ttSplit,
options.ModelOutputFile?.FullName,
rankerParameters).ConfigureAwait(false);
rankerParameters,
cancellationToken).ConfigureAwait(false);
}
else if (tvSplit > 0.0)
{
Expand All @@ -492,7 +494,8 @@ await evaluator.EvaluateAsync(
testFiles.LastOrDefault(),
featureDescriptionFile?.FullName,
options.ModelOutputFile?.FullName,
rankerParameters).ConfigureAwait(false);
rankerParameters,
cancellationToken).ConfigureAwait(false);
}
else
{
Expand All @@ -503,7 +506,8 @@ await evaluator.EvaluateAsync(
testFiles.LastOrDefault(),
featureDescriptionFile?.FullName,
options.ModelOutputFile?.FullName,
rankerParameters).ConfigureAwait(false);
rankerParameters,
cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
77 changes: 45 additions & 32 deletions src/RankLib/Eval/Evaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ public double Evaluate(IRanker? ranker, List<RankList> rankLists)
/// <param name="featureDefinitionFile">The feature definitions</param>
/// <param name="modelFile">A path to save the trained ranker to</param>
/// <param name="parameters">The ranker parameters</param>
/// <param name="cancellationToken">Token that can be used to cancel the operation</param>
/// <exception cref="ArgumentException">The ranker type is not an <see cref="IRanker"/></exception>
public async Task EvaluateAsync(
Type rankerType,
Expand All @@ -153,7 +154,8 @@ public async Task EvaluateAsync(
string? testFile,
string? featureDefinitionFile,
string? modelFile = null,
IRankerParameters? parameters = default)
IRankerParameters? parameters = default,
CancellationToken cancellationToken = default)
{
if (!typeof(IRanker).IsAssignableFrom(rankerType))
throw new ArgumentException($"Ranker type {rankerType} is not a ranker");
Expand All @@ -172,7 +174,7 @@ public async Task EvaluateAsync(
Normalize(test, features);
}

var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters)
var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters, cancellationToken)
.ConfigureAwait(false);

if (test != null)
Expand All @@ -183,7 +185,7 @@ public async Task EvaluateAsync(

if (!string.IsNullOrEmpty(modelFile))
{
await ranker.SaveAsync(modelFile);
await ranker.SaveAsync(modelFile, cancellationToken).ConfigureAwait(false);
_logger.LogInformation("Model saved to: {ModelFile}", modelFile);
}
}
Expand All @@ -198,16 +200,18 @@ public async Task EvaluateAsync(
/// <param name="featureDefinitionFile">The feature definitions</param>
/// <param name="modelFile">A path to save the trained ranker to</param>
/// <param name="parameters">The ranker parameters</param>
/// <param name="cancellationToken">Token that can be used to cancel the operation</param>
public Task EvaluateAsync<TRanker, TRankerParameters>(
string trainFile,
string? validationFile,
string? testFile,
string? featureDefinitionFile,
string? modelFile = null,
TRankerParameters? parameters = default)
TRankerParameters? parameters = default,
CancellationToken cancellationToken = default)
where TRanker : IRanker<TRankerParameters>
where TRankerParameters : IRankerParameters =>
EvaluateAsync(typeof(TRanker), trainFile, validationFile, testFile, featureDefinitionFile, modelFile, parameters);
EvaluateAsync(typeof(TRanker), trainFile, validationFile, testFile, featureDefinitionFile, modelFile, parameters, cancellationToken);

/// <summary>
/// Evaluates a new instance of <see cref="IRanker"/> specified by <paramref name="rankerType"/> and
Expand All @@ -220,6 +224,7 @@ public Task EvaluateAsync<TRanker, TRankerParameters>(
/// <param name="percentTrain">The percentage of <paramref name="sampleFile"/> to use for training data.</param>
/// <param name="modelFile">A path to save the trained ranker to</param>
/// <param name="parameters">The ranker parameters</param>
/// <param name="cancellationToken">Token that can be used to cancel the operation</param>
/// <exception cref="ArgumentException">The ranker type is not an <see cref="IRanker"/></exception>
public async Task EvaluateAsync(
Type rankerType,
Expand All @@ -228,7 +233,8 @@ public async Task EvaluateAsync(
string? featureDefinitionFile,
double percentTrain,
string? modelFile = null,
IRankerParameters? parameters = default)
IRankerParameters? parameters = default,
CancellationToken cancellationToken = default)
{
var train = new List<RankList>();
var test = new List<RankList>();
Expand All @@ -238,15 +244,15 @@ public async Task EvaluateAsync(
if (_normalize && validation != null)
Normalize(validation, features);

var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters)
var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters, cancellationToken)
.ConfigureAwait(false);

var rankScore = Evaluate(ranker, test);
_logger.LogInformation($"{_testScorer.Name} on test data: {Math.Round(rankScore, 4)}");

if (!string.IsNullOrEmpty(modelFile))
{
await ranker.SaveAsync(modelFile);
await ranker.SaveAsync(modelFile, cancellationToken);
_logger.LogInformation("Model saved to: {ModelFile}", modelFile);
}
}
Expand All @@ -261,16 +267,18 @@ public async Task EvaluateAsync(
/// <param name="percentTrain">The percentage of <paramref name="sampleFile"/> to use for training data.</param>
/// <param name="modelFile">A path to save the trained ranker to</param>
/// <param name="parameters">The ranker parameters</param>
/// <param name="cancellationToken">Token that can be used to cancel the operation</param>
public Task EvaluateAsync<TRanker, TRankerParameters>(
string sampleFile,
string? validationFile,
string featureDefinitionFile,
double percentTrain,
string? modelFile = null,
TRankerParameters? parameters = default)
TRankerParameters? parameters = default,
CancellationToken cancellationToken = default)
where TRanker : IRanker<TRankerParameters>
where TRankerParameters : IRankerParameters =>
EvaluateAsync(typeof(TRanker), sampleFile, validationFile, featureDefinitionFile, percentTrain, modelFile, parameters);
EvaluateAsync(typeof(TRanker), sampleFile, validationFile, featureDefinitionFile, percentTrain, modelFile, parameters, cancellationToken);

/// <summary>
/// Evaluates a new instance of <see cref="IRanker"/> specified by <paramref name="rankerType"/> and
Expand All @@ -283,6 +291,7 @@ public Task EvaluateAsync<TRanker, TRankerParameters>(
/// <param name="featureDefinitionFile">The feature definitions</param>
/// <param name="modelFile">A path to save the trained ranker to</param>
/// <param name="parameters">The ranker parameters</param>
/// <param name="cancellationToken">Token that can be used to cancel the operation</param>
/// <exception cref="ArgumentException">The ranker type is not an <see cref="IRanker"/></exception>
public async Task EvaluateAsync(
Type rankerType,
Expand All @@ -291,7 +300,8 @@ public async Task EvaluateAsync(
string? testFile,
string? featureDefinitionFile,
string? modelFile = null,
IRankerParameters? parameters = default)
IRankerParameters? parameters = default,
CancellationToken cancellationToken = default)
{
var train = new List<RankList>();
var validation = new List<RankList>();
Expand All @@ -301,7 +311,7 @@ public async Task EvaluateAsync(
if (_normalize && test != null)
Normalize(test, features);

var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters)
var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters, cancellationToken)
.ConfigureAwait(false);

if (test != null)
Expand All @@ -312,7 +322,7 @@ public async Task EvaluateAsync(

if (!string.IsNullOrEmpty(modelFile))
{
await ranker.SaveAsync(modelFile);
await ranker.SaveAsync(modelFile, cancellationToken).ConfigureAwait(false);
_logger.LogInformation("Model saved to: {ModelFile}", modelFile);
}
}
Expand All @@ -327,16 +337,18 @@ public async Task EvaluateAsync(
/// <param name="featureDefinitionFile">The feature definitions</param>
/// <param name="modelFile">A path to save the trained ranker to</param>
/// <param name="parameters">The ranker parameters</param>
/// <param name="cancellationToken">Token that can be used to cancel the operation</param>
public Task EvaluateAsync<TRanker, TRankerParameters>(
string trainFile,
double percentTrain,
string? testFile,
string featureDefinitionFile,
string? modelFile = null,
TRankerParameters? parameters = default)
TRankerParameters? parameters = default,
CancellationToken cancellationToken = default)
where TRanker : IRanker<TRankerParameters>
where TRankerParameters : IRankerParameters =>
EvaluateAsync(typeof(TRanker), trainFile, percentTrain, testFile, featureDefinitionFile, modelFile, parameters);
EvaluateAsync(typeof(TRanker), trainFile, percentTrain, testFile, featureDefinitionFile, modelFile, parameters, cancellationToken);

/// <summary>
/// Evaluates a new instance of <see cref="IRanker"/> specified by <typeparamref name="TRanker"/> and
Expand All @@ -348,16 +360,18 @@ public Task EvaluateAsync<TRanker, TRankerParameters>(
/// <param name="modelDir">The directory to save trained ranker models</param>
/// <param name="modelFile">The name prefix of trained ranker models</param>
/// <param name="parameters">The ranker parameters</param>
/// <param name="cancellationToken">Token that can be used to cancel the operation</param>
public Task EvaluateAsync<TRanker, TRankerParameters>(
string sampleFile,
string featureDefinitionFile,
int foldCount,
string modelDir,
string modelFile,
TRankerParameters? parameters = default)
TRankerParameters? parameters = default,
CancellationToken cancellationToken = default)
where TRanker : IRanker<TRankerParameters>
where TRankerParameters : IRankerParameters =>
EvaluateAsync(typeof(TRanker), sampleFile, featureDefinitionFile, foldCount, -1, modelDir, modelFile, parameters);
EvaluateAsync(typeof(TRanker), sampleFile, featureDefinitionFile, foldCount, -1, modelDir, modelFile, parameters, cancellationToken);

/// <summary>
/// Evaluates a new instance of <see cref="IRanker"/> specified by <paramref name="rankerType"/> and
Expand All @@ -371,6 +385,7 @@ public Task EvaluateAsync<TRanker, TRankerParameters>(
/// <param name="modelDir">The directory to save trained ranker models</param>
/// <param name="modelFile">The name prefix of trained ranker models</param>
/// <param name="parameters">The ranker parameters</param>
/// <param name="cancellationToken">Token that can be used to cancel the operation</param>
public async Task EvaluateAsync(
Type rankerType,
string sampleFile,
Expand All @@ -379,7 +394,8 @@ public async Task EvaluateAsync(
float trainValidationSplit,
string modelDir,
string modelFile,
IRankerParameters? parameters = default)
IRankerParameters? parameters = default,
CancellationToken cancellationToken = default)
{
var trainingData = new List<List<RankList>>();
var validationData = new List<List<RankList>>();
Expand Down Expand Up @@ -412,7 +428,7 @@ public async Task EvaluateAsync(
var validation = trainValidationSplit > 0 ? validationData[i] : null;
var test = testData[i];

var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters)
var ranker = await _trainer.TrainAsync(rankerType, train, validation, features, _trainScorer, parameters, cancellationToken)
.ConfigureAwait(false);

var testScore = Evaluate(ranker, test);
Expand All @@ -427,7 +443,7 @@ public async Task EvaluateAsync(
if (!string.IsNullOrEmpty(modelDir))
{
var foldModelFile = Path.Combine(modelDir, $"f{i + 1}.{modelFile}");
await ranker.SaveAsync(foldModelFile).ConfigureAwait(false);
await ranker.SaveAsync(foldModelFile, cancellationToken).ConfigureAwait(false);
_logger.LogInformation("Fold-{Fold} model saved to: {FoldModelFile}", i + 1, foldModelFile);
}
}
Expand All @@ -453,17 +469,19 @@ public async Task EvaluateAsync(
/// <param name="modelDir">The directory to save trained ranker models</param>
/// <param name="modelFile">The name prefix of trained ranker models</param>
/// <param name="parameters">The ranker parameters</param>
/// <param name="cancellationToken">Token that can be used to cancel the operation</param>
public Task EvaluateAsync<TRanker, TRankerParameters>(
string sampleFile,
string? featureDefinitionFile,
int foldCount,
float trainValidationSplit,
string modelDir,
string modelFile,
TRankerParameters? parameters = default)
TRankerParameters? parameters = default,
CancellationToken cancellationToken = default)
where TRanker : IRanker<TRankerParameters>
where TRankerParameters : IRankerParameters =>
EvaluateAsync(typeof(TRanker), sampleFile, featureDefinitionFile, foldCount, trainValidationSplit, modelDir, modelFile, parameters);
EvaluateAsync(typeof(TRanker), sampleFile, featureDefinitionFile, foldCount, trainValidationSplit, modelDir, modelFile, parameters, cancellationToken);

public void Test(string testFile)
{
Expand Down Expand Up @@ -776,8 +794,7 @@ public void Rank(string modelFile, string testFile, string indriRankingFile)
for (var j = 0; j < idx.Length; j++)
{
var k = idx[j];
var str = $"{l.Id} Q0 {l[k].Description.Replace("#", "").Trim()} {j + 1} {SimpleMath.Round(scores[k], 5)} indri";
outWriter.WriteLine(str);
outWriter.WriteLine($"{l.Id} Q0 {l[k].Description.AsSpan().Trim("#").Trim().ToString()} {j + 1} {SimpleMath.Round(scores[k], 5)} indri");
}
}
}
Expand All @@ -797,10 +814,7 @@ public void Rank(string testFile, string indriRankingFile)
foreach (var l in test)
{
for (var j = 0; j < l.Count; j++)
{
var str = $"{l.Id} Q0 {l[j].Description.Replace("#", "").Trim()} {j + 1} {SimpleMath.Round(1.0 - 0.0001 * j, 5)} indri";
outWriter.WriteLine(str);
}
outWriter.WriteLine($"{l.Id} Q0 {l[j].Description.AsSpan().Trim("#").Trim().ToString()} {j + 1} {SimpleMath.Round(1.0 - 0.0001 * j, 5)} indri");
}
}
catch (IOException ex)
Expand Down Expand Up @@ -848,7 +862,7 @@ public void Rank(List<string> modelFiles, string testFile, string indriRankingFi
}
catch (IOException ex)
{
throw RankLibException.Create("Error in Evaluator::Rank(): ", ex);
throw RankLibException.Create("Error ranking and writing indri ranking file", ex);
}
}

Expand Down Expand Up @@ -878,15 +892,14 @@ public void Rank(List<string> modelFiles, List<string> testFiles, string indriRa
for (var j = 0; j < idx.Length; j++)
{
var k = idx[j];
var str = $"{l.Id} Q0 {l[k].Description.Replace("#", "").Trim()} {j + 1} {SimpleMath.Round(scores[k], 5)} indri";
outWriter.WriteLine(str);
outWriter.WriteLine($"{l.Id} Q0 {l[k].Description.AsSpan().Trim("#").Trim().ToString()} {j + 1} {SimpleMath.Round(scores[k], 5)} indri");
}
}
}
}
catch (IOException ex)
{
throw RankLibException.Create("Error in Evaluator::Rank(): ", ex);
throw RankLibException.Create("Error ranking and writing indri ranking file", ex);
}
}

Expand Down
Loading

0 comments on commit 772082f

Please sign in to comment.