From c9c1c28241b9b4801c7decb121207fd8ea0e7734 Mon Sep 17 00:00:00 2001 From: LingKa Date: Tue, 16 Jan 2024 09:20:35 +0800 Subject: [PATCH] feat: error-handling Signed-off-by: LingKa --- client/error.py | 18 ++++ client/protocol.py | 202 +++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 204 insertions(+), 16 deletions(-) diff --git a/client/error.py b/client/error.py index e66794f..972775f 100644 --- a/client/error.py +++ b/client/error.py @@ -12,3 +12,21 @@ class ExecuteError(Exception): def __init__(self, err: _ExecuteError) -> None: self.inner = err + + +class ShuttingDownError(Exception): + """Server is shutting down""" + + pass + + +class WrongClusterVersionError(Exception): + """Wrong cluster version""" + + pass + + +class InternalError(Exception): + """Internal Error in client""" + + pass diff --git a/client/protocol.py b/client/protocol.py index 732739b..566bd34 100644 --- a/client/protocol.py +++ b/client/protocol.py @@ -16,9 +16,15 @@ ProposeId, FetchClusterResponse, FetchClusterRequest, + CurpError, ) from api.curp.curp_command_pb2_grpc import ProtocolStub -from client.error import ExecuteError +from client.error import ( + ExecuteError, + WrongClusterVersionError, + ShuttingDownError, + InternalError, +) from api.xline.xline_command_pb2 import Command, CommandResponse, SyncResponse from api.xline.xline_error_pb2 import ExecuteError as _ExecuteError @@ -36,7 +42,6 @@ class ProtocolClient: state: State connects: dict[int, grpc.Channel] cluster_version: int - # TODO config def __init__( self, @@ -82,25 +87,75 @@ async def fast_fetch_cluster(addrs: list[str]) -> FetchClusterResponse: msg = "fetch cluster error" raise Exception(msg) + async def fetch_cluster(self, linearizable: bool) -> FetchClusterResponse: + """ + Send fetch cluster requests to all servers + Note: The fetched cluster may still be outdated if `linearizable` is false + """ + connects = self.all_connects() + rpcs: list[grpc.Future[FetchClusterResponse]] = [] + + for channel in connects: + stub = ProtocolStub(channel) + rpcs.append(stub.FetchCluster(FetchClusterRequest(linearizable=linearizable))) + + max_term = 0 + resp: FetchClusterResponse | None = None + ok_cnt = 0 + majority_cnt = len(connects) // 2 + 1 + + for rpc in asyncio.as_completed(rpcs): + try: + res: FetchClusterResponse = await rpc + except grpc.RpcError as e: + logging.warning(e) + continue + + if max_term < res.term: + max_term = res.term + if len(res.members) == 0: + resp = res + ok_cnt = 1 + elif max_term == res.term: + if len(res.members) == 0: + resp = res + ok_cnt += 1 + else: + pass + + if ok_cnt >= majority_cnt: + break + + if resp is not None: + logging.debug("Fetch cluster succeeded, result: %s", res) + self.state.check_and_update(res.leader_id, res.term) + return resp + + async def fetch_leader(self) -> ServerId: + """ + Send fetch leader requests to all servers until there is a leader + Note: The fetched leader may still be outdated + """ + res = await self.fetch_cluster(False) + return res.leader_id + async def propose(self, cmd: Command, use_fast_path: bool = False) -> tuple[CommandResponse, SyncResponse | None]: """ Propose the request to servers, if use_fast_path is false, it will wait for the synced index """ propose_id = self.gen_propose_id() - # TODO: retry if use_fast_path: return await self.fast_path(propose_id, cmd) else: return await self.slow_path(propose_id, cmd) - # TODO: error handling async def fast_path(self, propose_id: ProposeId, cmd: Command) -> tuple[CommandResponse, SyncResponse | None]: """ Fast path of propose """ fast_round = self.fast_round(propose_id, cmd) - slow_round = self.slow_round(propose_id) + slow_round = self.slow_round(propose_id, cmd) # Wait for the fast and slow round at the same time for futures in asyncio.as_completed([fast_round, slow_round]): @@ -111,10 +166,8 @@ async def fast_path(self, propose_id: ProposeId, cmd: Command) -> tuple[CommandR continue if isinstance(first, CommandResponse) and second: - # TODO: error handling return (first, None) if isinstance(second, CommandResponse) and isinstance(first, SyncResponse): - # TODO: error handling return (second, first) msg = "fast path error" @@ -124,7 +177,7 @@ async def slow_path(self, propose_id: ProposeId, cmd: Command) -> tuple[CommandR """ Slow path of propose """ - results = await asyncio.gather(self.fast_round(propose_id, cmd), self.slow_round(propose_id)) + results = await asyncio.gather(self.fast_round(propose_id, cmd), self.slow_round(propose_id, cmd)) for result in results: if isinstance(result[0], SyncResponse) and isinstance(result[1], CommandResponse): return (result[1], result[0]) @@ -153,9 +206,17 @@ async def fast_round(self, propose_id: ProposeId, cmd: Command) -> tuple[Command res: ProposeResponse = ProposeResponse() try: res = await future - except Exception as e: + except grpc.RpcError as e: logging.warning(e) - continue + curp_err = CurpError() + dtl = e.details() + curp_err.ParseFromString(dtl) + if curp_err.HasField("ShuttingDown"): + raise ShuttingDownError from e + elif curp_err.HasField("WrongClusterVersion"): + raise WrongClusterVersionError from e + else: + continue ok_cnt += 1 if not res.HasField("result"): @@ -178,7 +239,7 @@ async def fast_round(self, propose_id: ProposeId, cmd: Command) -> tuple[Command logging.info("fast round failed. propose id: %s", propose_id) return (cmd_res, False) - async def slow_round(self, propose_id: ProposeId) -> tuple[SyncResponse, CommandResponse]: + async def slow_round(self, propose_id: ProposeId, cmd: Command) -> tuple[SyncResponse, CommandResponse]: """ The slow round of Curp protocol """ @@ -190,9 +251,33 @@ async def slow_round(self, propose_id: ProposeId) -> tuple[SyncResponse, Command channel = self.connects[self.state.leader] stub = ProtocolStub(channel) - res: WaitSyncedResponse = await stub.WaitSynced( - WaitSyncedRequest(propose_id=propose_id, cluster_version=self.cluster_version) - ) + + res = WaitSyncedResponse() + try: + res: WaitSyncedResponse = await stub.WaitSynced( + WaitSyncedRequest(propose_id=propose_id, cluster_version=self.cluster_version) + ) + except grpc.RpcError as e: + logging.warning("wait synced rpc error: %s", e) + curp_err = CurpError() + details = e.details() + curp_err.ParseFromString(details) + if curp_err.HasField("ShuttingDown"): + raise ShuttingDownError from e + elif curp_err.HasField("WrongClusterVersion"): + raise WrongClusterVersionError from e + elif curp_err.HasField("RpcTransport"): + # it's quite likely that the leader has crashed, + # then we should wait for some time and fetch the leader again + self.resend_propose(propose_id, cmd, None) + elif curp_err.HasField("redirect"): + new_leader = curp_err.redirect.leader_id + term = curp_err.redirect.term + self.state.check_and_update(new_leader, term) + # resend the propose to the new leader + self.resend_propose(propose_id, cmd, None) + else: + raise InternalError from e if res.after_sync_result.ok: asr.ParseFromString(res.after_sync_result.ok) @@ -208,6 +293,44 @@ async def slow_round(self, propose_id: ProposeId) -> tuple[SyncResponse, Command return (asr, er) + def resend_propose(self, propose_id: ProposeId, cmd: Command, new_leader: ServerId | None) -> True | None: + """ + Resend the propose only to the leader. + This is used when leader changes and we need to ensure that the propose is received by the new leader. + """ + leader_id: int | None = None + if new_leader is not None: + _id = new_leader + try: + self.fetch_leader() + leader_id = _id + except Exception as e: + logging.warning("failed to fetch leader, %s", e) + logging.debug("resend propose to %s", leader_id) + + stub = ProtocolStub(self.get_connect(leader_id)) + + try: + stub.Propose( + ProposeRequest( + propose_id=propose_id, command=cmd.SerializeToString(), cluster_version=self.cluster_version + ) + ) + except grpc.RpcError as e: + # if the propose fails again, need to fetch the leader and try again + logging.warning("failed to resend propose, %s", e) + curp_err = CurpError() + dtl = e.details() + curp_err.ParseFromString(dtl) + if curp_err.HasField("ShuttingDown"): + raise ShuttingDownError from e + elif curp_err.HasField("WrongClusterVersion"): + raise WrongClusterVersionError from e + elif curp_err.HasField("Duplicated"): + return True + else: + return None + def gen_propose_id(self) -> ProposeId: """ Generate a propose id @@ -232,6 +355,18 @@ def get_client_id(self) -> int: """ return random.randint(0, 2**64 - 1) + def all_connects(self) -> list[grpc.Channel]: + """ + Get all connects + """ + return list(self.connects.values()) + + def get_connect(self, _id: ServerId) -> grpc.Channel: + """ + Get all connects + """ + return self.connects[_id] + @staticmethod def super_quorum(nodes: int) -> int: """ @@ -258,9 +393,44 @@ class State: term: Current term """ - leader: int + leader: int | None term: int - def __init__(self, leader: int, term: int) -> None: + def __init__(self, leader: int | None, term: int) -> None: self.leader = leader self.term = term + + def check_and_update(self, leader_id: int | None, term: int): + """ + Check the term and leader id, update the state if needed + """ + if self.term < term: + # reset term only when the resp has leader id to prevent: + # If a server loses contact with its leader, it will update its term for election. + # Since other servers are all right, the election will not succeed. + # But if the client learns about the new term and updates its term to it, it will never get the true leader. + if leader_id is not None: + new_leader_id = leader_id + self.update_to_term(term) + self.set_leader(new_leader_id) + elif self.term == term: + if leader_id is not None: + new_leader_id = leader_id + if self.leader is None: + self.set_leader(new_leader_id) + else: + pass + + def update_to_term(self, term: int) -> None: + """ + Update to the newest term and reset local cache + """ + self.term = term + self.leader = None + + def set_leader(self, _id: ServerId) -> None: + """ + Set the leader and notify all the waiters + """ + logging.debug("client update its leader to %s", _id) + self.leader = _id