Skip to content

Commit

Permalink
Make connection initialization more robust (#366)
Browse files Browse the repository at this point in the history
* Add additional connection messages

* Improve debugging output

* Retry connections on startup, pass back connection errors

* Fix oneof bug
  • Loading branch information
geoffxy authored Nov 14, 2023
1 parent 55f7bbe commit 59b55fd
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 177 deletions.
9 changes: 8 additions & 1 deletion proto/brad.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ message StartSessionRequest {
}

message StartSessionResponse {
SessionId id = 1;
oneof result {
SessionId id = 1;
StartSessionError error = 2;
}
}

message StartSessionError {
string error_msg = 1;
}

message RunQueryRequest {
Expand Down
3 changes: 2 additions & 1 deletion src/brad/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ def is_connected(self) -> bool:

class ConnectionFailed(Exception):
"""
Used when
Used when an existing connection fails for any reason, or we failed to
establish a connection to an underlying engine.
"""
21 changes: 19 additions & 2 deletions src/brad/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brad.blueprint.manager import BlueprintManager
from brad.config.engine import Engine
from brad.config.file import ConfigFile
from brad.connection.connection import ConnectionFailed
from brad.daemon.monitor import Monitor
from brad.daemon.messages import (
ShutdownFrontEnd,
Expand Down Expand Up @@ -262,8 +263,24 @@ async def _run_teardown(self):
self._estimator = None

async def start_session(self) -> SessionId:
session_id, _ = await self._sessions.create_new_session()
return session_id
rand_backoff = None
while True:
try:
session_id, _ = await self._sessions.create_new_session()
return session_id
except ConnectionFailed:
if rand_backoff is None:
rand_backoff = RandomizedExponentialBackoff(
max_retries=10, base_delay_s=0.5, max_delay_s=10.0
)
time_to_wait = rand_backoff.wait_time_s()
if time_to_wait is None:
logger.exception(
"Failed to start a new session due to a repeated "
"connection failure (10 retries)."
)
raise
await asyncio.sleep(time_to_wait)

async def end_session(self, session_id: SessionId) -> None:
await self._sessions.end_session(session_id)
Expand Down
10 changes: 8 additions & 2 deletions src/brad/front_end/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import brad.proto_gen.brad_pb2_grpc as rpc
from brad.config.engine import Engine
from brad.config.session import SessionId
from brad.connection.connection import ConnectionFailed
from brad.front_end.brad_interface import BradInterface
from brad.front_end.errors import QueryError

Expand All @@ -24,8 +25,13 @@ def __init__(self, brad: BradInterface):
async def StartSession(
self, _request: b.StartSessionRequest, _context
) -> b.StartSessionResponse:
new_session_id = await self._brad.start_session()
return b.StartSessionResponse(id=b.SessionId(id_value=new_session_id.value()))
try:
new_session_id = await self._brad.start_session()
return b.StartSessionResponse(
id=b.SessionId(id_value=new_session_id.value())
)
except ConnectionFailed as ex:
return b.StartSessionResponse(error=b.StartSessionError(error_msg=repr(ex)))

async def RunQuery(
self, request: b.RunQueryRequest, _context
Expand Down
14 changes: 13 additions & 1 deletion src/brad/grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,19 @@ def close(self) -> None:
def start_session(self) -> SessionId:
assert self._stub is not None
result = self._stub.StartSession(b.StartSessionRequest())
return SessionId(result.id.id_value)
msg_kind = result.WhichOneof("result")
if msg_kind is None:
raise BradClientError(
message="BRAD RPC error: Unspecified start session result."
)
elif msg_kind == "id":
return SessionId(result.id.id_value)
elif msg_kind == "error":
raise BradClientError(message=result.error.error_msg)
else:
raise BradClientError(
message="BRAD RPC error: Unknown start session result."
)

def end_session(self, session_id: SessionId) -> None:
assert self._stub is not None
Expand Down
40 changes: 20 additions & 20 deletions src/brad/proto_gen/blueprint_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

84 changes: 44 additions & 40 deletions src/brad/proto_gen/blueprint_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,55 @@ from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union

ATHENA: Engine
AURORA: Engine
DESCRIPTOR: _descriptor.FileDescriptor
REDSHIFT: Engine

class Engine(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = [] # type: ignore
UNKNOWN: _ClassVar[Engine]
AURORA: _ClassVar[Engine]
REDSHIFT: _ClassVar[Engine]
ATHENA: _ClassVar[Engine]
UNKNOWN: Engine
AURORA: Engine
REDSHIFT: Engine
ATHENA: Engine

class Blueprint(_message.Message):
__slots__ = ["aurora", "policy", "redshift", "schema_name", "tables"]
AURORA_FIELD_NUMBER: _ClassVar[int]
POLICY_FIELD_NUMBER: _ClassVar[int]
REDSHIFT_FIELD_NUMBER: _ClassVar[int]
__slots__ = ["schema_name", "tables", "aurora", "redshift", "policy"]
SCHEMA_NAME_FIELD_NUMBER: _ClassVar[int]
TABLES_FIELD_NUMBER: _ClassVar[int]
aurora: Provisioning
policy: RoutingPolicy
redshift: Provisioning
AURORA_FIELD_NUMBER: _ClassVar[int]
REDSHIFT_FIELD_NUMBER: _ClassVar[int]
POLICY_FIELD_NUMBER: _ClassVar[int]
schema_name: str
tables: _containers.RepeatedCompositeFieldContainer[Table]
aurora: Provisioning
redshift: Provisioning
policy: RoutingPolicy
def __init__(self, schema_name: _Optional[str] = ..., tables: _Optional[_Iterable[_Union[Table, _Mapping]]] = ..., aurora: _Optional[_Union[Provisioning, _Mapping]] = ..., redshift: _Optional[_Union[Provisioning, _Mapping]] = ..., policy: _Optional[_Union[RoutingPolicy, _Mapping]] = ...) -> None: ...

class Index(_message.Message):
__slots__ = ["column_name"]
COLUMN_NAME_FIELD_NUMBER: _ClassVar[int]
column_name: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, column_name: _Optional[_Iterable[str]] = ...) -> None: ...

class Provisioning(_message.Message):
__slots__ = ["instance_type", "num_nodes"]
INSTANCE_TYPE_FIELD_NUMBER: _ClassVar[int]
NUM_NODES_FIELD_NUMBER: _ClassVar[int]
instance_type: str
num_nodes: int
def __init__(self, instance_type: _Optional[str] = ..., num_nodes: _Optional[int] = ...) -> None: ...

class RoutingPolicy(_message.Message):
__slots__ = ["policy"]
POLICY_FIELD_NUMBER: _ClassVar[int]
policy: bytes
def __init__(self, policy: _Optional[bytes] = ...) -> None: ...

class Table(_message.Message):
__slots__ = ["columns", "dependencies", "indexes", "locations", "table_name"]
__slots__ = ["table_name", "columns", "locations", "dependencies", "indexes"]
TABLE_NAME_FIELD_NUMBER: _ClassVar[int]
COLUMNS_FIELD_NUMBER: _ClassVar[int]
LOCATIONS_FIELD_NUMBER: _ClassVar[int]
DEPENDENCIES_FIELD_NUMBER: _ClassVar[int]
INDEXES_FIELD_NUMBER: _ClassVar[int]
LOCATIONS_FIELD_NUMBER: _ClassVar[int]
TABLE_NAME_FIELD_NUMBER: _ClassVar[int]
table_name: str
columns: _containers.RepeatedCompositeFieldContainer[TableColumn]
locations: _containers.RepeatedScalarFieldContainer[Engine]
dependencies: TableDependency
indexes: _containers.RepeatedCompositeFieldContainer[Index]
locations: _containers.RepeatedScalarFieldContainer[Engine]
table_name: str
def __init__(self, table_name: _Optional[str] = ..., columns: _Optional[_Iterable[_Union[TableColumn, _Mapping]]] = ..., locations: _Optional[_Iterable[_Union[Engine, str]]] = ..., dependencies: _Optional[_Union[TableDependency, _Mapping]] = ..., indexes: _Optional[_Iterable[_Union[Index, _Mapping]]] = ...) -> None: ...

class TableColumn(_message.Message):
__slots__ = ["data_type", "is_primary", "name"]
__slots__ = ["name", "data_type", "is_primary"]
NAME_FIELD_NUMBER: _ClassVar[int]
DATA_TYPE_FIELD_NUMBER: _ClassVar[int]
IS_PRIMARY_FIELD_NUMBER: _ClassVar[int]
NAME_FIELD_NUMBER: _ClassVar[int]
name: str
data_type: str
is_primary: bool
name: str
def __init__(self, name: _Optional[str] = ..., data_type: _Optional[str] = ..., is_primary: bool = ...) -> None: ...

class TableDependency(_message.Message):
Expand All @@ -76,5 +63,22 @@ class TableDependency(_message.Message):
transform: str
def __init__(self, source_table_names: _Optional[_Iterable[str]] = ..., transform: _Optional[str] = ...) -> None: ...

class Engine(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): # type: ignore
__slots__ = [] # type: ignore
class Provisioning(_message.Message):
__slots__ = ["instance_type", "num_nodes"]
INSTANCE_TYPE_FIELD_NUMBER: _ClassVar[int]
NUM_NODES_FIELD_NUMBER: _ClassVar[int]
instance_type: str
num_nodes: int
def __init__(self, instance_type: _Optional[str] = ..., num_nodes: _Optional[int] = ...) -> None: ...

class RoutingPolicy(_message.Message):
__slots__ = ["policy"]
POLICY_FIELD_NUMBER: _ClassVar[int]
policy: bytes
def __init__(self, policy: _Optional[bytes] = ...) -> None: ...

class Index(_message.Message):
__slots__ = ["column_name"]
COLUMN_NAME_FIELD_NUMBER: _ClassVar[int]
column_name: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, column_name: _Optional[_Iterable[str]] = ...) -> None: ...
Loading

0 comments on commit 59b55fd

Please sign in to comment.