diff --git a/src/Confluent.SchemaRegistry/AsyncSerde.cs b/src/Confluent.SchemaRegistry/AsyncSerde.cs index 70934d291..698e6873b 100644 --- a/src/Confluent.SchemaRegistry/AsyncSerde.cs +++ b/src/Confluent.SchemaRegistry/AsyncSerde.cs @@ -359,7 +359,7 @@ protected async Task ExecuteRules( for (int i = 0; i < rules.Count; i++) { Rule rule = rules[i]; - if (rule.Disabled) + if (IsDisabled(rule)) { continue; } @@ -406,21 +406,21 @@ protected async Task ExecuteRules( default: throw new ArgumentException("Unsupported rule kind " + rule.Kind); } - await RunAction(ctx, ruleMode, rule, message != null ? rule.OnSuccess : rule.OnFailure, + await RunAction(ctx, ruleMode, rule, message != null ? GetOnSuccess(rule) : GetOnFailure(rule), message, null, message != null ? null : ErrorAction.ActionType, ruleRegistry) .ConfigureAwait(continueOnCapturedContext: false); } catch (RuleException ex) { - await RunAction(ctx, ruleMode, rule, rule.OnFailure, message, + await RunAction(ctx, ruleMode, rule, GetOnFailure(rule), message, ex, ErrorAction.ActionType, ruleRegistry) .ConfigureAwait(continueOnCapturedContext: false); } } else { - await RunAction(ctx, ruleMode, rule, rule.OnFailure, message, + await RunAction(ctx, ruleMode, rule, GetOnFailure(rule), message, new RuleException("Could not find rule executor of type " + rule.Type), ErrorAction.ActionType, ruleRegistry) .ConfigureAwait(continueOnCapturedContext: false); @@ -429,6 +429,45 @@ await RunAction(ctx, ruleMode, rule, rule.OnFailure, message, return message; } + private string GetOnSuccess(Rule rule) + { + if (ruleRegistry.TryGetOverride(rule.Type, out RuleOverride ruleOverride)) + { + if (ruleOverride.OnSuccess != null) + { + return ruleOverride.OnSuccess; + } + } + + return rule.OnSuccess; + } + + private string GetOnFailure(Rule rule) + { + if (ruleRegistry.TryGetOverride(rule.Type, out RuleOverride ruleOverride)) + { + if (ruleOverride.OnFailure != null) + { + return ruleOverride.OnFailure; + } + } + + return rule.OnFailure; + } + + private bool IsDisabled(Rule rule) + { + if (ruleRegistry.TryGetOverride(rule.Type, out RuleOverride ruleOverride)) + { + if (ruleOverride.Disabled.HasValue) + { + return ruleOverride.Disabled.Value; + } + } + + return rule.Disabled; + } + private static IRuleExecutor GetRuleExecutor(RuleRegistry ruleRegistry, string type) { if (ruleRegistry.TryGetExecutor(type, out IRuleExecutor result)) diff --git a/src/Confluent.SchemaRegistry/RuleOverride.cs b/src/Confluent.SchemaRegistry/RuleOverride.cs new file mode 100644 index 000000000..cc00e4316 --- /dev/null +++ b/src/Confluent.SchemaRegistry/RuleOverride.cs @@ -0,0 +1,40 @@ +// Copyright 2024 Confluent Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Refer to LICENSE for more information. + +namespace Confluent.SchemaRegistry +{ + /// + /// A rule override. + /// + public class RuleOverride + { + public string Type { get; set; } + + public string OnSuccess { get; set; } + + public string OnFailure { get; set; } + + public bool? Disabled { get; set; } + + public RuleOverride(string type, string onSuccess, string onFailure, bool? disabled) + { + Type = type; + OnSuccess = onSuccess; + OnFailure = onFailure; + Disabled = disabled; + } + } +} \ No newline at end of file diff --git a/src/Confluent.SchemaRegistry/RuleRegistry.cs b/src/Confluent.SchemaRegistry/RuleRegistry.cs index af71421a0..16ae348c7 100644 --- a/src/Confluent.SchemaRegistry/RuleRegistry.cs +++ b/src/Confluent.SchemaRegistry/RuleRegistry.cs @@ -26,19 +26,16 @@ public class RuleRegistry { private readonly SemaphoreSlim ruleExecutorsMutex = new SemaphoreSlim(1); private readonly SemaphoreSlim ruleActionsMutex = new SemaphoreSlim(1); + private readonly SemaphoreSlim ruleOverridesMutex = new SemaphoreSlim(1); private IDictionary ruleExecutors = new Dictionary(); private IDictionary ruleActions = new Dictionary(); + private IDictionary ruleOverrides = new Dictionary(); private static readonly RuleRegistry GLOBAL_INSTANCE = new RuleRegistry(); public static RuleRegistry GlobalInstance => GLOBAL_INSTANCE; - public static List GetRuleActions() - { - return GlobalInstance.GetActions(); - } - public void RegisterExecutor(IRuleExecutor executor) { ruleExecutorsMutex.Wait(); @@ -123,6 +120,48 @@ public List GetActions() } } + public void RegisterOverride(RuleOverride ruleOverride) + { + ruleOverridesMutex.Wait(); + try + { + if (!ruleOverrides.ContainsKey(ruleOverride.Type)) + { + ruleOverrides.Add(ruleOverride.Type, ruleOverride); + } + } + finally + { + ruleOverridesMutex.Release(); + } + } + + public bool TryGetOverride(string name, out RuleOverride ruleOverride) + { + ruleOverridesMutex.Wait(); + try + { + return ruleOverrides.TryGetValue(name, out ruleOverride); + } + finally + { + ruleOverridesMutex.Release(); + } + } + + public List GetOverrides() + { + ruleOverridesMutex.Wait(); + try + { + return new List(ruleOverrides.Values); + } + finally + { + ruleOverridesMutex.Release(); + } + } + public static void RegisterRuleExecutor(IRuleExecutor executor) { GlobalInstance.RegisterExecutor(executor); @@ -132,5 +171,10 @@ public static void RegisterRuleAction(IRuleAction action) { GlobalInstance.RegisterAction(action); } + + public static void RegisterRuleOverride(RuleOverride ruleOverride) + { + GlobalInstance.RegisterOverride(ruleOverride); + } } } diff --git a/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs b/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs index 8843f882d..db021d651 100644 --- a/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs +++ b/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs @@ -275,12 +275,12 @@ public void ISpecificRecordCELFieldTransform() schema.RuleSet = new RuleSet(new List(), new List { - new Rule("testCEL", RuleKind.Transform, RuleMode.Write, "CEL_FIELD", null, null, + new Rule("testCEL", RuleKind.Transform, RuleMode.Write, "CEL_FIELD", null, null, "typeName == 'STRING' ; value + '-suffix'", null, null, false) } ); store[schemaStr] = 1; - subjectStore["topic-value"] = new List { schema }; + subjectStore["topic-value"] = new List { schema }; var config = new AvroSerializerConfig { AutoRegisterSchemas = false, @@ -305,6 +305,46 @@ public void ISpecificRecordCELFieldTransform() Assert.Equal(user.favorite_number, result.favorite_number); } + [Fact] + public void ISpecificRecordCELFieldTransformDisable() + { + var schemaStr = User._SCHEMA.ToString(); + var schema = new RegisteredSchema("topic-value", 1, 1, schemaStr, SchemaType.Avro, null); + schema.RuleSet = new RuleSet(new List(), + new List + { + new Rule("testCEL", RuleKind.Transform, RuleMode.Write, "CEL_FIELD", null, null, + "typeName == 'STRING' ; value + '-suffix'", null, null, false) + } + ); + store[schemaStr] = 1; + subjectStore["topic-value"] = new List { schema }; + var config = new AvroSerializerConfig + { + AutoRegisterSchemas = false, + UseLatestVersion = true + }; + RuleRegistry registry = new RuleRegistry(); + registry.RegisterOverride(new RuleOverride("CEL_FIELD", null, null, true)); + var serializer = new AvroSerializer(schemaRegistryClient, config, registry); + var deserializer = new AvroDeserializer(schemaRegistryClient, null); + + var user = new User + { + favorite_color = "blue", + favorite_number = 100, + name = "awesome" + }; + + Headers headers = new Headers(); + var bytes = serializer.SerializeAsync(user, new SerializationContext(MessageComponentType.Value, testTopic, headers)).Result; + var result = deserializer.DeserializeAsync(bytes, false, new SerializationContext(MessageComponentType.Value, testTopic, headers)).Result; + + Assert.Equal("awesome", result.name); + Assert.Equal("blue", result.favorite_color); + Assert.Equal(user.favorite_number, result.favorite_number); + } + [Fact] public void ISpecificRecordCELFieldCondition() {