Skip to content

Commit

Permalink
Merge pull request #84 from Encamina/@lmarcos/semantic_text_chunker
Browse files Browse the repository at this point in the history
Added semantic text chunker
  • Loading branch information
LuisM000 authored Mar 7, 2024
2 parents 67660f9 + 560de83 commit 1f16fe4
Show file tree
Hide file tree
Showing 14 changed files with 506 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ Previous classification is not required if changes are simple or all belong to t
- Updated `xunit.analyzers` from `1.10.0` to `1.11.0`.
- Updated `xunit.extensibility.core` from `2.6.6` to `2.7.0`.
- Updated `xunit.runner.visualstudio` from `2.5.6` to `2.5.7`.
- Added new interface `Encamina.Enmarcha.AI.Abstractions.ISemanticTextSplitter` and its implementations `Encamina.Enmarcha.AI.SemanticTextSplitter` to split a text into meaningful chunks based on embeddings.
- Added a new utility class for mathematical operations `Encamina.Enmarcha.Core.MathUtils`.

### Minor Changes

Expand Down
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

<PropertyGroup>
<VersionPrefix>8.1.5</VersionPrefix>
<VersionSuffix>preview-04</VersionSuffix>
<VersionSuffix>preview-05</VersionSuffix>
</PropertyGroup>

