Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(robot-server): Use a non-nullable indexed enum for protocol_kind #15822

Merged
merged 8 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions robot-server/robot_server/persistence/_migrations/v5_to_v6.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
will still be available as part of the completed analysis blob.
- Adds a new analysis_csv_rtp_table to store the CSV parameters' file IDs used in analysis
- Adds a new run_csv_rtp_table to store the CSV parameters' file IDs used in runs
- Converts protocol.protocol_kind to a constrained string (a SQL "enum"), makes it
non-nullable (NULL was semantically equivalent to "standard"), and adds an index.
"""

from pathlib import Path
Expand Down Expand Up @@ -63,12 +65,9 @@ def _migrate_db_with_changes(
source_transaction: sqlalchemy.engine.Connection,
dest_transaction: sqlalchemy.engine.Connection,
) -> None:
copy_rows_unmodified(
schema_5.protocol_table,
schema_6.protocol_table,
_migrate_protocol_table_with_new_protocol_kind_col(
source_transaction,
dest_transaction,
order_by_rowid=True,
)
_migrate_analysis_table_excluding_rtp_defaults_and_vals(
source_transaction,
Expand Down Expand Up @@ -97,6 +96,31 @@ def _migrate_db_with_changes(
)


def _migrate_protocol_table_with_new_protocol_kind_col(
source_transaction: sqlalchemy.engine.Connection,
dest_transaction: sqlalchemy.engine.Connection,
) -> None:
"""Add a new 'protocol_kind' column to protocols table."""
select_old_protocols = sqlalchemy.select(schema_5.protocol_table).order_by(
sqlite_rowid
)
insert_new_protocol = sqlalchemy.insert(schema_6.protocol_table)
for old_row in source_transaction.execute(select_old_protocols).all():
new_protocol_kind = (
# Account for old_row.protocol_kind being NULL.
schema_6.ProtocolKindSQLEnum.QUICK_TRANSFER
if old_row.protocol_kind == "quick-transfer"
else schema_6.ProtocolKindSQLEnum.STANDARD
)
dest_transaction.execute(
insert_new_protocol,
id=old_row.id,
created_at=old_row.created_at,
protocol_key=old_row.protocol_key,
protocol_kind=new_protocol_kind,
)


def _migrate_analysis_table_excluding_rtp_defaults_and_vals(
source_transaction: sqlalchemy.engine.Connection,
dest_transaction: sqlalchemy.engine.Connection,
Expand Down
2 changes: 2 additions & 0 deletions robot-server/robot_server/persistence/tables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
action_table,
data_files_table,
PrimitiveParamSQLEnum,
ProtocolKindSQLEnum,
)


Expand All @@ -26,4 +27,5 @@
"action_table",
"data_files_table",
"PrimitiveParamSQLEnum",
"ProtocolKindSQLEnum",
]
18 changes: 17 additions & 1 deletion robot-server/robot_server/persistence/tables/schema_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ class PrimitiveParamSQLEnum(enum.Enum):
STR = "str"


class ProtocolKindSQLEnum(enum.Enum):
"""What kind a stored protocol is."""

STANDARD = "standard"
QUICK_TRANSFER = "quick-transfer"


protocol_table = sqlalchemy.Table(
"protocol",
metadata,
Expand All @@ -30,7 +37,16 @@ class PrimitiveParamSQLEnum(enum.Enum):
nullable=False,
),
sqlalchemy.Column("protocol_key", sqlalchemy.String, nullable=True),
sqlalchemy.Column("protocol_kind", sqlalchemy.String, nullable=True),
sqlalchemy.Column(
"protocol_kind",
sqlalchemy.Enum(
ProtocolKindSQLEnum,
values_callable=lambda obj: [e.value for e in obj],
create_constraint=True,
),
index=True,
nullable=False,
),
)

analysis_table = sqlalchemy.Table(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def make_room_for_new_protocol(self) -> None:
protocols = {
p.protocol_id
for p in self._protocol_store.get_all()
if p.protocol_kind == self._protocol_kind.value
if p.protocol_kind == self._protocol_kind
}

protocol_run_usage_info = [
Expand Down
8 changes: 0 additions & 8 deletions robot-server/robot_server/protocols/protocol_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ class ProtocolKind(str, Enum):
STANDARD = "standard"
QUICK_TRANSFER = "quick-transfer"

@staticmethod
def from_string(name: Optional[str]) -> Optional["ProtocolKind"]:
"""Get the ProtocolKind from a string."""
for item in ProtocolKind:
if name == item.value:
return item
return None


class ProtocolFile(BaseModel):
"""A file in a protocol."""
Expand Down
34 changes: 26 additions & 8 deletions robot-server/robot_server/protocols/protocol_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
analysis_primitive_type_rtp_table,
analysis_csv_rtp_table,
data_files_table,
ProtocolKindSQLEnum,
)
from robot_server.protocols.protocol_models import ProtocolKind


_CACHE_ENTRIES = 32
Expand All @@ -41,7 +43,7 @@ class ProtocolResource:
created_at: datetime
source: ProtocolSource
protocol_key: Optional[str]
protocol_kind: Optional[str]
protocol_kind: ProtocolKind


@dataclass(frozen=True)
Expand Down Expand Up @@ -173,7 +175,7 @@ def insert(self, resource: ProtocolResource) -> None:
protocol_id=resource.protocol_id,
created_at=resource.created_at,
protocol_key=resource.protocol_key,
protocol_kind=resource.protocol_kind,
protocol_kind=_http_protocol_kind_to_sql(resource.protocol_kind),
)
)
self._sources_by_id[resource.protocol_id] = resource.source
Expand All @@ -191,7 +193,7 @@ def get(self, protocol_id: str) -> ProtocolResource:
protocol_id=sql_resource.protocol_id,
created_at=sql_resource.created_at,
protocol_key=sql_resource.protocol_key,
protocol_kind=sql_resource.protocol_kind,
protocol_kind=_sql_protocol_kind_to_http(sql_resource.protocol_kind),
source=self._sources_by_id[sql_resource.protocol_id],
)

Expand All @@ -207,7 +209,7 @@ def get_all(self) -> List[ProtocolResource]:
protocol_id=r.protocol_id,
created_at=r.created_at,
protocol_key=r.protocol_key,
protocol_kind=r.protocol_kind,
protocol_kind=_sql_protocol_kind_to_http(r.protocol_kind),
source=self._sources_by_id[r.protocol_id],
)
for r in all_sql_resources
Expand Down Expand Up @@ -525,7 +527,7 @@ class _DBProtocolResource:
protocol_id: str
created_at: datetime
protocol_key: Optional[str]
protocol_kind: Optional[str]
protocol_kind: ProtocolKindSQLEnum


def _convert_sql_row_to_dataclass(
Expand All @@ -540,9 +542,9 @@ def _convert_sql_row_to_dataclass(
assert protocol_key is None or isinstance(
protocol_key, str
), f"Protocol Key {protocol_key} not a string or None"
assert protocol_kind is None or isinstance(
protocol_kind, str
), f"Protocol Kind {protocol_kind} not a string or None"
assert isinstance(
protocol_kind, ProtocolKindSQLEnum
), f"Protocol Kind {protocol_kind} not the expected enum"

return _DBProtocolResource(
protocol_id=protocol_id,
Expand All @@ -561,3 +563,19 @@ def _convert_dataclass_to_sql_values(
"protocol_key": resource.protocol_key,
"protocol_kind": resource.protocol_kind,
}


def _http_protocol_kind_to_sql(http_enum: ProtocolKind) -> ProtocolKindSQLEnum:
match http_enum:
case ProtocolKind.STANDARD:
return ProtocolKindSQLEnum.STANDARD
case ProtocolKind.QUICK_TRANSFER:
return ProtocolKindSQLEnum.QUICK_TRANSFER


def _sql_protocol_kind_to_http(sql_enum: ProtocolKindSQLEnum) -> ProtocolKind:
match sql_enum:
case ProtocolKindSQLEnum.STANDARD:
return ProtocolKind.STANDARD
case ProtocolKindSQLEnum.QUICK_TRANSFER:
return ProtocolKind.QUICK_TRANSFER
57 changes: 31 additions & 26 deletions robot-server/robot_server/protocols/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ async def create_protocol( # noqa: C901
" always trigger an analysis (for now).",
alias="runTimeParameterValues",
),
protocol_kind: Optional[ProtocolKind] = Form(
default=None,
protocol_kind: ProtocolKind = Form(
# This default needs to be kept in sync with the function body.
# See todo comments.
default=ProtocolKind.STANDARD,
description=(
"Whether this is a `standard` protocol or a `quick-transfer` protocol."
"if omitted, the protocol will be `standard` by default."
Expand Down Expand Up @@ -277,12 +279,30 @@ async def create_protocol( # noqa: C901
created_at: Timestamp to attach to the new resource.
maximum_quick_transfer_protocols: Robot setting value limiting stored quick transfers protocols.
"""
kind = ProtocolKind.from_string(protocol_kind) or ProtocolKind.STANDARD
if kind == ProtocolKind.QUICK_TRANSFER:
# We have to do these isinstance checks because if `runTimeParameterValues` or
# `protocolKind` are not specified in the request, then they get assigned a
# Form(default) value instead of just the default value. \(O.o)/
# TODO: check if we can make our own "RTP multipart-form field" Pydantic type
# so we can validate the data contents and return a better error response.
# TODO: check if this is still necessary after converting FastAPI args to Annotated.
parsed_rtp_values = (
json.loads(run_time_parameter_values)
if isinstance(run_time_parameter_values, str)
else {}
)
parsed_rtp_files = (
json.loads(run_time_parameter_files)
if isinstance(run_time_parameter_files, str)
else {}
)
if not isinstance(protocol_kind, ProtocolKind):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that protocol_kind: Form default is ProtocolKind.STANDARD, do we need to set it here explicitly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird. We should not, but apparently we do. See the TODO comments above.

