diff --git a/src/mria_config.erl b/src/mria_config.erl index 61624e2..445f1b5 100644 --- a/src/mria_config.erl +++ b/src/mria_config.erl @@ -1,5 +1,5 @@ %%-------------------------------------------------------------------- -%% Copyright (c) 2021-2023 EMQ Technologies Co., Ltd. All Rights Reserved. +%% Copyright (c) 2021-2024 EMQ Technologies Co., Ltd. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -70,7 +70,9 @@ -type callback() :: start | stop | {start | stop, mria_rlog:shard()} - | core_node_discovery. + | core_node_discovery + | lb_custom_info + | lb_custom_info_check. -type callback_function() :: fun(() -> term()) | fun((term()) -> term()). diff --git a/src/mria_lb.erl b/src/mria_lb.erl index efd4582..213e886 100644 --- a/src/mria_lb.erl +++ b/src/mria_lb.erl @@ -1,5 +1,5 @@ %%-------------------------------------------------------------------- -%% Copyright (c) 2021-2023 EMQ Technologies Co., Ltd. All Rights Reserved. +%% Copyright (c) 2021-2024 EMQ Technologies Co., Ltd. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -60,6 +60,7 @@ , protocol_version := non_neg_integer() , db_nodes => [node()] , shard_badness => [{mria_rlog:shard(), float()}] + , custom_info => _ }. -define(update, update). @@ -192,29 +193,55 @@ find_clusters(NodeInfo) -> -spec shard_badness(#{node() => node_info()}) -> #{mria_rlog:shard() => {node(), Badness}} when Badness :: float(). shard_badness(NodeInfo) -> - MyProtoVersion = mria_rlog:get_protocol_version(), maps:fold( - fun(Node, #{shard_badness := Shards, protocol_version := ProtoVsn}, Acc) - when ProtoVsn =:= MyProtoVersion -> - lists:foldl( - fun({Shard, Badness}, Acc1) -> - maps:update_with(Shard, - fun({_OldNode, OldBadness}) when OldBadness > Badness -> - {Node, Badness}; - (Old) -> - Old - end, - {Node, Badness}, - Acc1) - end, - Acc, - Shards); - (_Node, _NodeInfo, Acc) -> - Acc + fun(Node, LbInfo = #{shard_badness := Shards}, Acc) -> + case verify_node_compatibility(LbInfo) of + true -> + lists:foldl( + fun({Shard, Badness}, Acc1) -> + maps:update_with(Shard, + fun({_OldNode, OldBadness}) when OldBadness > Badness -> + {Node, Badness}; + (Old) -> + Old + end, + {Node, Badness}, + Acc1) + end, + Acc, + Shards); + false -> + Acc + end end, #{}, NodeInfo). +verify_node_compatibility(LbInfo = #{protocol_version := ProtoVsn}) -> + case mria_config:callback(lb_custom_info_check) of + {ok, CustomCheckFun} -> + ok; + undefined -> + CustomCheckFun = fun(_) -> true end + end, + CustomInfo = maps:get(custom_info, LbInfo, undefined), + MyProtoVersion = mria_rlog:get_protocol_version(), + %% Actual check: + IsCustomCompat = try + Result = CustomCheckFun(CustomInfo), + is_boolean(Result) orelse + error({non_boolean_result, Result}), + Result + catch + %% TODO: this can get spammy: + EC:Err:Stack -> + ?tp(error, mria_failed_to_check_upstream_compatibility, + #{lb_info => LbInfo, EC => Err, stacktrace => Stack}), + false + end, + ProtoVsn =:= MyProtoVersion andalso + IsCustomCompat. + start_timer(LastUpdateTime) -> %% Leave at least 100 ms between updates to leave some time to %% process other events: @@ -287,11 +314,16 @@ lb_callback() -> {ok, Vsn} -> Vsn; undefined -> undefined end, + CustomInfo = case mria_config:callback(lb_custom_info) of + {ok, CB} -> CB(); + undefined -> undefined + end, BasicInfo = #{ running => IsRunning , version => Version , whoami => Whoami , protocol_version => mria_rlog:get_protocol_version() + , custom_info => CustomInfo }, MoreInfo = case Whoami of diff --git a/test/mria_lb_SUITE.erl b/test/mria_lb_SUITE.erl index 5ee0241..f76d0f4 100644 --- a/test/mria_lb_SUITE.erl +++ b/test/mria_lb_SUITE.erl @@ -1,5 +1,5 @@ %%-------------------------------------------------------------------- -%% Copyright (c) 2019-2021, 2023 EMQ Technologies Co., Ltd. All Rights Reserved. +%% Copyright (c) 2019-2021, 2023-2024 EMQ Technologies Co., Ltd. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ -include_lib("eunit/include/eunit.hrl"). -include_lib("snabbkaffe/include/snabbkaffe.hrl"). +-include("mria_rlog.hrl"). all() -> mria_ct:all(?MODULE). @@ -233,6 +234,29 @@ t_core_node_leave(_Config) -> mria_ct:teardown_cluster(Cluster) end, []). +t_custom_compat_check(_Config) -> + Env = [ {mria, {callback, lb_custom_info_check}, fun(Val) -> Val =:= chosen_one end} + | mria_mnesia_test_util:common_env()], + Cluster = mria_ct:cluster([ core + , core + , {core, [{mria, {callback, lb_custom_info}, + fun() -> chosen_one end}]} + , replicant + ], Env), + ?check_trace( + #{timetrap => 15000}, + try + [_C1, _C2, C3, R1] = mria_ct:start_cluster(mria, Cluster), + ?assertEqual({ok, C3}, + erpc:call( R1 + , mria_status, get_core_node, [?mria_meta_shard, infinity] + , infinity + )) + after + mria_ct:teardown_cluster(Cluster) + end, + []). + clear_core_node_list(Replicant) -> MaybeOldCallback = erpc:call(Replicant, mria_config, callback, [core_node_discovery]), try