diff --git a/csharp/lib/AsyncClient.cs b/csharp/lib/AsyncClient.cs index 83e3d4c39b..38ee7984f9 100644 --- a/csharp/lib/AsyncClient.cs +++ b/csharp/lib/AsyncClient.cs @@ -2,6 +2,7 @@ * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ +using System.Buffers; using System.Runtime.InteropServices; namespace Glide; @@ -22,18 +23,35 @@ public AsyncClient(string host, UInt32 port, bool useTLS) } } - public async Task SetAsync(string key, string value) + private async Task command(IntPtr[] args, int argsCount, RequestType requestType) { - var message = messageContainer.GetMessageForCall(key, value); - SetFfi(clientPointer, (ulong)message.Index, message.KeyPtr, message.ValuePtr); - await message; + // We need to pin the array in place, in order to ensure that the GC doesn't move it while the operation is running. + GCHandle pinnedArray = GCHandle.Alloc(args, GCHandleType.Pinned); + IntPtr pointer = pinnedArray.AddrOfPinnedObject(); + var message = messageContainer.GetMessageForCall(args, argsCount); + CommandFfi(clientPointer, (ulong)message.Index, (int)requestType, pointer, (uint)argsCount); + var result = await message; + pinnedArray.Free(); + return result; + } + + public async Task SetAsync(string key, string value) + { + var args = this.arrayPool.Rent(2); + args[0] = Marshal.StringToHGlobalAnsi(key); + args[1] = Marshal.StringToHGlobalAnsi(value); + var result = await command(args, 2, RequestType.SetString); + this.arrayPool.Return(args); + return result; } public async Task GetAsync(string key) { - var message = messageContainer.GetMessageForCall(key, null); - GetFfi(clientPointer, (ulong)message.Index, message.KeyPtr); - return await message; + var args = this.arrayPool.Rent(1); + args[0] = Marshal.StringToHGlobalAnsi(key); + var result = await command(args, 1, RequestType.GetString); + this.arrayPool.Return(args); + return result; } public void Dispose() @@ -89,6 +107,7 @@ private void FailureCallback(ulong index) private IntPtr clientPointer; private readonly MessageContainer messageContainer = new(); + private readonly ArrayPool arrayPool = ArrayPool.Shared; #endregion private fields @@ -96,11 +115,8 @@ private void FailureCallback(ulong index) private delegate void StringAction(ulong index, IntPtr str); private delegate void FailureAction(ulong index); - [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "get")] - private static extern void GetFfi(IntPtr client, ulong index, IntPtr key); - - [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "set")] - private static extern void SetFfi(IntPtr client, ulong index, IntPtr key, IntPtr value); + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "command")] + private static extern void CommandFfi(IntPtr client, ulong index, Int32 requestType, IntPtr args, UInt32 argCount); private delegate void IntAction(IntPtr arg); [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "create_client")] @@ -110,4 +126,114 @@ private void FailureCallback(ulong index) private static extern void CloseClientFfi(IntPtr client); #endregion + + #region RequestType + + // TODO: generate this with a bindings generator + private enum RequestType + { + InvalidRequest = 0, + CustomCommand = 1, + GetString = 2, + SetString = 3, + Ping = 4, + Info = 5, + Del = 6, + Select = 7, + ConfigGet = 8, + ConfigSet = 9, + ConfigResetStat = 10, + ConfigRewrite = 11, + ClientGetName = 12, + ClientGetRedir = 13, + ClientId = 14, + ClientInfo = 15, + ClientKill = 16, + ClientList = 17, + ClientNoEvict = 18, + ClientNoTouch = 19, + ClientPause = 20, + ClientReply = 21, + ClientSetInfo = 22, + ClientSetName = 23, + ClientUnblock = 24, + ClientUnpause = 25, + Expire = 26, + HashSet = 27, + HashGet = 28, + HashDel = 29, + HashExists = 30, + MGet = 31, + MSet = 32, + Incr = 33, + IncrBy = 34, + Decr = 35, + IncrByFloat = 36, + DecrBy = 37, + HashGetAll = 38, + HashMSet = 39, + HashMGet = 40, + HashIncrBy = 41, + HashIncrByFloat = 42, + LPush = 43, + LPop = 44, + RPush = 45, + RPop = 46, + LLen = 47, + LRem = 48, + LRange = 49, + LTrim = 50, + SAdd = 51, + SRem = 52, + SMembers = 53, + SCard = 54, + PExpireAt = 55, + PExpire = 56, + ExpireAt = 57, + Exists = 58, + Unlink = 59, + TTL = 60, + Zadd = 61, + Zrem = 62, + Zrange = 63, + Zcard = 64, + Zcount = 65, + ZIncrBy = 66, + ZScore = 67, + Type = 68, + HLen = 69, + Echo = 70, + ZPopMin = 71, + Strlen = 72, + Lindex = 73, + ZPopMax = 74, + XRead = 75, + XAdd = 76, + XReadGroup = 77, + XAck = 78, + XTrim = 79, + XGroupCreate = 80, + XGroupDestroy = 81, + HSetNX = 82, + SIsMember = 83, + Hvals = 84, + PTTL = 85, + ZRemRangeByRank = 86, + Persist = 87, + ZRemRangeByScore = 88, + Time = 89, + Zrank = 90, + Rename = 91, + DBSize = 92, + Brpop = 93, + Hkeys = 94, + PfAdd = 96, + PfCount = 97, + PfMerge = 98, + Blpop = 100, + RPushX = 102, + LPushX = 103, + } + + #endregion } diff --git a/csharp/lib/Message.cs b/csharp/lib/Message.cs index c0d4c7f07b..5ecb994d0b 100644 --- a/csharp/lib/Message.cs +++ b/csharp/lib/Message.cs @@ -17,11 +17,10 @@ internal class Message : INotifyCompletion /// know how to find the message and set its result. public int Index { get; } - /// The pointer to the unmanaged memory that contains the operation's key. - public IntPtr KeyPtr { get; private set; } - - /// The pointer to the unmanaged memory that contains the operation's key. - public IntPtr ValuePtr { get; private set; } + /// The array holding the pointers to the unmanaged memory that contains the operation's arguments. + public IntPtr[]? args { get; private set; } + // We need to save the args count, because sometimes we get arrays that are larger than they need to be. We can't rely on `this.args.Length`, due to it coming from an array pool. + private int argsCount; private readonly MessageContainer container; public Message(int index, MessageContainer container) @@ -84,30 +83,29 @@ private void CheckRaceAndCallContinuation() /// This returns a task that will complete once SetException / SetResult are called, /// and ensures that the internal state of the message is set-up before the task is created, /// and cleaned once it is complete. - public void StartTask(string? key, string? value, object client) + public void SetupTask(IntPtr[] args, int argsCount, object client) { continuation = null; this.completionState = COMPLETION_STAGE_STARTED; this.result = default(T); this.exception = null; this.client = client; - this.KeyPtr = key is null ? IntPtr.Zero : Marshal.StringToHGlobalAnsi(key); - this.ValuePtr = value is null ? IntPtr.Zero : Marshal.StringToHGlobalAnsi(value); + this.args = args; + this.argsCount = argsCount; } // This function isn't thread-safe. Access to it should be from a single thread, and only once per operation. // For the sake of performance, this responsibility is on the caller, and the function doesn't contain any safety measures. private void FreePointers() { - if (KeyPtr != IntPtr.Zero) - { - Marshal.FreeHGlobal(KeyPtr); - KeyPtr = IntPtr.Zero; - } - if (ValuePtr != IntPtr.Zero) + if (this.args is not null) { - Marshal.FreeHGlobal(ValuePtr); - ValuePtr = IntPtr.Zero; + for (var i = 0; i < this.argsCount; i++) + { + Marshal.FreeHGlobal(this.args[i]); + } + this.args = null; + this.argsCount = 0; } client = null; } diff --git a/csharp/lib/MessageContainer.cs b/csharp/lib/MessageContainer.cs index faa1b5a277..14d4267723 100644 --- a/csharp/lib/MessageContainer.cs +++ b/csharp/lib/MessageContainer.cs @@ -11,10 +11,10 @@ internal class MessageContainer { internal Message GetMessage(int index) => messages[index]; - internal Message GetMessageForCall(string? key, string? value) + internal Message GetMessageForCall(IntPtr[] args, int argsCount) { var message = GetFreeMessage(); - message.StartTask(key, value, this); + message.SetupTask(args, argsCount, this); return message; } diff --git a/csharp/lib/src/lib.rs b/csharp/lib/src/lib.rs index 8baa6d0155..fce015a376 100644 --- a/csharp/lib/src/lib.rs +++ b/csharp/lib/src/lib.rs @@ -3,7 +3,8 @@ */ use glide_core::client; use glide_core::client::Client as GlideClient; -use redis::{Cmd, FromRedisValue, RedisResult}; +use glide_core::request_type::RequestType; +use redis::{FromRedisValue, RedisResult}; use std::{ ffi::{c_void, CStr, CString}, os::raw::c_char, @@ -91,61 +92,43 @@ pub extern "C" fn close_client(client_ptr: *const c_void) { /// Expects that key and value will be kept valid until the callback is called. #[no_mangle] -pub extern "C" fn set( +pub extern "C" fn command( client_ptr: *const c_void, callback_index: usize, - key: *const c_char, - value: *const c_char, + request_type: RequestType, + args: *const *mut c_char, + arg_count: u32, ) { let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) }; - // The safety of this needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed. - let ptr_address = client_ptr as usize; - - let key_cstring = unsafe { CStr::from_ptr(key as *mut c_char) }; - let value_cstring = unsafe { CStr::from_ptr(value as *mut c_char) }; - let mut client_clone = client.client.clone(); - client.runtime.spawn(async move { - let key_bytes = key_cstring.to_bytes(); - let value_bytes = value_cstring.to_bytes(); - let mut cmd = Cmd::new(); - cmd.arg("SET").arg(key_bytes).arg(value_bytes); - let result = client_clone.send_command(&cmd, None).await; - unsafe { - let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); - match result { - Ok(_) => (client.success_callback)(callback_index, std::ptr::null()), // TODO - should return "OK" string. - Err(_) => (client.failure_callback)(callback_index), // TODO - report errors - }; - } - }); -} -/// Expects that key will be kept valid until the callback is called. If the callback is called with a string pointer, the pointer must -/// be used synchronously, because the string will be dropped after the callback. -#[no_mangle] -pub extern "C" fn get(client_ptr: *const c_void, callback_index: usize, key: *const c_char) { - let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) }; - // The safety of this needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed. + // The safety of these needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed. let ptr_address = client_ptr as usize; + let args_address = args as usize; - let key_cstring = unsafe { CStr::from_ptr(key as *mut c_char) }; let mut client_clone = client.client.clone(); client.runtime.spawn(async move { - let key_bytes = key_cstring.to_bytes(); - let mut cmd = Cmd::new(); - cmd.arg("GET").arg(key_bytes); - let result = client_clone.send_command(&cmd, None).await; - let client = unsafe { Box::leak(Box::from_raw(ptr_address as *mut Client)) }; - let value = match result { - Ok(value) => value, - Err(_) => { - unsafe { (client.failure_callback)(callback_index) }; // TODO - report errors, + let Some(mut cmd) = request_type.get_command() else { + unsafe { + let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); + (client.failure_callback)(callback_index); // TODO - report errors return; } }; - let result = Option::::from_owned_redis_value(value); + let args_slice = unsafe { + std::slice::from_raw_parts(args_address as *const *mut c_char, arg_count as usize) + }; + for arg in args_slice { + let c_str = unsafe { CStr::from_ptr(*arg as *mut c_char) }; + cmd.arg(c_str.to_bytes()); + } + + let result = client_clone + .send_command(&cmd, None) + .await + .and_then(Option::::from_owned_redis_value); unsafe { + let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); match result { Ok(None) => (client.success_callback)(callback_index, std::ptr::null()), Ok(Some(c_str)) => (client.success_callback)(callback_index, c_str.as_ptr()), diff --git a/csharp/tests/Integration/GetAndSet.cs b/csharp/tests/Integration/GetAndSet.cs index 98ac38beaa..17362e3cc1 100644 --- a/csharp/tests/Integration/GetAndSet.cs +++ b/csharp/tests/Integration/GetAndSet.cs @@ -12,13 +12,19 @@ namespace tests.Integration; public class GetAndSet { + private async Task GetAndSetValues(AsyncClient client, string key, string value) + { + var setResult = await client.SetAsync(key, value); + Assert.That(setResult, Is.EqualTo("OK")); + var result = await client.GetAsync(key); + Assert.That(result, Is.EqualTo(value)); + } + private async Task GetAndSetRandomValues(AsyncClient client) { var key = Guid.NewGuid().ToString(); var value = Guid.NewGuid().ToString(); - await client.SetAsync(key, value); - var result = await client.GetAsync(key); - Assert.That(result, Is.EqualTo(value)); + await GetAndSetValues(client, key, value); } [Test] @@ -37,9 +43,7 @@ public async Task GetAndSetCanHandleNonASCIIUnicode() { var key = Guid.NewGuid().ToString(); var value = "שלום hello 汉字"; - await client.SetAsync(key, value); - var result = await client.GetAsync(key); - Assert.That(result, Is.EqualTo(value)); + await GetAndSetValues(client, key, value); } } @@ -60,9 +64,7 @@ public async Task GetReturnsEmptyString() { var key = Guid.NewGuid().ToString(); var value = ""; - await client.SetAsync(key, value); - var result = await client.GetAsync(key); - Assert.That(result, Is.EqualTo(value)); + await GetAndSetValues(client, key, value); } } @@ -81,9 +83,7 @@ public async Task HandleVeryLargeInput() { value += value; } - await client.SetAsync(key, value); - var result = await client.GetAsync(key); - Assert.That(result, Is.EqualTo(value)); + await GetAndSetValues(client, key, value); } } diff --git a/glide-core/src/lib.rs b/glide-core/src/lib.rs index bd194f008f..f904928be1 100644 --- a/glide-core/src/lib.rs +++ b/glide-core/src/lib.rs @@ -15,3 +15,4 @@ pub use socket_listener::*; pub mod errors; pub mod scripts_container; pub use client::ConnectionRequest; +pub mod request_type; diff --git a/glide-core/src/request_type.rs b/glide-core/src/request_type.rs new file mode 100644 index 0000000000..46a49fd20f --- /dev/null +++ b/glide-core/src/request_type.rs @@ -0,0 +1,337 @@ +/** + * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 + */ +use redis::{cmd, Cmd}; + +#[cfg(feature = "socket-layer")] +use crate::redis_request::RequestType as ProtobufRequestType; + +#[repr(C)] +#[derive(Debug)] +pub enum RequestType { + InvalidRequest = 0, + CustomCommand = 1, + GetString = 2, + SetString = 3, + Ping = 4, + Info = 5, + Del = 6, + Select = 7, + ConfigGet = 8, + ConfigSet = 9, + ConfigResetStat = 10, + ConfigRewrite = 11, + ClientGetName = 12, + ClientGetRedir = 13, + ClientId = 14, + ClientInfo = 15, + ClientKill = 16, + ClientList = 17, + ClientNoEvict = 18, + ClientNoTouch = 19, + ClientPause = 20, + ClientReply = 21, + ClientSetInfo = 22, + ClientSetName = 23, + ClientUnblock = 24, + ClientUnpause = 25, + Expire = 26, + HashSet = 27, + HashGet = 28, + HashDel = 29, + HashExists = 30, + MGet = 31, + MSet = 32, + Incr = 33, + IncrBy = 34, + Decr = 35, + IncrByFloat = 36, + DecrBy = 37, + HashGetAll = 38, + HashMSet = 39, + HashMGet = 40, + HashIncrBy = 41, + HashIncrByFloat = 42, + LPush = 43, + LPop = 44, + RPush = 45, + RPop = 46, + LLen = 47, + LRem = 48, + LRange = 49, + LTrim = 50, + SAdd = 51, + SRem = 52, + SMembers = 53, + SCard = 54, + PExpireAt = 55, + PExpire = 56, + ExpireAt = 57, + Exists = 58, + Unlink = 59, + TTL = 60, + Zadd = 61, + Zrem = 62, + Zrange = 63, + Zcard = 64, + Zcount = 65, + ZIncrBy = 66, + ZScore = 67, + Type = 68, + HLen = 69, + Echo = 70, + ZPopMin = 71, + Strlen = 72, + Lindex = 73, + ZPopMax = 74, + XRead = 75, + XAdd = 76, + XReadGroup = 77, + XAck = 78, + XTrim = 79, + XGroupCreate = 80, + XGroupDestroy = 81, + HSetNX = 82, + SIsMember = 83, + Hvals = 84, + PTTL = 85, + ZRemRangeByRank = 86, + Persist = 87, + ZRemRangeByScore = 88, + Time = 89, + Zrank = 90, + Rename = 91, + DBSize = 92, + Brpop = 93, + Hkeys = 94, + PfAdd = 96, + PfCount = 97, + PfMerge = 98, + Blpop = 100, + RPushX = 102, + LPushX = 103, +} + +fn get_two_word_command(first: &str, second: &str) -> Cmd { + let mut cmd = cmd(first); + cmd.arg(second); + cmd +} + +#[cfg(feature = "socket-layer")] +impl From<::protobuf::EnumOrUnknown> for RequestType { + fn from(value: ::protobuf::EnumOrUnknown) -> Self { + match value.enum_value_or(ProtobufRequestType::InvalidRequest) { + ProtobufRequestType::InvalidRequest => RequestType::InvalidRequest, + ProtobufRequestType::CustomCommand => RequestType::CustomCommand, + ProtobufRequestType::GetString => RequestType::GetString, + ProtobufRequestType::SetString => RequestType::SetString, + ProtobufRequestType::Ping => RequestType::Ping, + ProtobufRequestType::Info => RequestType::Info, + ProtobufRequestType::Del => RequestType::Del, + ProtobufRequestType::Select => RequestType::Select, + ProtobufRequestType::ConfigGet => RequestType::ConfigGet, + ProtobufRequestType::ConfigSet => RequestType::ConfigSet, + ProtobufRequestType::ConfigResetStat => RequestType::ConfigResetStat, + ProtobufRequestType::ConfigRewrite => RequestType::ConfigRewrite, + ProtobufRequestType::ClientGetName => RequestType::ClientGetName, + ProtobufRequestType::ClientGetRedir => RequestType::ClientGetRedir, + ProtobufRequestType::ClientId => RequestType::ClientId, + ProtobufRequestType::ClientInfo => RequestType::ClientInfo, + ProtobufRequestType::ClientKill => RequestType::ClientKill, + ProtobufRequestType::ClientList => RequestType::ClientList, + ProtobufRequestType::ClientNoEvict => RequestType::ClientNoEvict, + ProtobufRequestType::ClientNoTouch => RequestType::ClientNoTouch, + ProtobufRequestType::ClientPause => RequestType::ClientPause, + ProtobufRequestType::ClientReply => RequestType::ClientReply, + ProtobufRequestType::ClientSetInfo => RequestType::ClientSetInfo, + ProtobufRequestType::ClientSetName => RequestType::ClientSetName, + ProtobufRequestType::ClientUnblock => RequestType::ClientUnblock, + ProtobufRequestType::ClientUnpause => RequestType::ClientUnpause, + ProtobufRequestType::Expire => RequestType::Expire, + ProtobufRequestType::HashSet => RequestType::HashSet, + ProtobufRequestType::HashGet => RequestType::HashGet, + ProtobufRequestType::HashDel => RequestType::HashDel, + ProtobufRequestType::HashExists => RequestType::HashExists, + ProtobufRequestType::MSet => RequestType::MSet, + ProtobufRequestType::MGet => RequestType::MGet, + ProtobufRequestType::Incr => RequestType::Incr, + ProtobufRequestType::IncrBy => RequestType::IncrBy, + ProtobufRequestType::IncrByFloat => RequestType::IncrByFloat, + ProtobufRequestType::Decr => RequestType::Decr, + ProtobufRequestType::DecrBy => RequestType::DecrBy, + ProtobufRequestType::HashGetAll => RequestType::HashGetAll, + ProtobufRequestType::HashMSet => RequestType::HashMSet, + ProtobufRequestType::HashMGet => RequestType::HashMGet, + ProtobufRequestType::HashIncrBy => RequestType::HashIncrBy, + ProtobufRequestType::HashIncrByFloat => RequestType::HashIncrByFloat, + ProtobufRequestType::LPush => RequestType::LPush, + ProtobufRequestType::LPop => RequestType::LPop, + ProtobufRequestType::RPush => RequestType::RPush, + ProtobufRequestType::RPop => RequestType::RPop, + ProtobufRequestType::LLen => RequestType::LLen, + ProtobufRequestType::LRem => RequestType::LRem, + ProtobufRequestType::LRange => RequestType::LRange, + ProtobufRequestType::LTrim => RequestType::LTrim, + ProtobufRequestType::SAdd => RequestType::SAdd, + ProtobufRequestType::SRem => RequestType::SRem, + ProtobufRequestType::SMembers => RequestType::SMembers, + ProtobufRequestType::SCard => RequestType::SCard, + ProtobufRequestType::PExpireAt => RequestType::PExpireAt, + ProtobufRequestType::PExpire => RequestType::PExpire, + ProtobufRequestType::ExpireAt => RequestType::ExpireAt, + ProtobufRequestType::Exists => RequestType::Exists, + ProtobufRequestType::Unlink => RequestType::Unlink, + ProtobufRequestType::TTL => RequestType::TTL, + ProtobufRequestType::Zadd => RequestType::Zadd, + ProtobufRequestType::Zrem => RequestType::Zrem, + ProtobufRequestType::Zrange => RequestType::Zrange, + ProtobufRequestType::Zcard => RequestType::Zcard, + ProtobufRequestType::Zcount => RequestType::Zcount, + ProtobufRequestType::ZIncrBy => RequestType::ZIncrBy, + ProtobufRequestType::ZScore => RequestType::ZScore, + ProtobufRequestType::Type => RequestType::Type, + ProtobufRequestType::HLen => RequestType::HLen, + ProtobufRequestType::Echo => RequestType::Echo, + ProtobufRequestType::ZPopMin => RequestType::ZPopMin, + ProtobufRequestType::Strlen => RequestType::Strlen, + ProtobufRequestType::Lindex => RequestType::Lindex, + ProtobufRequestType::ZPopMax => RequestType::ZPopMax, + ProtobufRequestType::XAck => RequestType::XAck, + ProtobufRequestType::XAdd => RequestType::XAdd, + ProtobufRequestType::XReadGroup => RequestType::XReadGroup, + ProtobufRequestType::XRead => RequestType::XRead, + ProtobufRequestType::XGroupCreate => RequestType::XGroupCreate, + ProtobufRequestType::XGroupDestroy => RequestType::XGroupDestroy, + ProtobufRequestType::XTrim => RequestType::XTrim, + ProtobufRequestType::HSetNX => RequestType::HSetNX, + ProtobufRequestType::SIsMember => RequestType::SIsMember, + ProtobufRequestType::Hvals => RequestType::Hvals, + ProtobufRequestType::PTTL => RequestType::PTTL, + ProtobufRequestType::ZRemRangeByRank => RequestType::ZRemRangeByRank, + ProtobufRequestType::Persist => RequestType::Persist, + ProtobufRequestType::ZRemRangeByScore => RequestType::ZRemRangeByScore, + ProtobufRequestType::Time => RequestType::Time, + ProtobufRequestType::Zrank => RequestType::Zrank, + ProtobufRequestType::Rename => RequestType::Rename, + ProtobufRequestType::DBSize => RequestType::DBSize, + ProtobufRequestType::Brpop => RequestType::Brpop, + ProtobufRequestType::Hkeys => RequestType::Hkeys, + ProtobufRequestType::PfAdd => RequestType::PfAdd, + ProtobufRequestType::PfCount => RequestType::PfCount, + ProtobufRequestType::PfMerge => RequestType::PfMerge, + ProtobufRequestType::RPushX => RequestType::RPushX, + ProtobufRequestType::LPushX => RequestType::LPushX, + ProtobufRequestType::Blpop => RequestType::Blpop, + } + } +} + +impl RequestType { + /// Returns a `Cmd` set with the command name matching the request. + pub fn get_command(&self) -> Option { + match self { + RequestType::InvalidRequest => None, + RequestType::CustomCommand => Some(Cmd::new()), + RequestType::GetString => Some(cmd("GET")), + RequestType::SetString => Some(cmd("SET")), + RequestType::Ping => Some(cmd("PING")), + RequestType::Info => Some(cmd("INFO")), + RequestType::Del => Some(cmd("DEL")), + RequestType::Select => Some(cmd("SELECT")), + RequestType::ConfigGet => Some(get_two_word_command("CONFIG", "GET")), + RequestType::ConfigSet => Some(get_two_word_command("CONFIG", "SET")), + RequestType::ConfigResetStat => Some(get_two_word_command("CONFIG", "RESETSTAT")), + RequestType::ConfigRewrite => Some(get_two_word_command("CONFIG", "REWRITE")), + RequestType::ClientGetName => Some(get_two_word_command("CLIENT", "GETNAME")), + RequestType::ClientGetRedir => Some(get_two_word_command("CLIENT", "GETREDIR")), + RequestType::ClientId => Some(get_two_word_command("CLIENT", "ID")), + RequestType::ClientInfo => Some(get_two_word_command("CLIENT", "INFO")), + RequestType::ClientKill => Some(get_two_word_command("CLIENT", "KILL")), + RequestType::ClientList => Some(get_two_word_command("CLIENT", "LIST")), + RequestType::ClientNoEvict => Some(get_two_word_command("CLIENT", "NO-EVICT")), + RequestType::ClientNoTouch => Some(get_two_word_command("CLIENT", "NO-TOUCH")), + RequestType::ClientPause => Some(get_two_word_command("CLIENT", "PAUSE")), + RequestType::ClientReply => Some(get_two_word_command("CLIENT", "REPLY")), + RequestType::ClientSetInfo => Some(get_two_word_command("CLIENT", "SETINFO")), + RequestType::ClientSetName => Some(get_two_word_command("CLIENT", "SETNAME")), + RequestType::ClientUnblock => Some(get_two_word_command("CLIENT", "UNBLOCK")), + RequestType::ClientUnpause => Some(get_two_word_command("CLIENT", "UNPAUSE")), + RequestType::Expire => Some(cmd("EXPIRE")), + RequestType::HashSet => Some(cmd("HSET")), + RequestType::HashGet => Some(cmd("HGET")), + RequestType::HashDel => Some(cmd("HDEL")), + RequestType::HashExists => Some(cmd("HEXISTS")), + RequestType::MSet => Some(cmd("MSET")), + RequestType::MGet => Some(cmd("MGET")), + RequestType::Incr => Some(cmd("INCR")), + RequestType::IncrBy => Some(cmd("INCRBY")), + RequestType::IncrByFloat => Some(cmd("INCRBYFLOAT")), + RequestType::Decr => Some(cmd("DECR")), + RequestType::DecrBy => Some(cmd("DECRBY")), + RequestType::HashGetAll => Some(cmd("HGETALL")), + RequestType::HashMSet => Some(cmd("HMSET")), + RequestType::HashMGet => Some(cmd("HMGET")), + RequestType::HashIncrBy => Some(cmd("HINCRBY")), + RequestType::HashIncrByFloat => Some(cmd("HINCRBYFLOAT")), + RequestType::LPush => Some(cmd("LPUSH")), + RequestType::LPop => Some(cmd("LPOP")), + RequestType::RPush => Some(cmd("RPUSH")), + RequestType::RPop => Some(cmd("RPOP")), + RequestType::LLen => Some(cmd("LLEN")), + RequestType::LRem => Some(cmd("LREM")), + RequestType::LRange => Some(cmd("LRANGE")), + RequestType::LTrim => Some(cmd("LTRIM")), + RequestType::SAdd => Some(cmd("SADD")), + RequestType::SRem => Some(cmd("SREM")), + RequestType::SMembers => Some(cmd("SMEMBERS")), + RequestType::SCard => Some(cmd("SCARD")), + RequestType::PExpireAt => Some(cmd("PEXPIREAT")), + RequestType::PExpire => Some(cmd("PEXPIRE")), + RequestType::ExpireAt => Some(cmd("EXPIREAT")), + RequestType::Exists => Some(cmd("EXISTS")), + RequestType::Unlink => Some(cmd("UNLINK")), + RequestType::TTL => Some(cmd("TTL")), + RequestType::Zadd => Some(cmd("ZADD")), + RequestType::Zrem => Some(cmd("ZREM")), + RequestType::Zrange => Some(cmd("ZRANGE")), + RequestType::Zcard => Some(cmd("ZCARD")), + RequestType::Zcount => Some(cmd("ZCOUNT")), + RequestType::ZIncrBy => Some(cmd("ZINCRBY")), + RequestType::ZScore => Some(cmd("ZSCORE")), + RequestType::Type => Some(cmd("TYPE")), + RequestType::HLen => Some(cmd("HLEN")), + RequestType::Echo => Some(cmd("ECHO")), + RequestType::ZPopMin => Some(cmd("ZPOPMIN")), + RequestType::Strlen => Some(cmd("STRLEN")), + RequestType::Lindex => Some(cmd("LINDEX")), + RequestType::ZPopMax => Some(cmd("ZPOPMAX")), + RequestType::XAck => Some(cmd("XACK")), + RequestType::XAdd => Some(cmd("XADD")), + RequestType::XReadGroup => Some(cmd("XREADGROUP")), + RequestType::XRead => Some(cmd("XREAD")), + RequestType::XGroupCreate => Some(get_two_word_command("XGROUP", "CREATE")), + RequestType::XGroupDestroy => Some(get_two_word_command("XGROUP", "DESTROY")), + RequestType::XTrim => Some(cmd("XTRIM")), + RequestType::HSetNX => Some(cmd("HSETNX")), + RequestType::SIsMember => Some(cmd("SISMEMBER")), + RequestType::Hvals => Some(cmd("HVALS")), + RequestType::PTTL => Some(cmd("PTTL")), + RequestType::ZRemRangeByRank => Some(cmd("ZREMRANGEBYRANK")), + RequestType::Persist => Some(cmd("PERSIST")), + RequestType::ZRemRangeByScore => Some(cmd("ZREMRANGEBYSCORE")), + RequestType::Time => Some(cmd("TIME")), + RequestType::Zrank => Some(cmd("ZRANK")), + RequestType::Rename => Some(cmd("RENAME")), + RequestType::DBSize => Some(cmd("DBSIZE")), + RequestType::Brpop => Some(cmd("BRPOP")), + RequestType::Hkeys => Some(cmd("HKEYS")), + RequestType::PfAdd => Some(cmd("PFADD")), + RequestType::PfCount => Some(cmd("PFCOUNT")), + RequestType::PfMerge => Some(cmd("PFMERGE")), + RequestType::RPushX => Some(cmd("RPUSHX")), + RequestType::LPushX => Some(cmd("LPUSHX")), + RequestType::Blpop => Some(cmd("BLPOP")), + } + } +} diff --git a/glide-core/src/socket_listener.rs b/glide-core/src/socket_listener.rs index 2ab4792f79..fc72b49a46 100644 --- a/glide-core/src/socket_listener.rs +++ b/glide-core/src/socket_listener.rs @@ -6,8 +6,7 @@ use crate::client::Client; use crate::connection_request::ConnectionRequest; use crate::errors::{error_message, error_type, RequestErrorType}; use crate::redis_request::{ - command, redis_request, Command, RedisRequest, RequestType, Routes, ScriptInvocation, - SlotTypes, Transaction, + command, redis_request, Command, RedisRequest, Routes, ScriptInvocation, SlotTypes, Transaction, }; use crate::response; use crate::response::Response; @@ -21,7 +20,7 @@ use redis::cluster_routing::{ }; use redis::cluster_routing::{ResponsePolicy, Routable}; use redis::RedisError; -use redis::{cmd, Cmd, Value}; +use redis::{Cmd, Value}; use std::cell::Cell; use std::rc::Rc; use std::{env, str}; @@ -257,119 +256,9 @@ async fn write_to_writer(response: Response, writer: &Rc) -> Result<(), } } -fn get_two_word_command(first: &str, second: &str) -> Cmd { - let mut cmd = cmd(first); - cmd.arg(second); - cmd -} - fn get_command(request: &Command) -> Option { - let request_enum = request - .request_type - .enum_value_or(RequestType::InvalidRequest); - match request_enum { - RequestType::InvalidRequest => None, - RequestType::CustomCommand => Some(Cmd::new()), - RequestType::GetString => Some(cmd("GET")), - RequestType::SetString => Some(cmd("SET")), - RequestType::Ping => Some(cmd("PING")), - RequestType::Info => Some(cmd("INFO")), - RequestType::Del => Some(cmd("DEL")), - RequestType::Select => Some(cmd("SELECT")), - RequestType::ConfigGet => Some(get_two_word_command("CONFIG", "GET")), - RequestType::ConfigSet => Some(get_two_word_command("CONFIG", "SET")), - RequestType::ConfigResetStat => Some(get_two_word_command("CONFIG", "RESETSTAT")), - RequestType::ConfigRewrite => Some(get_two_word_command("CONFIG", "REWRITE")), - RequestType::ClientGetName => Some(get_two_word_command("CLIENT", "GETNAME")), - RequestType::ClientGetRedir => Some(get_two_word_command("CLIENT", "GETREDIR")), - RequestType::ClientId => Some(get_two_word_command("CLIENT", "ID")), - RequestType::ClientInfo => Some(get_two_word_command("CLIENT", "INFO")), - RequestType::ClientKill => Some(get_two_word_command("CLIENT", "KILL")), - RequestType::ClientList => Some(get_two_word_command("CLIENT", "LIST")), - RequestType::ClientNoEvict => Some(get_two_word_command("CLIENT", "NO-EVICT")), - RequestType::ClientNoTouch => Some(get_two_word_command("CLIENT", "NO-TOUCH")), - RequestType::ClientPause => Some(get_two_word_command("CLIENT", "PAUSE")), - RequestType::ClientReply => Some(get_two_word_command("CLIENT", "REPLY")), - RequestType::ClientSetInfo => Some(get_two_word_command("CLIENT", "SETINFO")), - RequestType::ClientSetName => Some(get_two_word_command("CLIENT", "SETNAME")), - RequestType::ClientUnblock => Some(get_two_word_command("CLIENT", "UNBLOCK")), - RequestType::ClientUnpause => Some(get_two_word_command("CLIENT", "UNPAUSE")), - RequestType::Expire => Some(cmd("EXPIRE")), - RequestType::HashSet => Some(cmd("HSET")), - RequestType::HashGet => Some(cmd("HGET")), - RequestType::HashDel => Some(cmd("HDEL")), - RequestType::HashExists => Some(cmd("HEXISTS")), - RequestType::MSet => Some(cmd("MSET")), - RequestType::MGet => Some(cmd("MGET")), - RequestType::Incr => Some(cmd("INCR")), - RequestType::IncrBy => Some(cmd("INCRBY")), - RequestType::IncrByFloat => Some(cmd("INCRBYFLOAT")), - RequestType::Decr => Some(cmd("DECR")), - RequestType::DecrBy => Some(cmd("DECRBY")), - RequestType::HashGetAll => Some(cmd("HGETALL")), - RequestType::HashMSet => Some(cmd("HMSET")), - RequestType::HashMGet => Some(cmd("HMGET")), - RequestType::HashIncrBy => Some(cmd("HINCRBY")), - RequestType::HashIncrByFloat => Some(cmd("HINCRBYFLOAT")), - RequestType::LPush => Some(cmd("LPUSH")), - RequestType::LPop => Some(cmd("LPOP")), - RequestType::RPush => Some(cmd("RPUSH")), - RequestType::RPop => Some(cmd("RPOP")), - RequestType::LLen => Some(cmd("LLEN")), - RequestType::LRem => Some(cmd("LREM")), - RequestType::LRange => Some(cmd("LRANGE")), - RequestType::LTrim => Some(cmd("LTRIM")), - RequestType::SAdd => Some(cmd("SADD")), - RequestType::SRem => Some(cmd("SREM")), - RequestType::SMembers => Some(cmd("SMEMBERS")), - RequestType::SCard => Some(cmd("SCARD")), - RequestType::PExpireAt => Some(cmd("PEXPIREAT")), - RequestType::PExpire => Some(cmd("PEXPIRE")), - RequestType::ExpireAt => Some(cmd("EXPIREAT")), - RequestType::Exists => Some(cmd("EXISTS")), - RequestType::Unlink => Some(cmd("UNLINK")), - RequestType::TTL => Some(cmd("TTL")), - RequestType::Zadd => Some(cmd("ZADD")), - RequestType::Zrem => Some(cmd("ZREM")), - RequestType::Zrange => Some(cmd("ZRANGE")), - RequestType::Zcard => Some(cmd("ZCARD")), - RequestType::Zcount => Some(cmd("ZCOUNT")), - RequestType::ZIncrBy => Some(cmd("ZINCRBY")), - RequestType::ZScore => Some(cmd("ZSCORE")), - RequestType::Type => Some(cmd("TYPE")), - RequestType::HLen => Some(cmd("HLEN")), - RequestType::Echo => Some(cmd("ECHO")), - RequestType::ZPopMin => Some(cmd("ZPOPMIN")), - RequestType::Strlen => Some(cmd("STRLEN")), - RequestType::Lindex => Some(cmd("LINDEX")), - RequestType::ZPopMax => Some(cmd("ZPOPMAX")), - RequestType::XAck => Some(cmd("XACK")), - RequestType::XAdd => Some(cmd("XADD")), - RequestType::XReadGroup => Some(cmd("XREADGROUP")), - RequestType::XRead => Some(cmd("XREAD")), - RequestType::XGroupCreate => Some(get_two_word_command("XGROUP", "CREATE")), - RequestType::XGroupDestroy => Some(get_two_word_command("XGROUP", "DESTROY")), - RequestType::XTrim => Some(cmd("XTRIM")), - RequestType::HSetNX => Some(cmd("HSETNX")), - RequestType::SIsMember => Some(cmd("SISMEMBER")), - RequestType::Hvals => Some(cmd("HVALS")), - RequestType::PTTL => Some(cmd("PTTL")), - RequestType::ZRemRangeByRank => Some(cmd("ZREMRANGEBYRANK")), - RequestType::Persist => Some(cmd("PERSIST")), - RequestType::ZRemRangeByScore => Some(cmd("ZREMRANGEBYSCORE")), - RequestType::Time => Some(cmd("TIME")), - RequestType::Zrank => Some(cmd("ZRANK")), - RequestType::Rename => Some(cmd("RENAME")), - RequestType::DBSize => Some(cmd("DBSIZE")), - RequestType::Brpop => Some(cmd("BRPOP")), - RequestType::Hkeys => Some(cmd("HKEYS")), - RequestType::PfAdd => Some(cmd("PFADD")), - RequestType::PfCount => Some(cmd("PFCOUNT")), - RequestType::PfMerge => Some(cmd("PFMERGE")), - RequestType::RPushX => Some(cmd("RPUSHX")), - RequestType::LPushX => Some(cmd("LPUSHX")), - RequestType::Blpop => Some(cmd("BLPOP")), - } + let request_type: crate::request_type::RequestType = request.request_type.into(); + request_type.get_command() } fn get_redis_command(command: &Command) -> Result {