protocol_kind = ProtocolKind.STANDARD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think previously if the protocol kind specified in the request was not one of the defined kinds then it would raise an error. Are we intentionally ignoring it now?

Copy link
Contributor Author

@SyntaxColoring SyntaxColoring Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the actual running server, FastAPI will still reject invalid protocolKinds.

        {
            "id": "InvalidRequest",
            "title": "Invalid Request",
            "detail": "value is not a valid enumeration member; permitted: 'standard', 'quick-transfer'",
            "source": {
                "pointer": "/protocolKind"
            },
            "errorCode": "4000"
        }

This is a weird unit-test-only conditional.

I think this can be fixed by adopting FastAPI's newer Annotated[...] syntax instead of the current = Depends(...) syntax, which I'm working on on the side.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right right!

Ya, switching to Annotated would be nice.


if protocol_kind == ProtocolKind.QUICK_TRANSFER:
quick_transfer_protocols = [
protocol
for protocol in protocol_store.get_all()
if protocol.protocol_kind == ProtocolKind.QUICK_TRANSFER.value
if protocol.protocol_kind == ProtocolKind.QUICK_TRANSFER
]
if len(quick_transfer_protocols) >= maximum_quick_transfer_protocols:
raise HTTPException(
Expand All @@ -294,21 +314,6 @@ async def create_protocol( # noqa: C901
assert file.filename is not None
buffered_files = await file_reader_writer.read(files=files) # type: ignore[arg-type]

# We have to do this isinstance check because if `runTimeParameterValues` is
# not specified in the request, then it gets assigned a Form(None) value
# instead of just a None. \(O.o)/
# TODO: check if we can make our own "RTP multipart-form field" Pydantic type
# so we can validate the data contents and return a better error response.
parsed_rtp_values = (
json.loads(run_time_parameter_values)
if isinstance(run_time_parameter_values, str)
else {}
)
parsed_rtp_files = (
json.loads(run_time_parameter_files)
if isinstance(run_time_parameter_files, str)
else {}
)
content_hash = await file_hasher.hash(buffered_files)
cached_protocol_id = protocol_store.get_id_by_hash(content_hash)

Expand Down Expand Up @@ -340,7 +345,7 @@ async def _get_cached_protocol_analysis() -> PydanticResponse[
data = Protocol.construct(
id=cached_protocol_id,
createdAt=resource.created_at,
protocolKind=ProtocolKind.from_string(resource.protocol_kind),
protocolKind=resource.protocol_kind,
protocolType=resource.source.config.protocol_type,
robotType=resource.source.robot_type,
metadata=Metadata.parse_obj(resource.source.metadata),
Expand Down Expand Up @@ -390,11 +395,11 @@ async def _get_cached_protocol_analysis() -> PydanticResponse[
created_at=created_at,
source=source,
protocol_key=key,
protocol_kind=kind.value,
protocol_kind=protocol_kind,
)

protocol_deleter: ProtocolAutoDeleter = protocol_auto_deleter
if kind == ProtocolKind.QUICK_TRANSFER:
if protocol_kind == ProtocolKind.QUICK_TRANSFER:
protocol_deleter = quick_transfer_protocol_auto_deleter
protocol_deleter.make_room_for_new_protocol()
protocol_store.insert(protocol_resource)
Expand All @@ -413,7 +418,7 @@ async def _get_cached_protocol_analysis() -> PydanticResponse[
data = Protocol(
id=protocol_id,
createdAt=created_at,
protocolKind=kind,
protocolKind=protocol_kind,
protocolType=source.config.protocol_type,
robotType=source.robot_type,
metadata=Metadata.parse_obj(source.metadata),
Expand Down Expand Up @@ -523,7 +528,7 @@ async def get_protocols(
Protocol.construct(
id=r.protocol_id,
createdAt=r.created_at,
protocolKind=ProtocolKind.from_string(r.protocol_kind),
protocolKind=r.protocol_kind,
protocolType=r.source.config.protocol_type,
robotType=r.source.robot_type,
metadata=Metadata.parse_obj(r.source.metadata),
Expand Down Expand Up @@ -604,7 +609,7 @@ async def get_protocol_by_id(
data = Protocol.construct(
id=protocolId,
createdAt=resource.created_at,
protocolKind=ProtocolKind.from_string(resource.protocol_kind),
protocolKind=resource.protocol_kind,
protocolType=resource.source.config.protocol_type,
robotType=resource.source.robot_type,
metadata=Metadata.parse_obj(resource.source.metadata),
Expand Down
4 changes: 1 addition & 3 deletions robot-server/robot_server/runs/run_auto_deleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def make_room_for_new_run(self) -> None: # noqa: D102
protocols = self._protocol_store.get_all()
protocol_ids = [p.protocol_id for p in protocols]
filtered_protocol_ids = [
p.protocol_id
for p in protocols
if p.protocol_kind == self._protocol_kind.value
p.protocol_id for p in protocols if p.protocol_kind == self._protocol_kind
]

# runs with no protocols first, then oldest to newest.
Expand Down
8 changes: 6 additions & 2 deletions robot-server/tests/persistence/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
id VARCHAR NOT NULL,
created_at DATETIME NOT NULL,
protocol_key VARCHAR,
protocol_kind VARCHAR,
PRIMARY KEY (id)
protocol_kind VARCHAR(14) NOT NULL,
PRIMARY KEY (id),
CONSTRAINT protocolkindsqlenum CHECK (protocol_kind IN ('standard', 'quick-transfer'))
)
""",
"""
Expand Down Expand Up @@ -114,6 +115,9 @@
CREATE UNIQUE INDEX ix_run_run_id_index_in_run ON run_command (run_id, index_in_run)
""",
"""
CREATE INDEX ix_protocol_protocol_kind ON protocol (protocol_kind)
""",
"""
CREATE TABLE data_files (
id VARCHAR NOT NULL,
name VARCHAR NOT NULL,
Expand Down
6 changes: 3 additions & 3 deletions robot-server/tests/protocols/test_analyses_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def test_initialize_analyzer(
content_hash="abc123",
),
protocol_key="dummy-data-111",
protocol_kind=ProtocolKind.STANDARD.value,
protocol_kind=ProtocolKind.STANDARD,
)
analyzer = decoy.mock(cls=protocol_analyzer.ProtocolAnalyzer)
decoy.when(
Expand Down Expand Up @@ -128,7 +128,7 @@ async def test_raises_error_and_saves_result_if_initialization_errors(
content_hash="abc123",
),
protocol_key="dummy-data-111",
protocol_kind=ProtocolKind.STANDARD.value,
protocol_kind=ProtocolKind.STANDARD,
)
raised_exception = Exception("Oh noooo!")
enumerated_error = EnumeratedError(
Expand Down Expand Up @@ -197,7 +197,7 @@ async def test_start_analysis(
content_hash="abc123",
),
protocol_key="dummy-data-111",
protocol_kind=ProtocolKind.STANDARD.value,
protocol_kind=ProtocolKind.STANDARD,
)
bool_parameter = BooleanParameter(
displayName="Foo", variableName="Bar", default=True, value=False
Expand Down
Loading
Loading