diff --git a/nucleus/deploy/cli/__init__.py b/nucleus/deploy/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nucleus/deploy/cli/bin.py b/nucleus/deploy/cli/bin.py new file mode 100644 index 00000000..43a0cfab --- /dev/null +++ b/nucleus/deploy/cli/bin.py @@ -0,0 +1,27 @@ +import click + +from nucleus.deploy.cli.bundles import bundles +from nucleus.deploy.cli.endpoints import endpoints + + +@click.group("cli") +def entry_point(): + """Launch CLI + + \b + ██╗ █████╗ ██╗ ██╗███╗ ██╗ ██████╗██╗ ██╗ + ██║ ██╔══██╗██║ ██║████╗ ██║██╔════╝██║ ██║ + ██║ ███████║██║ ██║██╔██╗ ██║██║ ███████║ + ██║ ██╔══██║██║ ██║██║╚██╗██║██║ ██╔══██║ + ███████╗██║ ██║╚██████╔╝██║ ╚████║╚██████╗██║ ██║ + ╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝ + + `scale-launch` is a command line interface to interact with Scale Launch + """ + + +entry_point.add_command(bundles) # type: ignore +entry_point.add_command(endpoints) # type: ignore + +if __name__ == "__main__": + entry_point() diff --git a/nucleus/deploy/cli/bundles.py b/nucleus/deploy/cli/bundles.py new file mode 100644 index 00000000..60d1826f --- /dev/null +++ b/nucleus/deploy/cli/bundles.py @@ -0,0 +1,72 @@ +import click +from rich.console import Console +from rich.syntax import Syntax +from rich.table import Column, Table + +from nucleus.deploy.cli.client import init_client + + +@click.group("bundles") +def bundles(): + """Bundles is a wrapper around model bundles in Scale Launch""" + + +@bundles.command("list") +def list_bundles(): + """List all of your Bundles""" + client = init_client() + + table = Table( + Column("Bundle Id", overflow="fold", min_width=24), + "Bundle name", + "Location", + "Packaging type", + title="Bundles", + title_justify="left", + ) + + for model_bundle in client.list_model_bundles(): + table.add_row( + model_bundle.bundle_id, + model_bundle.bundle_name, + model_bundle.location, + model_bundle.packaging_type, + ) + console = Console() + console.print(table) + + +@bundles.command("get") +@click.argument("bundle_name") +def get_bundle(bundle_name): + """Print bundle info""" + client = init_client() + + model_bundle = client.get_model_bundle(bundle_name) + + console = Console() + console.print(f"bundle_id: {model_bundle.bundle_id}") + console.print(f"bundle_name: {model_bundle.bundle_name}") + console.print(f"location: {model_bundle.location}") + console.print(f"packaging_type: {model_bundle.packaging_type}") + console.print(f"env_params: {model_bundle.env_params}") + console.print(f"requirements: {model_bundle.requirements}") + + console.print("metadata:") + for meta_name, meta_value in model_bundle.metadata.items(): + # TODO print non-code metadata differently + console.print(f"{meta_name}:", style="yellow") + syntax = Syntax(meta_value, "python") + console.print(syntax) + + +@bundles.command("delete") +@click.argument("bundle_name") +def delete_bundle(bundle_name): + """Delete a model bundle""" + client = init_client() + + console = Console() + model_bundle = client.get_model_bundle(bundle_name) + res = client.delete_model_bundle(model_bundle) + console.print(res) diff --git a/nucleus/deploy/cli/client.py b/nucleus/deploy/cli/client.py new file mode 100644 index 00000000..64dc40c6 --- /dev/null +++ b/nucleus/deploy/cli/client.py @@ -0,0 +1,14 @@ +import functools +import os + +import nucleus + + +@functools.lru_cache() +def init_client(): + api_key = os.environ.get("LAUNCH_API_KEY", None) + if api_key: + client = nucleus.deploy.DeployClient(api_key) + else: + raise RuntimeError("No LAUNCH_API_KEY set") + return client diff --git a/nucleus/deploy/cli/endpoints.py b/nucleus/deploy/cli/endpoints.py new file mode 100644 index 00000000..1a7d4511 --- /dev/null +++ b/nucleus/deploy/cli/endpoints.py @@ -0,0 +1,48 @@ +import click +from rich.console import Console +from rich.table import Table + +from nucleus.deploy.cli.client import init_client +from nucleus.deploy.model_endpoint import AsyncModelEndpoint, Endpoint + + +@click.group("endpoints") +def endpoints(): + """Endpoints is a wrapper around model bundles in Scale Launch""" + + +@endpoints.command("list") +def list_endpoints(): + """List all of your Bundles""" + client = init_client() + + table = Table( + "Endpoint name", + "Metadata", + "Endpoint type", + title="Endpoints", + title_justify="left", + ) + + for endpoint_sync_async in client.list_model_endpoints(): + endpoint = endpoint_sync_async.endpoint + table.add_row( + endpoint.name, + endpoint.metadata, + endpoint.endpoint_type, + ) + console = Console() + console.print(table) + + +@endpoints.command("delete") +@click.argument("endpoint_name") +def delete_bundle(endpoint_name): + """Delete a model bundle""" + client = init_client() + + console = Console() + endpoint = Endpoint(name=endpoint_name) + dummy_endpoint = AsyncModelEndpoint(endpoint=endpoint, client=client) + res = client.delete_model_endpoint(dummy_endpoint) + console.print(res) diff --git a/nucleus/deploy/client.py b/nucleus/deploy/client.py index b83b03e1..a53f59f6 100644 --- a/nucleus/deploy/client.py +++ b/nucleus/deploy/client.py @@ -22,7 +22,11 @@ get_imports, ) from nucleus.deploy.model_bundle import ModelBundle -from nucleus.deploy.model_endpoint import AsyncModelEndpoint, SyncModelEndpoint +from nucleus.deploy.model_endpoint import ( + AsyncModelEndpoint, + Endpoint, + SyncModelEndpoint, +) from nucleus.deploy.request_validation import validate_task_request DEFAULT_NETWORK_TIMEOUT_SEC = 120 @@ -384,7 +388,7 @@ def create_model_endpoint( """ payload = dict( endpoint_name=endpoint_name, - bundle_name=model_bundle.name, + bundle_name=model_bundle.bundle_name, cpus=cpus, memory=memory, gpus=gpus, @@ -411,10 +415,11 @@ def create_model_endpoint( logger.info( "Endpoint creation task id is %s", endpoint_creation_task_id ) + endpoint = Endpoint(name=endpoint_name) if endpoint_type == "async": - return AsyncModelEndpoint(endpoint_id=endpoint_name, client=self) + return AsyncModelEndpoint(endpoint=endpoint, client=self) elif endpoint_type == "sync": - return SyncModelEndpoint(endpoint_id=endpoint_name, client=self) + return SyncModelEndpoint(endpoint=endpoint, client=self) else: raise ValueError( "Endpoint should be one of the types 'sync' or 'async'" @@ -431,10 +436,23 @@ def list_model_bundles(self) -> List[ModelBundle]: """ resp = self.connection.get("model_bundle") model_bundles = [ - ModelBundle(name=item["bundle_name"]) for item in resp["bundles"] + ModelBundle.from_dict(item) for item in resp["bundles"] # type: ignore ] return model_bundles + def get_model_bundle(self, bundle_name: str) -> ModelBundle: + """ + Returns a Model Bundle object specified by `bundle_name`. + + Returns: + A ModelBundle object + """ + resp = self.connection.get(f"model_bundle/{bundle_name}") + assert ( + len(resp["bundles"]) == 1 + ), f"Bundle with name `{bundle_name}` not found" + return ModelBundle.from_dict(resp["bundles"][0]) # type: ignore + def list_model_endpoints( self, ) -> List[Union[AsyncModelEndpoint, SyncModelEndpoint]]: @@ -447,12 +465,16 @@ def list_model_endpoints( """ resp = self.connection.get(ENDPOINT_PATH) async_endpoints: List[Union[AsyncModelEndpoint, SyncModelEndpoint]] = [ - AsyncModelEndpoint(endpoint_id=endpoint["name"], client=self) + AsyncModelEndpoint( + endpoint=Endpoint.from_dict(endpoint), client=self # type: ignore + ) for endpoint in resp["endpoints"] if endpoint["endpoint_type"] == "async" ] sync_endpoints: List[Union[AsyncModelEndpoint, SyncModelEndpoint]] = [ - SyncModelEndpoint(endpoint_id=endpoint["name"], client=self) + SyncModelEndpoint( + endpoint=Endpoint.from_dict(endpoint), client=self # type: ignore + ) for endpoint in resp["endpoints"] if endpoint["endpoint_type"] == "sync" ] @@ -462,7 +484,7 @@ def delete_model_bundle(self, model_bundle: ModelBundle): """ Deletes the model bundle on the server. """ - route = f"model_bundle/{model_bundle.name}" + route = f"model_bundle/{model_bundle.bundle_name}" resp = self.connection.delete(route) return resp["deleted"] @@ -472,7 +494,7 @@ def delete_model_endpoint( """ Deletes a model endpoint. """ - route = f"{ENDPOINT_PATH}/{model_endpoint.endpoint_id}" + route = f"{ENDPOINT_PATH}/{model_endpoint.endpoint.name}" resp = self.connection.delete(route) return resp["deleted"] diff --git a/nucleus/deploy/model_bundle.py b/nucleus/deploy/model_bundle.py index c8391f04..8bf5b4d5 100644 --- a/nucleus/deploy/model_bundle.py +++ b/nucleus/deploy/model_bundle.py @@ -1,11 +1,23 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from dataclasses_json import Undefined, dataclass_json + + +@dataclass_json(undefined=Undefined.EXCLUDE) +@dataclass class ModelBundle: """ Represents a ModelBundle. - TODO fill this out with more than just a name potentially. """ - def __init__(self, name): - self.name = name + bundle_name: str + bundle_id: Optional[str] = None + env_params: Optional[Dict[str, str]] = None + location: Optional[str] = None + metadata: Optional[Dict[Any, Any]] = None + packaging_type: Optional[str] = None + requirements: Optional[List[str]] = None def __str__(self): - return f"ModelBundle(name={self.name})" + return f"ModelBundle(bundle_name={self.bundle_name})" diff --git a/nucleus/deploy/model_endpoint.py b/nucleus/deploy/model_endpoint.py index 756d79a7..ab6a6df0 100644 --- a/nucleus/deploy/model_endpoint.py +++ b/nucleus/deploy/model_endpoint.py @@ -1,8 +1,11 @@ import concurrent.futures import uuid from collections import Counter +from dataclasses import dataclass from typing import Dict, Optional, Sequence +from dataclasses_json import Undefined, dataclass_json + from nucleus.deploy.request_validation import validate_task_request TASK_PENDING_STATE = "PENDING" @@ -10,6 +13,21 @@ TASK_FAILURE_STATE = "FAILURE" +@dataclass_json(undefined=Undefined.EXCLUDE) +@dataclass +class Endpoint: + """ + Represents an Endpoint from the database. + """ + + name: str + metadata: Optional[Dict] = None + endpoint_type: Optional[str] = None + + def __str__(self): + return f"Endpoint(name={self.name})" + + class EndpointRequest: """ Represents a single request to either a SyncModelEndpoint or AsyncModelEndpoint. @@ -62,16 +80,16 @@ def __str__(self): class SyncModelEndpoint: - def __init__(self, endpoint_id: str, client): - self.endpoint_id = endpoint_id + def __init__(self, endpoint: Endpoint, client): + self.endpoint = endpoint self.client = client def __str__(self): - return f"SyncModelEndpoint " + return f"SyncModelEndpoint " def predict(self, request: EndpointRequest) -> EndpointResponse: raw_response = self.client.sync_request( - self.endpoint_id, + self.endpoint.name, url=request.url, args=request.args, return_pickled=request.return_pickled, @@ -92,17 +110,17 @@ class AsyncModelEndpoint: A higher level abstraction for a Model Endpoint. """ - def __init__(self, endpoint_id: str, client): + def __init__(self, endpoint: Endpoint, client): """ Parameters: - endpoint_id: The unique name of the ModelEndpoint + endpoint: Endpoint object. client: A DeployClient object """ - self.endpoint_id = endpoint_id + self.endpoint = endpoint self.client = client def __str__(self): - return f"AsyncModelEndpoint " + return f"AsyncModelEndpoint " def predict_batch( self, requests: Sequence[EndpointRequest] @@ -129,7 +147,7 @@ def single_request(request): # request has keys url and args inner_inference_request = self.client.async_request( - endpoint_id=self.endpoint_id, + endpoint_id=self.endpoint.name, url=request.url, args=request.args, return_pickled=request.return_pickled, diff --git a/poetry.lock b/poetry.lock index 5a011e27..26be4f43 100644 --- a/poetry.lock +++ b/poetry.lock @@ -481,6 +481,23 @@ category = "main" optional = false python-versions = ">=3.6, <3.7" +[[package]] +name = "dataclasses-json" +version = "0.5.6" +description = "Easily serialize dataclasses to and from JSON" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +dataclasses = {version = "*", markers = "python_version == \"3.6\""} +marshmallow = ">=3.3.0,<4.0.0" +marshmallow-enum = ">=1.5.1,<2.0.0" +typing-inspect = ">=0.4.0" + +[package.extras] +dev = ["pytest (>=6.2.3)", "ipython", "mypy (>=0.710)", "hypothesis", "portray", "flake8", "simplejson", "types-dataclasses"] + [[package]] name = "decorator" version = "5.1.1" @@ -1017,6 +1034,31 @@ category = "dev" optional = false python-versions = ">=3.6" +[[package]] +name = "marshmallow" +version = "3.14.1" +description = "A lightweight library for converting complex datatypes to and from native Python datatypes." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +dev = ["pytest", "pytz", "simplejson", "mypy (==0.910)", "flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "pre-commit (>=2.4,<3.0)", "tox"] +docs = ["sphinx (==4.3.0)", "sphinx-issues (==1.2.0)", "alabaster (==0.7.12)", "sphinx-version-warning (==1.1.2)", "autodocsumm (==0.2.7)"] +lint = ["mypy (==0.910)", "flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "pre-commit (>=2.4,<3.0)"] +tests = ["pytest", "pytz", "simplejson"] + +[[package]] +name = "marshmallow-enum" +version = "1.5.1" +description = "Enum field for Marshmallow" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +marshmallow = ">=2.0.0" + [[package]] name = "mccabe" version = "0.6.1" @@ -1069,7 +1111,7 @@ dmypy = ["psutil (>=4.0)"] name = "mypy-extensions" version = "0.4.3" description = "Experimental type system extensions for programs checked with the mypy typechecker." -category = "dev" +category = "main" optional = false python-versions = "*" @@ -2047,6 +2089,18 @@ category = "main" optional = false python-versions = ">=3.6" +[[package]] +name = "typing-inspect" +version = "0.7.1" +description = "Runtime inspection utilities for typing module." +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "unidecode" version = "1.3.3" @@ -2153,7 +2207,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" python-versions = "^3.6.2" -content-hash = "e3bcd44367afeb16e45c11b50c33d3b320a2ae7d399cf5dad91ac9f886e394eb" +content-hash = "ebd78ebb5b5693589b5dfa684234d880c07f2ef63b080b51fa40b691bdda82fb" [metadata.files] absl-py = [ @@ -2519,6 +2573,10 @@ dataclasses = [ {file = "dataclasses-0.7-py3-none-any.whl", hash = "sha256:3459118f7ede7c8bea0fe795bff7c6c2ce287d01dd226202f7c9ebc0610a7836"}, {file = "dataclasses-0.7.tar.gz", hash = "sha256:494a6dcae3b8bcf80848eea2ef64c0cc5cd307ffc263e17cdf42f3e5420808e6"}, ] +dataclasses-json = [ + {file = "dataclasses-json-0.5.6.tar.gz", hash = "sha256:1f60be3405dee30b86ffbf6a436db8ba5efaeeb676bfda358e516a97aa7dfce4"}, + {file = "dataclasses_json-0.5.6-py3-none-any.whl", hash = "sha256:1d7f3a284a49d350ddbabde0e7d0c5ffa34a144aaf1bcb5b9f2c87673ff0c76e"}, +] decorator = [ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, @@ -2878,6 +2936,14 @@ markupsafe = [ {file = "MarkupSafe-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:693ce3f9e70a6cf7d2fb9e6c9d8b204b6b39897a2c4a1aa65728d5ac97dcc1d8"}, {file = "MarkupSafe-2.0.1.tar.gz", hash = "sha256:594c67807fb16238b30c44bdf74f36c02cdf22d1c8cda91ef8a0ed8dabf5620a"}, ] +marshmallow = [ + {file = "marshmallow-3.14.1-py3-none-any.whl", hash = "sha256:04438610bc6dadbdddb22a4a55bcc7f6f8099e69580b2e67f5a681933a1f4400"}, + {file = "marshmallow-3.14.1.tar.gz", hash = "sha256:4c05c1684e0e97fe779c62b91878f173b937fe097b356cd82f793464f5bc6138"}, +] +marshmallow-enum = [ + {file = "marshmallow-enum-1.5.1.tar.gz", hash = "sha256:38e697e11f45a8e64b4a1e664000897c659b60aa57bfa18d44e226a9920b6e58"}, + {file = "marshmallow_enum-1.5.1-py2.py3-none-any.whl", hash = "sha256:57161ab3dbfde4f57adeb12090f39592e992b9c86d206d02f6bd03ebec60f072"}, +] mccabe = [ {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, @@ -3701,6 +3767,11 @@ typing-extensions = [ {file = "typing_extensions-4.1.1-py3-none-any.whl", hash = "sha256:21c85e0fe4b9a155d0799430b0ad741cdce7e359660ccbd8b530613e8df88ce2"}, {file = "typing_extensions-4.1.1.tar.gz", hash = "sha256:1a9462dcc3347a79b1f1c0271fbe79e844580bb598bafa1ed208b94da3cdcd42"}, ] +typing-inspect = [ + {file = "typing_inspect-0.7.1-py2-none-any.whl", hash = "sha256:b1f56c0783ef0f25fb064a01be6e5407e54cf4a4bf4f3ba3fe51e0bd6dcea9e5"}, + {file = "typing_inspect-0.7.1-py3-none-any.whl", hash = "sha256:3cd7d4563e997719a710a3bfe7ffb544c6b72069b6812a02e9b414a8fa3aaa6b"}, + {file = "typing_inspect-0.7.1.tar.gz", hash = "sha256:047d4097d9b17f46531bf6f014356111a1b6fb821a24fe7ac909853ca2a782aa"}, +] unidecode = [ {file = "Unidecode-1.3.3-py3-none-any.whl", hash = "sha256:a5a8a4b6fb033724ffba8502af2e65ca5bfc3dd53762dedaafe4b0134ad42e3c"}, {file = "Unidecode-1.3.3.tar.gz", hash = "sha256:8521f2853fd250891dc27d156a9d30e61c4e76319da963c4a1c27083a909ac30"}, diff --git a/pyproject.toml b/pyproject.toml index bcc5d758..0cb18889 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ rich = "^10.15.2" shellingham = "^1.4.0" scikit-learn = ">=0.24.0" cloudpickle = "^2.0.0" +dataclasses-json = "^0.5.6" [tool.poetry.dev-dependencies] poetry = "^1.1.5" @@ -72,6 +73,8 @@ smart_open = "^1.9.0" # For a temporary nucleus -> HMI integration [tool.poetry.scripts] nu = "cli.nu:nu" +scale-launch = 'nucleus.deploy.cli.bin:entry_point' + [tool.pytest.ini_options] markers = [