Skip to content

Commit

Permalink
Fix latest comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kbeaugrand committed Jun 14, 2024
1 parent eb8dc1c commit c508659
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public async Task CreateCollectionAsync(string collectionName, CancellationToken
using var command = this._connection.CreateCommand();

command.CommandText = $@"
IF OBJECT_ID(N'{this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}', N'U') IS NULL
CREATE TABLE {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}
IF OBJECT_ID(N'{this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}', N'U') IS NULL
CREATE TABLE {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}
( [id] UNIQUEIDENTIFIER NOT NULL,
[key] NVARCHAR(256) NOT NULL,
[metadata] TEXT,
Expand All @@ -61,17 +61,17 @@ [key] NVARCHAR(256) NOT NULL,
PRIMARY KEY ([id])
);
IF OBJECT_ID(N'{this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}', N'U') IS NULL
CREATE TABLE {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}
IF OBJECT_ID(N'{this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}', N'U') IS NULL
CREATE TABLE {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}
(
[memory_id] UNIQUEIDENTIFIER NOT NULL,
[vector_value_id] [int] NOT NULL,
[vector_value] [float] NOT NULL
);
IF OBJECT_ID(N'{NormalizeSQLObjectName(this._configuration.Schema)}.IXC_{$"{NormalizeSQLObjectName(this._configuration.EmbeddingsTableName)}_{collectionName}"}', N'U') IS NULL
CREATE CLUSTERED COLUMNSTORE INDEX [IXC_{$"{NormalizeSQLObjectName(this._configuration.EmbeddingsTableName)}_{collectionName}]"}
ON {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")};";
IF OBJECT_ID(N'{NormalizeSQLObjectName(this._configuration.Schema)}.IXC_{$"{NormalizeSQLObjectName(this._configuration.EmbeddingsTableNamePrefix)}_{collectionName}"}', N'U') IS NULL
CREATE CLUSTERED COLUMNSTORE INDEX [IXC_{$"{NormalizeSQLObjectName(this._configuration.EmbeddingsTableNamePrefix)}_{collectionName}]"}
ON {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")};";

command.Parameters.AddWithValue("@collectionName", collectionName);

Expand All @@ -85,15 +85,26 @@ public async Task<bool> DoesCollectionExistsAsync(string collectionName,
{
collectionName = NormalizeIndexName(collectionName);

var collections = this.GetCollectionsAsync(cancellationToken)
.WithCancellation(cancellationToken)
.ConfigureAwait(false);

await foreach (var item in collections)
using (await this.OpenConnectionAsync(cancellationToken).ConfigureAwait(false))
{
if (item.Equals(collectionName, StringComparison.OrdinalIgnoreCase))
using var command = this._connection.CreateCommand();

command.CommandText = """
SELECT 1
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema = @schema
AND table_name = @tableName
""";

command.Parameters.AddWithValue("@schema", this._configuration.Schema);
command.Parameters.AddWithValue("@tableName", $"{this._configuration.CollectionTableNamePrefix}_{collectionName}");

using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);

while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
return true;
return Convert.ToBoolean(dataReader.GetInt32(0));
}
}

Expand All @@ -116,14 +127,14 @@ AND table_name LIKE @tableName
""";

command.Parameters.AddWithValue("@schema", this._configuration.Schema);
command.Parameters.AddWithValue("@tableName", $"{this._configuration.MemoryTableName}_%");
command.Parameters.AddWithValue("@tableName", $"{this._configuration.CollectionTableNamePrefix}_%");

using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);

while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
yield return dataReader.GetString(dataReader.GetOrdinal("table_name"))
.Replace($"{this._configuration.MemoryTableName}_", string.Empty);
.Replace($"{this._configuration.CollectionTableNamePrefix}_", string.Empty);
}
}
}
Expand All @@ -143,8 +154,8 @@ public async Task DeleteCollectionAsync(string collectionName, CancellationToken
{
using var command = this._connection.CreateCommand();

command.CommandText = $@"DROP TABLE {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")};
DROP TABLE {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")};";
command.CommandText = $@"DROP TABLE {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")};
DROP TABLE {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")};";

await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
Expand All @@ -168,7 +179,7 @@ public async Task DeleteCollectionAsync(string collectionName, CancellationToken

command.CommandText = $@"
SELECT {queryColumns}
FROM {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}
FROM {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}
WHERE [key] = @key";

command.Parameters.AddWithValue("@key", key);
Expand Down Expand Up @@ -210,7 +221,7 @@ public async IAsyncEnumerable<SqlServerMemoryEntry> ReadBatchAsync(string collec

command.CommandText = $@"
SELECT {queryColumns}
FROM {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}
FROM {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}
WHERE [key] IN ({string.Join(",", Enumerable.Range(0, keysArray.Length).Select(c => $"@key{c}"))})";

for (int i = 0; i < keysArray.Length; i++)
Expand All @@ -236,7 +247,7 @@ public async Task DeleteAsync(string collectionName, string key, CancellationTok
{
using var command = this._connection.CreateCommand();

command.CommandText = $"DELETE FROM {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")} WHERE [key] = @key";
command.CommandText = $"DELETE FROM {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")} WHERE [key] = @key";
command.Parameters.AddWithValue("@key", key);

await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
Expand All @@ -261,7 +272,7 @@ public async Task DeleteBatchAsync(string collectionName, IEnumerable<string> ke

command.CommandText = $@"
DELETE
FROM {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}
FROM {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}
WHERE [key] IN ({string.Join(",", Enumerable.Range(0, keysArray.Length).Select(c => $"@key{c}"))})";

for (int i = 0; i < keysArray.Length; i++)
Expand Down Expand Up @@ -302,42 +313,42 @@ public async Task DeleteBatchAsync(string collectionName, IEnumerable<string> ke
[similarity] AS
(
SELECT TOP (@limit)
{this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}.[memory_id],
SUM([embedding].[vector_value] * {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}.[vector_value]) /
{this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}.[memory_id],
SUM([embedding].[vector_value] * {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}.[vector_value]) /
(
SQRT(SUM([embedding].[vector_value] * [embedding].[vector_value]))
*
SQRT(SUM({this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}.[vector_value] * {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}.[vector_value]))
SQRT(SUM({this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}.[vector_value] * {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}.[vector_value]))
) AS cosine_similarity
-- sum([embedding].[vector_value] * {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}.[vector_value]) as cosine_distance -- Optimized as per https://platform.openai.com/docs/guides/embeddings/which-distance-function-should-i-use
-- sum([embedding].[vector_value] * {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}.[vector_value]) as cosine_distance -- Optimized as per https://platform.openai.com/docs/guides/embeddings/which-distance-function-should-i-use
FROM
[embedding]
INNER JOIN
{this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")} ON [embedding].vector_value_id = {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}.vector_value_id
{this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")} ON [embedding].vector_value_id = {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}.vector_value_id
GROUP BY
{this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}.[memory_id]
{this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}.[memory_id]
ORDER BY
cosine_similarity DESC
)
SELECT
{this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[id],
{this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[key],
{this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[metadata],
{this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[timestamp],
{this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[embedding],
{this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[id],
{this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[key],
{this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[metadata],
{this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[timestamp],
{this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[embedding],
(
SELECT
[vector_value]
FROM {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}
WHERE {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[id] = {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")}.[memory_id]
FROM {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}
WHERE {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[id] = {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")}.[memory_id]
ORDER BY vector_value_id
FOR JSON AUTO
) AS [embeddings],
[similarity].[cosine_similarity]
FROM
[similarity]
INNER JOIN
{this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")} ON [similarity].[memory_id] = {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[id]
{this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")} ON [similarity].[memory_id] = {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[id]
WHERE [cosine_similarity] >= @min_relevance_score
ORDER BY [cosine_similarity] desc";

Expand Down Expand Up @@ -370,25 +381,25 @@ public async Task UpsertAsync(string collectionName,
using var command = this._connection.CreateCommand();

command.CommandText = $@"
MERGE INTO {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}
MERGE INTO {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}
USING (SELECT @key) as [src]([key])
ON {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[key] = [src].[key]
ON {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[key] = [src].[key]
WHEN MATCHED THEN
UPDATE SET metadata = @metadata, embedding = @embedding, timestamp = @timestamp
WHEN NOT MATCHED THEN
INSERT ([id], [key], [metadata], [timestamp], [embedding])
VALUES (NEWID(), @key, @metadata, @timestamp, @embedding);
MERGE {this.GetFullTableName($"{this._configuration.EmbeddingsTableName}_{collectionName}")} AS [tgt]
MERGE {this.GetFullTableName($"{this._configuration.EmbeddingsTableNamePrefix}_{collectionName}")} AS [tgt]
USING (
SELECT
{this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[id],
{this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[id],
cast([vector].[key] AS INT) AS [vector_value_id],
cast([vector].[value] AS FLOAT) AS [vector_value]
FROM {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}
FROM {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}
CROSS APPLY
openjson(@embedding) [vector]
WHERE {this.GetFullTableName($"{this._configuration.MemoryTableName}_{collectionName}")}.[key] = @key
WHERE {this.GetFullTableName($"{this._configuration.CollectionTableNamePrefix}_{collectionName}")}.[key] = @key
) AS [src]
ON [tgt].[memory_id] = [src].[id] AND [tgt].[vector_value_id] = [src].[vector_value_id]
WHEN MATCHED THEN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ public class SqlServerConfig
/// <summary>
/// The default SQL Server memories table name.
/// </summary>
internal const string DefaultMemoryTableName = "SKMemories";
internal const string DefaultMemoryTableNamePrefix = "SKMemories";

/// <summary>
/// The default SQL Server embeddings table name.
/// </summary>
internal const string DefaultEmbeddingsTableName = "SKEmbeddings";
internal const string DefaultEmbeddingsTableNamePrefix = "SKEmbeddings";

/// <summary>
/// The default schema used by the SQL Server memory store.
Expand All @@ -33,12 +33,14 @@ public class SqlServerConfig
public string Schema { get; set; } = DefaultSchema;

/// <summary>
/// The SQL Server memories table name.
/// The SQL Server memories table name prefix.
/// When creating a collection, real table name will be '{CollectionTableNamePrefix}_{CollectionName}'.
/// </summary>
public string MemoryTableName { get; set; } = DefaultMemoryTableName;
public string CollectionTableNamePrefix { get; set; } = DefaultMemoryTableNamePrefix;

/// <summary>
/// The SQL Server embeddings table name.
/// The SQL Server embeddings table name prefix.
/// When creating a collection, real table name will be '{EmbeddingsTableNamePrefix}_{CollectionName}'.
/// </summary>
public string EmbeddingsTableName { get; set; } = DefaultEmbeddingsTableName;
public string EmbeddingsTableNamePrefix { get; set; } = DefaultEmbeddingsTableNamePrefix;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Configuration;
using Microsoft.SemanticKernel.Connectors.SqlServer.Classic;
using Microsoft.SemanticKernel.Memory;
using MongoDB.Driver;
using SharpYaml.Schemas;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.SqlServer;
Expand All @@ -21,9 +19,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.SqlServer;
public class ClassicSqlServerMemoryStoreTests : IAsyncLifetime
{
// If null, all tests will be enabled
//private const string? SkipReason = "Requires SqlServer server up and running";

private const string? SkipReason = null;
private const string? SkipReason = "Requires SqlServer server up and running";

public async Task InitializeAsync()
{
Expand All @@ -46,7 +42,7 @@ public async Task InitializeAsync()
this._config = new SqlServerConfig
{
Schema = "sk_it",
MemoryTableName = "SKMemories",
CollectionTableNamePrefix = "SKMemories",
};

await this.CleanupDatabaseAsync();
Expand Down

0 comments on commit c508659

Please sign in to comment.