<!--
Expand Down
7 changes: 7 additions & 0 deletions Enmarcha.sln
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Encamina.Enmarcha.Data.Azur
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Encamina.Enmarcha.AspNet.OpenApi", "src\Encamina.Enmarcha.AspNet.OpenApi\Encamina.Enmarcha.AspNet.OpenApi.csproj", "{0EFAA5CF-7106-40E0-A427-1CFBFFAEA3EC}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Encamina.Enmarcha.Core.Tests", "tst\Encamina.Enmarcha.Core.Tests\Encamina.Enmarcha.Core.Tests.csproj", "{0516ADAE-C543-4B48-94EE-AC535DEFED0E}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -370,6 +372,10 @@ Global
{0EFAA5CF-7106-40E0-A427-1CFBFFAEA3EC}.Debug|Any CPU.Build.0 = Debug|Any CPU
{0EFAA5CF-7106-40E0-A427-1CFBFFAEA3EC}.Release|Any CPU.ActiveCfg = Release|Any CPU
{0EFAA5CF-7106-40E0-A427-1CFBFFAEA3EC}.Release|Any CPU.Build.0 = Release|Any CPU
{0516ADAE-C543-4B48-94EE-AC535DEFED0E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{0516ADAE-C543-4B48-94EE-AC535DEFED0E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{0516ADAE-C543-4B48-94EE-AC535DEFED0E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{0516ADAE-C543-4B48-94EE-AC535DEFED0E}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand All @@ -387,6 +393,7 @@ Global
{AA1E5E93-FE02-4395-9260-C7C869F22785} = {43252034-27E2-4981-AC2D-EA986B287863}
{7F3ECD81-28E6-4000-9005-1B2ABA8EC1C5} = {CBD50B5F-AFB8-4DA1-9FD7-17D98EB3ED78}
{7B6F4DC4-74E2-4013-8DBA-12B7AAAD5278} = {CBD50B5F-AFB8-4DA1-9FD7-17D98EB3ED78}
{0516ADAE-C543-4B48-94EE-AC535DEFED0E} = {CBD50B5F-AFB8-4DA1-9FD7-17D98EB3ED78}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {F30DF47A-541C-4383-BCEB-E4108D06A70E}
Expand Down
22 changes: 22 additions & 0 deletions src/Encamina.Enmarcha.AI.Abstractions/BreakpointThresholdType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
namespace Encamina.Enmarcha.AI.Abstractions;

/// <summary>
/// Type of thresholds used for breakpoints in <see cref="ISemanticTextSplitter"/>.
/// </summary>
public enum BreakpointThresholdType
{
/// <summary>
/// Threshold based on percentiles for breakpoints.
/// </summary>
Percentile,

/// <summary>
/// Threshold based on standard deviations for breakpoints.
/// </summary>
StandardDeviation,

/// <summary>
/// Threshold based on interquartile range for breakpoints.
/// </summary>
Interquartile,
}
16 changes: 16 additions & 0 deletions src/Encamina.Enmarcha.AI.Abstractions/ISemanticTextSplitter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
namespace Encamina.Enmarcha.AI.Abstractions;

/// <summary>
/// Represents a semantic text splitter, which splits a text into semantic chunks based on embeddings.
/// </summary>
public interface ISemanticTextSplitter
{
/// <summary>
/// Splits the input text based on semantic content.
/// </summary>
/// <param name="text">The input text to be split.</param>
/// <param name="embeddingsGenerator">A function to generate embeddings for a list of strings.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
/// <returns>A collection of text splits.</returns>
Task<IEnumerable<string>> SplitAsync(string text, Func<IList<string>, CancellationToken, Task<IList<ReadOnlyMemory<float>>>> embeddingsGenerator, CancellationToken cancellationToken);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using System.ComponentModel.DataAnnotations;

namespace Encamina.Enmarcha.AI.Abstractions;

/// <summary>
/// Options for semantic text splitters.
/// </summary>
public class SemanticTextSplitterOptions
{
/// <summary>
/// Gets size of the buffer used in semantic text splitting. It represents the number of sentences to include on each side of the current sentence within the buffer.
/// </summary>
[Required]
[Range(0, int.MaxValue)]
public int BufferSize { get; init; } = 1;

/// <summary>
/// Gets type of threshold used for identifying breakpoints in the text. It can be based on percentiles, standard deviations, or interquartile range.
/// </summary>
[Required]
public BreakpointThresholdType BreakpointThresholdType { get; init; } = BreakpointThresholdType.Percentile;

/// <summary>
/// Gets amount used in the threshold calculation for identifying breakpoints. The interpretation depends on the selected threshold type.
/// </summary>
/// <remarks>
/// <list type="bullet">
/// <item>
/// For BreakpointThresholdType.Percentile, a valid value is 95.
/// </item>
/// <item>
/// For BreakpointThresholdType.StandardDeviation, a valid value is 3.
/// </item>
/// <item>
/// For BreakpointThresholdType.Interquartile, a valid value is 1.5.
/// </item>
/// </list>
/// </remarks>
[Required]
public float BreakpointThresholdAmount { get; init; } = 95;
}
1 change: 1 addition & 0 deletions src/Encamina.Enmarcha.AI/Encamina.Enmarcha.AI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Options" Version="8.0.2" />
<PackageReference Include="System.Numerics.Tensors" Version="8.0.0" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.Extensions.DependencyInjection;
public static class IServiceCollectionExtensions
{
/// <summary>
/// Adds a defult cognitive service provider to the <see cref="IServiceCollection"/> as singleton.
/// Adds a default cognitive service provider to the <see cref="IServiceCollection"/> as singleton.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> to add services to.</param>
/// <returns>The <see cref="IServiceCollection"/> so that additional calls can be chained.</returns>
Expand Down Expand Up @@ -45,4 +45,14 @@ public static IServiceCollection AddRecursiveCharacterTextSplitter(this IService
{
return services.AddSingleton<ITextSplitter, RecursiveCharacterTextSplitter>();
}

/// <summary>
/// Adds a «Semantic Text Splitter» service as singleton instance of <see cref="ISemanticTextSplitter"/> to the <see cref="IServiceCollection"/>.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> to add services to.</param>
/// <returns>The <see cref="IServiceCollection"/> so that additional calls can be chained.</returns>
public static IServiceCollection AddSemanticTextSplitter(this IServiceCollection services)
{
return services.AddSingleton<ISemanticTextSplitter, SemanticTextSplitter>();
}
}
198 changes: 198 additions & 0 deletions src/Encamina.Enmarcha.AI/TextSplitters/SemanticTextSplitter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
using System.Numerics.Tensors;
using System.Text;
using System.Text.RegularExpressions;

using Encamina.Enmarcha.AI.Abstractions;
using Encamina.Enmarcha.Core;

using Microsoft.Extensions.Options;

namespace Encamina.Enmarcha.AI.TextSplitters;

/// <summary>
/// Implementation of the <see cref="ISemanticTextSplitter"/> interface that utilizes semantic analysis to split a given text into meaningful chunks.
/// It employs a combination of sentence embeddings and cosine similarity to identify breakpoints and create cohesive sentence groups.
/// </summary>
public class SemanticTextSplitter : ISemanticTextSplitter
{
private static readonly Regex SentenceSplitRegex = new(@"(?<=[.?!])\s+", RegexOptions.Compiled, TimeSpan.FromSeconds(30));

private SemanticTextSplitterOptions options;

/// <summary>
/// Initializes a new instance of the <see cref="SemanticTextSplitter"/> class.
/// </summary>
/// <param name="options">The options to use when configuring the semantic text splitter.</param>
public SemanticTextSplitter(IOptionsMonitor<SemanticTextSplitterOptions> options)
{
this.options = options.CurrentValue;

options.OnChange(newOptions => this.options = newOptions);
}

/// <inheritdoc/>
public async Task<IEnumerable<string>> SplitAsync(string text, Func<IList<string>, CancellationToken, Task<IList<ReadOnlyMemory<float>>>> embeddingsGenerator, CancellationToken cancellationToken)
{
// Code inspired by
// https://github.com/run-llama/llama_index/blob/8ed753df970f068f6afc8a83fd51a1f40880de9e/llama-index-packs/llama-index-packs-node-parser-semantic-chunking/llama_index/packs/node_parser_semantic_chunking/base.py
// https://github.com/langchain-ai/langchain/blob/ced5e7bae790cd9ec4e5374f5d070d9f23d6457b/libs/experimental/langchain_experimental/text_splitter.py

// Splitting the text on '.', '?', and '!'
var sentences = SentenceSplitRegex.Split(text).Where(t => !string.IsNullOrEmpty(t)).ToList();
if (sentences.Count == 1)
{
return sentences;
}

// Combine sentences based on buffer size
var combinedSentences = CreateCombinedSentences(sentences, options.BufferSize);

// Generate embeddings for combined sentences
var combinedSentencesEmbeddings = await embeddingsGenerator(combinedSentences, cancellationToken);

// Calculate cosine distances between consecutive sentence embeddings
var distancesToNextSentence = CalculateDistancesToNextSentence(combinedSentencesEmbeddings);

// Calculate threshold for identifying breakpoints
var breakpointDistanceThreshold = CalculateBreakpointThreshold(distancesToNextSentence, options.BreakpointThresholdType, options.BreakpointThresholdAmount);

// Identify indexes above the threshold as breakpoints
var indexesAboveThreshold = distancesToNextSentence
.Select((distance, index) => new { Index = index, Distance = distance })
.Where(item => item.Distance > breakpointDistanceThreshold)
.Select(item => item.Index)
.ToList();

// Slice sentences based on identified breakpoints
var chunks = SliceSentences(sentences, indexesAboveThreshold);

return chunks;
}

/// <summary>
/// Combines sentences based on a specified buffer size, creating cohesive groups for further analysis.
/// Each combined sentence is formed by including neighboring sentences within the specified buffer size before and after the current sentence.
/// </summary>
/// <param name="sentences">The list of sentences to be combined.</param>
/// <param name="bufferSize">The number of sentences to include on each side of the current sentence within the buffer size.</param>
/// <returns>A list of combined sentences.</returns>
private static List<string> CreateCombinedSentences(IList<string> sentences, int bufferSize)
{
var combinedSentences = new List<string>(sentences.Count);

// Iterate through each sentence in the input list to create combined sentences
for (var i = 0; i < sentences.Count; i++)
{
var combinedSentenceBuilder = new StringBuilder();

// Add sentences before the current one, based on the buffer size.
for (var j = i - bufferSize; j < i; j++)
{
if (j >= 0)
{
combinedSentenceBuilder.Append(sentences[j]).Append(' ');
}
}

// Add the current sentence
combinedSentenceBuilder.Append(sentences[i]);

// Add sentences after the current one, based on the buffer size
for (var j = i + 1; j < i + 1 + bufferSize; j++)
{
if (j < sentences.Count)
{
combinedSentenceBuilder.Append(' ').Append(sentences[j]);
}
}

combinedSentences.Add(combinedSentenceBuilder.ToString());
}

return combinedSentences;
}

/// <summary>
/// Calculates the cosine distances between consecutive sentence embeddings.
/// </summary>
/// <param name="embeddings">The list of sentence embeddings to calculate distances.</param>
/// <returns>A list of cosine distances between consecutive sentence embeddings.</returns>
private static List<double> CalculateDistancesToNextSentence(IList<ReadOnlyMemory<float>> embeddings)
{
var distances = new List<double>(embeddings.Count - 1);

for (var i = 0; i < embeddings.Count - 1; i++)
{
var embeddingCurrent = embeddings[i];
var embeddingNext = embeddings[i + 1];

// Calculate cosine similarity
var similarity = TensorPrimitives.CosineSimilarity(embeddingCurrent.Span, embeddingNext.Span);

// Convert to cosine distance
var distance = 1 - similarity;

distances.Add(distance);
}

return distances;
}

/// <summary>
/// Calculates the threshold for identifying breakpoints based on the specified percentile of sorted cosine distances.
/// </summary>
/// <param name="distances">The list of cosine distances between sentence embeddings.</param>
/// <param name="breakpointThresholdType">The type of threshold calculation to be applied.</param>
/// <param name="breakpointThresholdAmount">The amount used in the threshold calculation.</param>
/// <returns>The calculated threshold for identifying breakpoints.</returns>
private static double CalculateBreakpointThreshold(IList<double> distances, BreakpointThresholdType breakpointThresholdType, float breakpointThresholdAmount)
{
switch (breakpointThresholdType)
{
case BreakpointThresholdType.Percentile:
return MathUtils.Percentile(distances, breakpointThresholdAmount);
case BreakpointThresholdType.StandardDeviation:
return (MathUtils.StandardDeviation(distances) * breakpointThresholdAmount) + distances.Average();
case BreakpointThresholdType.Interquartile:
var iqr = MathUtils.InterquartileRange(distances);
return distances.Average() + (breakpointThresholdAmount * iqr);
default:
throw new ArgumentOutOfRangeException(nameof(breakpointThresholdType), breakpointThresholdType, null);
}
}

/// <summary>
/// Slices the sentences based on the provided indexes, creating chunks of text between breakpoints.
/// </summary>
/// <param name="sentences">The list of sentences to be sliced.</param>
/// <param name="indexes">The list of indexes indicating breakpoints in the sentences.</param>
/// <returns>A list of sliced text chunks.</returns>
private static IEnumerable<string> SliceSentences(IList<string> sentences, List<int> indexes)
{
var chunks = new List<string>();
var startIndex = 0;

// Iterate through the breakpoints to slice the sentences
foreach (var index in indexes)
{
// Slice the sentences from the current start index to the end index
var group = sentences.Skip(startIndex).Take(index - startIndex + 1).ToList();

chunks.Add(string.Join(" ", group));

// Update the start index for the next group
startIndex = index + 1;
}

// The last group, if any sentences remain
if (startIndex < sentences.Count)
{
// Get the remaining sentences after the last breakpoint
var remainingGroup = sentences.Skip(startIndex).ToList();

chunks.Add(string.Join(" ", remainingGroup));
}

return chunks;
}
}
Loading

0 comments on commit 1f16fe4

Please sign in to comment.