Skip to content

Commit

Permalink
feat: error-handling
Browse files Browse the repository at this point in the history
Signed-off-by: LingKa <[email protected]>
  • Loading branch information
LingKa28 committed Jan 16, 2024
1 parent 7b78699 commit c9c1c28
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 16 deletions.
18 changes: 18 additions & 0 deletions client/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,21 @@ class ExecuteError(Exception):

def __init__(self, err: _ExecuteError) -> None:
self.inner = err


class ShuttingDownError(Exception):
"""Server is shutting down"""

pass


class WrongClusterVersionError(Exception):
"""Wrong cluster version"""

pass


class InternalError(Exception):
"""Internal Error in client"""

pass
202 changes: 186 additions & 16 deletions client/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
ProposeId,
FetchClusterResponse,
FetchClusterRequest,
CurpError,
)
from api.curp.curp_command_pb2_grpc import ProtocolStub
from client.error import ExecuteError
from client.error import (
ExecuteError,
WrongClusterVersionError,
ShuttingDownError,
InternalError,
)
from api.xline.xline_command_pb2 import Command, CommandResponse, SyncResponse
from api.xline.xline_error_pb2 import ExecuteError as _ExecuteError

Expand All @@ -36,7 +42,6 @@ class ProtocolClient:
state: State
connects: dict[int, grpc.Channel]
cluster_version: int
# TODO config

