From 9fcc1b647657fbe491e35c1d8f3f54a150884cd5 Mon Sep 17 00:00:00 2001 From: Badrish Chandramouli Date: Fri, 28 Jun 2024 17:23:46 -0700 Subject: [PATCH] RESP3 versions of ZRANGE and HGETALL (#503) * RESP3 versions of ZRANGE and HGETALL --- libs/common/RespWriteUtils.cs | 60 +++++++++++++ libs/server/Objects/Hash/HashObject.cs | 8 +- libs/server/Objects/Hash/HashObjectImpl.cs | 15 +++- .../Objects/SortedSet/SortedSetObjectImpl.cs | 86 +++++++++---------- libs/server/Resp/Objects/HashCommands.cs | 1 + libs/server/Resp/Objects/SortedSetCommands.cs | 1 + libs/server/Resp/RespServerSession.cs | 15 +++- .../Storage/Session/ObjectStore/Common.cs | 18 +++- 8 files changed, 149 insertions(+), 55 deletions(-) diff --git a/libs/common/RespWriteUtils.cs b/libs/common/RespWriteUtils.cs index 31fc06fd07..3650823d6c 100644 --- a/libs/common/RespWriteUtils.cs +++ b/libs/common/RespWriteUtils.cs @@ -488,6 +488,37 @@ public static bool TryWriteDoubleBulkString(double value, ref byte* curr, byte* return true; } + /// + /// Try to write a double-precision floating-point as bulk string. + /// + /// if the could be written to ; otherwise. + [SkipLocalsInit] + public static bool TryWriteDoubleNumeric(double value, ref byte* curr, byte* end) + { + if (double.IsNaN(value)) + { + return TryWriteNaN_Numeric(value, ref curr, end); + } + else if (double.IsInfinity(value)) + { + return TryWriteInfinity_Numeric(value, ref curr, end); + } + + Span buffer = stackalloc byte[32]; + if (!Utf8Formatter.TryFormat(value, buffer, out var bytesWritten, format: default)) + return false; + + var itemDigits = NumUtils.NumDigits(bytesWritten); + int totalLen = 1 + bytesWritten + 2; + if (totalLen > (int)(end - curr)) + return false; + + *curr++ = (byte)','; + buffer.Slice(0, bytesWritten).CopyTo(new Span(curr, bytesWritten)); + curr += bytesWritten; + WriteNewline(ref curr); + return true; + } [MethodImpl(MethodImplOptions.NoInlining)] private static bool TryWriteInfinity(double value, ref byte* curr, byte* end) @@ -508,6 +539,25 @@ private static bool TryWriteInfinity(double value, ref byte* curr, byte* end) return true; } + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool TryWriteInfinity_Numeric(double value, ref byte* curr, byte* end) + { + var buffer = new Span(curr, (int)(end - curr)); + if (double.IsPositiveInfinity(value)) + { + if (!",+inf\r\n"u8.TryCopyTo(buffer)) + return false; + } + else + { + if (!",-inf\r\n"u8.TryCopyTo(buffer)) + return false; + } + + curr += 7; + return true; + } + [MethodImpl(MethodImplOptions.NoInlining)] private static bool TryWriteNaN(double value, ref byte* curr, byte* end) { @@ -518,6 +568,16 @@ private static bool TryWriteNaN(double value, ref byte* curr, byte* end) return true; } + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool TryWriteNaN_Numeric(double value, ref byte* curr, byte* end) + { + var buffer = new Span(curr, (int)(end - curr)); + if (!",nan\r\n"u8.TryCopyTo(buffer)) + return false; + curr += 6; + return true; + } + /// /// Create header for *Scan output /// *scan commands have an array of two elements diff --git a/libs/server/Objects/Hash/HashObject.cs b/libs/server/Objects/Hash/HashObject.cs index db24c96f5c..cf712f603e 100644 --- a/libs/server/Objects/Hash/HashObject.cs +++ b/libs/server/Objects/Hash/HashObject.cs @@ -115,8 +115,8 @@ public override unsafe bool Operate(ref SpanByte input, ref SpanByteAndMemory ou fixed (byte* _input = input.AsSpan()) fixed (byte* _output = output.SpanByte.AsSpan()) { - var header = (RespInputHeader*)_input; - if (header->type != GarnetObjectType.Hash) + var header = (ObjectInputHeader*)_input; + if (header->header.type != GarnetObjectType.Hash) { //Indicates when there is an incorrect type output.Length = 0; @@ -125,7 +125,7 @@ public override unsafe bool Operate(ref SpanByte input, ref SpanByteAndMemory ou } var previousSize = this.Size; - switch (header->HashOp) + switch (header->header.HashOp) { case HashOperation.HSET: HashSet(_input, input.Length, _output); @@ -140,7 +140,7 @@ public override unsafe bool Operate(ref SpanByte input, ref SpanByteAndMemory ou HashMultipleGet(_input, input.Length, ref output); break; case HashOperation.HGETALL: - HashGetAll(ref output); + HashGetAll(respProtocolVersion: header->arg1, ref output); break; case HashOperation.HDEL: HashDelete(_input, input.Length, _output); diff --git a/libs/server/Objects/Hash/HashObjectImpl.cs b/libs/server/Objects/Hash/HashObjectImpl.cs index 6619a83e3e..87a7dd65ab 100644 --- a/libs/server/Objects/Hash/HashObjectImpl.cs +++ b/libs/server/Objects/Hash/HashObjectImpl.cs @@ -7,7 +7,6 @@ using System.Diagnostics; using System.Globalization; using System.Linq; -using System.Security.Cryptography; using System.Text; using Garnet.common; using Tsavorite.core; @@ -123,7 +122,7 @@ private void HashMultipleGet(byte* input, int length, ref SpanByteAndMemory outp } } - private void HashGetAll(ref SpanByteAndMemory output) + private void HashGetAll(int respProtocolVersion, ref SpanByteAndMemory output) { var isMemory = false; MemoryHandle ptrHandle = default; @@ -135,8 +134,16 @@ private void HashGetAll(ref SpanByteAndMemory output) ObjectOutputHeader _output = default; try { - while (!RespWriteUtils.WriteArrayLength(hash.Count * 2, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + if (respProtocolVersion < 3) + { + while (!RespWriteUtils.WriteArrayLength(hash.Count * 2, ref curr, end)) + ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + } + else + { + while (!RespWriteUtils.WriteMapLength(hash.Count, ref curr, end)) + ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + } foreach (var item in hash) { diff --git a/libs/server/Objects/SortedSet/SortedSetObjectImpl.cs b/libs/server/Objects/SortedSet/SortedSetObjectImpl.cs index 3dedf296f2..39db516e25 100644 --- a/libs/server/Objects/SortedSet/SortedSetObjectImpl.cs +++ b/libs/server/Objects/SortedSet/SortedSetObjectImpl.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; -using System.Security.Cryptography; using System.Text; using Garnet.common; using Tsavorite.core; @@ -323,6 +322,7 @@ private void SortedSetRange(byte* input, int length, ref SpanByteAndMemory outpu //ZRANGEBYSCORE key min max [WITHSCORES] [LIMIT offset count] var _input = (ObjectInputHeader*)input; int count = _input->arg1; + int respProtocolVersion = _input->arg2; byte* input_startptr = input + sizeof(ObjectInputHeader); byte* input_currptr = input_startptr; @@ -410,23 +410,9 @@ private void SortedSetRange(byte* input, int length, ref SpanByteAndMemory outpu if (options.ByScore) { - var scoredElements = GetElementsInRangeByScore(minValue, maxValue, minExclusive, maxExclusive, options.WithScores, options.Reverse, options.ValidLimit, false, options.Limit); - // write the size of the array reply - while (!RespWriteUtils.WriteArrayLength(options.WithScores ? scoredElements.Count * 2 : scoredElements.Count, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); - - foreach (var (score, element) in scoredElements) - { - while (!RespWriteUtils.WriteBulkString(element, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); - if (options.WithScores) - { - while (!RespWriteUtils.TryWriteDoubleBulkString(score, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); - } - } + WriteSortedSetResult(options.WithScores, scoredElements.Count, respProtocolVersion, scoredElements, ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); countDone = _input->arg1; } else @@ -469,20 +455,7 @@ private void SortedSetRange(byte* input, int length, ref SpanByteAndMemory outpu var iterator = options.Reverse ? sortedSet.Reverse() : sortedSet; iterator = iterator.Skip(minIndex).Take(n); - // write the size of the array reply - while (!RespWriteUtils.WriteArrayLength(options.WithScores ? n * 2 : n, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); - - foreach (var (score, element) in iterator) - { - while (!RespWriteUtils.WriteBulkString(element, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); - if (options.WithScores) - { - while (!RespWriteUtils.TryWriteDoubleBulkString(score, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); - } - } + WriteSortedSetResult(options.WithScores, n, respProtocolVersion, iterator, ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); countDone = _input->arg1; } } @@ -502,20 +475,7 @@ private void SortedSetRange(byte* input, int length, ref SpanByteAndMemory outpu } else { - //write the size of the array reply - while (!RespWriteUtils.WriteArrayLength(options.WithScores ? elementsInLex.Count * 2 : elementsInLex.Count, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); - - foreach (var (score, element) in elementsInLex) - { - while (!RespWriteUtils.WriteBulkString(element, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); - if (options.WithScores) - { - while (!RespWriteUtils.TryWriteDoubleBulkString(score, ref curr, end)) - ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); - } - } + WriteSortedSetResult(options.WithScores, elementsInLex.Count, respProtocolVersion, elementsInLex, ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); countDone = _input->arg1; } } @@ -531,6 +491,44 @@ private void SortedSetRange(byte* input, int length, ref SpanByteAndMemory outpu } } + void WriteSortedSetResult(bool withScores, int count, int respProtocolVersion, IEnumerable<(double, byte[])> iterator, ref SpanByteAndMemory output, ref bool isMemory, ref byte* ptr, ref MemoryHandle ptrHandle, ref byte* curr, ref byte* end) + { + if (withScores && respProtocolVersion >= 3) + { + // write the size of the array reply + while (!RespWriteUtils.WriteArrayLength(count, ref curr, end)) + ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + + foreach (var (score, element) in iterator) + { + while (!RespWriteUtils.WriteArrayLength(2, ref curr, end)) + ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + + while (!RespWriteUtils.WriteBulkString(element, ref curr, end)) + ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + while (!RespWriteUtils.TryWriteDoubleNumeric(score, ref curr, end)) + ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + } + } + else + { + // write the size of the array reply + while (!RespWriteUtils.WriteArrayLength(withScores ? count * 2 : count, ref curr, end)) + ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + + foreach (var (score, element) in iterator) + { + while (!RespWriteUtils.WriteBulkString(element, ref curr, end)) + ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + if (withScores) + { + while (!RespWriteUtils.TryWriteDoubleBulkString(score, ref curr, end)) + ObjectUtils.ReallocateOutput(ref output, ref isMemory, ref ptr, ref ptrHandle, ref curr, ref end); + } + } + } + } + private void SortedSetRangeByScore(byte* input, int length, ref SpanByteAndMemory output) { SortedSetRange(input, length, ref output); diff --git a/libs/server/Resp/Objects/HashCommands.cs b/libs/server/Resp/Objects/HashCommands.cs index a97ee7046d..41e6e7c3ec 100644 --- a/libs/server/Resp/Objects/HashCommands.cs +++ b/libs/server/Resp/Objects/HashCommands.cs @@ -197,6 +197,7 @@ private bool HashGetAll(RespCommand command, int count, byte* ptr, r inputPtr->header.type = GarnetObjectType.Hash; inputPtr->header.flags = 0; inputPtr->header.HashOp = HashOperation.HGETALL; + inputPtr->arg1 = respProtocolVersion; // Prepare GarnetObjectStore output var outputFooter = new GarnetObjectStoreOutput { spanByteAndMemory = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)) }; diff --git a/libs/server/Resp/Objects/SortedSetCommands.cs b/libs/server/Resp/Objects/SortedSetCommands.cs index f8709ddcf4..533de7d2ec 100644 --- a/libs/server/Resp/Objects/SortedSetCommands.cs +++ b/libs/server/Resp/Objects/SortedSetCommands.cs @@ -267,6 +267,7 @@ private unsafe bool SortedSetRange(RespCommand command, int count, b inputPtr->header.flags = 0; inputPtr->header.SortedSetOp = op; inputPtr->arg1 = count - 1; + inputPtr->arg2 = respProtocolVersion; var outputFooter = new GarnetObjectStoreOutput { spanByteAndMemory = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)) }; diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index 6652f054e1..059de67dfa 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -270,14 +270,25 @@ public override int TryConsumeMessages(byte* reqBuffer, int bytesReceived) // Send message and dispose the network sender to end the session if (dcurr > networkSender.GetResponseObjectHead()) Send(networkSender.GetResponseObjectHead()); - networkSender.Dispose(); + + // The session is no longer usable, dispose it + networkSender.DisposeNetworkSender(true); } catch (GarnetException ex) { sessionMetrics?.incr_total_number_resp_server_session_exceptions(1); logger?.Log(ex.LogLevel, ex, "ProcessMessages threw a GarnetException:"); + + // Forward Garnet error as RESP error + while (!RespWriteUtils.WriteError($"ERR Garnet Exception: {ex.Message}", ref dcurr, dend)) + SendAndReset(); + + // Send message and dispose the network sender to end the session + if (dcurr > networkSender.GetResponseObjectHead()) + Send(networkSender.GetResponseObjectHead()); + // The session is no longer usable, dispose it - networkSender.Dispose(); + networkSender.DisposeNetworkSender(true); } catch (Exception ex) { diff --git a/libs/server/Storage/Session/ObjectStore/Common.cs b/libs/server/Storage/Session/ObjectStore/Common.cs index bca5c80adb..43f0addf33 100644 --- a/libs/server/Storage/Session/ObjectStore/Common.cs +++ b/libs/server/Storage/Session/ObjectStore/Common.cs @@ -16,6 +16,9 @@ sealed partial class StorageSession : IDisposable unsafe GarnetStatus RMWObjectStoreOperation(byte[] key, ArgSlice input, out ObjectOutputHeader output, ref TObjectContext objectStoreContext) where TObjectContext : ITsavoriteContext { + if (objectStoreContext.Session is null) + StorageSession.ThrowObjectStoreUninitializedException(); + var _input = input.SpanByte; output = new(); @@ -48,6 +51,9 @@ unsafe GarnetStatus RMWObjectStoreOperation(byte[] key, ArgSlice GarnetStatus RMWObjectStoreOperationWithOutput(byte[] key, ArgSlice input, ref TObjectContext objectStoreContext, ref GarnetObjectStoreOutput outputFooter) where TObjectContext : ITsavoriteContext { + if (objectStoreContext.Session is null) + StorageSession.ThrowObjectStoreUninitializedException(); + var _input = input.SpanByte; // Perform RMW on object store @@ -75,6 +81,9 @@ GarnetStatus RMWObjectStoreOperationWithOutput(byte[] key, ArgSl GarnetStatus ReadObjectStoreOperationWithOutput(byte[] key, ArgSlice input, ref TObjectContext objectStoreContext, ref GarnetObjectStoreOutput outputFooter) where TObjectContext : ITsavoriteContext { + if (objectStoreContext.Session is null) + StorageSession.ThrowObjectStoreUninitializedException(); + var _input = input.SpanByte; // Perform read on object store @@ -221,8 +230,11 @@ unsafe ArgSlice ProcessRespSingleTokenOutput(GarnetObjectStoreOutput outputFoote /// /// unsafe GarnetStatus ReadObjectStoreOperation(byte[] key, ArgSlice input, out ObjectOutputHeader output, ref TObjectContext objectStoreContext) - where TObjectContext : ITsavoriteContext + where TObjectContext : ITsavoriteContext { + if (objectStoreContext.Session is null) + StorageSession.ThrowObjectStoreUninitializedException(); + var _input = input.SpanByte; output = new(); @@ -256,6 +268,10 @@ public GarnetStatus ObjectScan(byte[] key, ArgSlice input, ref G where TObjectContext : ITsavoriteContext => ReadObjectStoreOperationWithOutput(key, input, ref objectStoreContext, ref outputFooter); + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowObjectStoreUninitializedException() + => throw new GarnetException("Object store is disabled"); + #endregion } } \ No newline at end of file