def __init__(
self,
Expand Down Expand Up @@ -82,25 +87,75 @@ async def fast_fetch_cluster(addrs: list[str]) -> FetchClusterResponse:
msg = "fetch cluster error"
raise Exception(msg)

async def fetch_cluster(self, linearizable: bool) -> FetchClusterResponse:
"""
Send fetch cluster requests to all servers
Note: The fetched cluster may still be outdated if `linearizable` is false
"""
connects = self.all_connects()
rpcs: list[grpc.Future[FetchClusterResponse]] = []

for channel in connects:
stub = ProtocolStub(channel)
rpcs.append(stub.FetchCluster(FetchClusterRequest(linearizable=linearizable)))

max_term = 0
resp: FetchClusterResponse | None = None
ok_cnt = 0
majority_cnt = len(connects) // 2 + 1

for rpc in asyncio.as_completed(rpcs):
try:
res: FetchClusterResponse = await rpc
except grpc.RpcError as e:
logging.warning(e)
continue

if max_term < res.term:
max_term = res.term
if len(res.members) == 0:
resp = res
ok_cnt = 1
elif max_term == res.term:
if len(res.members) == 0:
resp = res
ok_cnt += 1
else:
pass

if ok_cnt >= majority_cnt:
break

if resp is not None:
logging.debug("Fetch cluster succeeded, result: %s", res)
self.state.check_and_update(res.leader_id, res.term)
return resp

async def fetch_leader(self) -> ServerId:
"""
Send fetch leader requests to all servers until there is a leader
Note: The fetched leader may still be outdated
"""
res = await self.fetch_cluster(False)
return res.leader_id

async def propose(self, cmd: Command, use_fast_path: bool = False) -> tuple[CommandResponse, SyncResponse | None]:
"""
Propose the request to servers, if use_fast_path is false, it will wait for the synced index
"""
propose_id = self.gen_propose_id()

# TODO: retry
if use_fast_path:
return await self.fast_path(propose_id, cmd)
else:
return await self.slow_path(propose_id, cmd)
# TODO: error handling

async def fast_path(self, propose_id: ProposeId, cmd: Command) -> tuple[CommandResponse, SyncResponse | None]:
"""
Fast path of propose
"""
fast_round = self.fast_round(propose_id, cmd)
slow_round = self.slow_round(propose_id)
slow_round = self.slow_round(propose_id, cmd)

# Wait for the fast and slow round at the same time
for futures in asyncio.as_completed([fast_round, slow_round]):
Expand All @@ -111,10 +166,8 @@ async def fast_path(self, propose_id: ProposeId, cmd: Command) -> tuple[CommandR
continue

if isinstance(first, CommandResponse) and second:
# TODO: error handling
return (first, None)
if isinstance(second, CommandResponse) and isinstance(first, SyncResponse):
# TODO: error handling
return (second, first)

msg = "fast path error"
Expand All @@ -124,7 +177,7 @@ async def slow_path(self, propose_id: ProposeId, cmd: Command) -> tuple[CommandR
"""
Slow path of propose
"""
results = await asyncio.gather(self.fast_round(propose_id, cmd), self.slow_round(propose_id))
results = await asyncio.gather(self.fast_round(propose_id, cmd), self.slow_round(propose_id, cmd))
for result in results:
if isinstance(result[0], SyncResponse) and isinstance(result[1], CommandResponse):
return (result[1], result[0])
Expand Down Expand Up @@ -153,9 +206,17 @@ async def fast_round(self, propose_id: ProposeId, cmd: Command) -> tuple[Command
res: ProposeResponse = ProposeResponse()
try:
res = await future
except Exception as e:
except grpc.RpcError as e:
logging.warning(e)
continue
curp_err = CurpError()
dtl = e.details()
curp_err.ParseFromString(dtl)
if curp_err.HasField("ShuttingDown"):
raise ShuttingDownError from e
elif curp_err.HasField("WrongClusterVersion"):
raise WrongClusterVersionError from e
else:
continue

ok_cnt += 1
if not res.HasField("result"):
Expand All @@ -178,7 +239,7 @@ async def fast_round(self, propose_id: ProposeId, cmd: Command) -> tuple[Command
logging.info("fast round failed. propose id: %s", propose_id)
return (cmd_res, False)

async def slow_round(self, propose_id: ProposeId) -> tuple[SyncResponse, CommandResponse]:
async def slow_round(self, propose_id: ProposeId, cmd: Command) -> tuple[SyncResponse, CommandResponse]:
"""
The slow round of Curp protocol
"""
Expand All @@ -190,9 +251,33 @@ async def slow_round(self, propose_id: ProposeId) -> tuple[SyncResponse, Command

channel = self.connects[self.state.leader]
stub = ProtocolStub(channel)
res: WaitSyncedResponse = await stub.WaitSynced(
WaitSyncedRequest(propose_id=propose_id, cluster_version=self.cluster_version)
)

res = WaitSyncedResponse()
try:
res: WaitSyncedResponse = await stub.WaitSynced(
WaitSyncedRequest(propose_id=propose_id, cluster_version=self.cluster_version)
)
except grpc.RpcError as e:
logging.warning("wait synced rpc error: %s", e)
curp_err = CurpError()
details = e.details()
curp_err.ParseFromString(details)
if curp_err.HasField("ShuttingDown"):
raise ShuttingDownError from e
elif curp_err.HasField("WrongClusterVersion"):
raise WrongClusterVersionError from e
elif curp_err.HasField("RpcTransport"):
# it's quite likely that the leader has crashed,
# then we should wait for some time and fetch the leader again
self.resend_propose(propose_id, cmd, None)
elif curp_err.HasField("redirect"):
new_leader = curp_err.redirect.leader_id
term = curp_err.redirect.term
self.state.check_and_update(new_leader, term)
# resend the propose to the new leader
self.resend_propose(propose_id, cmd, None)
else:
raise InternalError from e

if res.after_sync_result.ok:
asr.ParseFromString(res.after_sync_result.ok)
Expand All @@ -208,6 +293,44 @@ async def slow_round(self, propose_id: ProposeId) -> tuple[SyncResponse, Command

return (asr, er)

def resend_propose(self, propose_id: ProposeId, cmd: Command, new_leader: ServerId | None) -> True | None:
"""
Resend the propose only to the leader.
This is used when leader changes and we need to ensure that the propose is received by the new leader.
"""
leader_id: int | None = None
if new_leader is not None:
_id = new_leader
try:
self.fetch_leader()
leader_id = _id
except Exception as e:
logging.warning("failed to fetch leader, %s", e)
logging.debug("resend propose to %s", leader_id)

stub = ProtocolStub(self.get_connect(leader_id))

try:
stub.Propose(
ProposeRequest(
propose_id=propose_id, command=cmd.SerializeToString(), cluster_version=self.cluster_version
)
)
except grpc.RpcError as e:
# if the propose fails again, need to fetch the leader and try again
logging.warning("failed to resend propose, %s", e)
curp_err = CurpError()
dtl = e.details()
curp_err.ParseFromString(dtl)
if curp_err.HasField("ShuttingDown"):
raise ShuttingDownError from e
elif curp_err.HasField("WrongClusterVersion"):
raise WrongClusterVersionError from e
elif curp_err.HasField("Duplicated"):
return True
else:
return None

def gen_propose_id(self) -> ProposeId:
"""
Generate a propose id
Expand All @@ -232,6 +355,18 @@ def get_client_id(self) -> int:
"""
return random.randint(0, 2**64 - 1)

def all_connects(self) -> list[grpc.Channel]:
"""
Get all connects
"""
return list(self.connects.values())

def get_connect(self, _id: ServerId) -> grpc.Channel:
"""
Get all connects
"""
return self.connects[_id]

@staticmethod
def super_quorum(nodes: int) -> int:
"""
Expand All @@ -258,9 +393,44 @@ class State:
term: Current term
"""

leader: int
leader: int | None
term: int

def __init__(self, leader: int, term: int) -> None:
def __init__(self, leader: int | None, term: int) -> None:
self.leader = leader
self.term = term

def check_and_update(self, leader_id: int | None, term: int):
"""
Check the term and leader id, update the state if needed
"""
if self.term < term:
# reset term only when the resp has leader id to prevent:
# If a server loses contact with its leader, it will update its term for election.
# Since other servers are all right, the election will not succeed.
# But if the client learns about the new term and updates its term to it, it will never get the true leader.
if leader_id is not None:
new_leader_id = leader_id
self.update_to_term(term)
self.set_leader(new_leader_id)
elif self.term == term:
if leader_id is not None:
new_leader_id = leader_id
if self.leader is None:
self.set_leader(new_leader_id)
else:
pass

def update_to_term(self, term: int) -> None:
"""
Update to the newest term and reset local cache
"""
self.term = term
self.leader = None

def set_leader(self, _id: ServerId) -> None:
"""
Set the leader and notify all the waiters
"""
logging.debug("client update its leader to %s", _id)
self.leader = _id

0 comments on commit c9c1c28

Please sign in to comment.