Skip to content

Commit

Permalink
Remove is_trusted argument from APIs (#1405)
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten authored Feb 20, 2025
1 parent 43f61aa commit d5f1794
Show file tree
Hide file tree
Showing 18 changed files with 108 additions and 141 deletions.
6 changes: 1 addition & 5 deletions docs/examples/09-private-huggingface.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,8 @@ system_packages: []
```
# Deploying the model

An important note for deploying models with secrets is that
you must use the `--trusted` flag to give the model access to
secrets stored on the remote secrets manager.

```bash
$ truss push --trusted
$ truss push
```

After the model finishes deploying, you can invoke it with:
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/performance/cached-weights.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ You'll need a [Baseten API key](https://app.baseten.co/settings/account/api_keys
We have successfully packaged Llama 2 as a Truss. Let's deploy!

```sh
truss push --trusted
truss push
```

### Step 5: Invoke the model
Expand Down
4 changes: 1 addition & 3 deletions docs/examples/private-model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,8 @@ You'll need a [Baseten API key](https://app.baseten.co/settings/account/api_keys

We have successfully packaged a gated model as a Truss. Let's deploy!

Use `--trusted` with `truss push` to give the model server access to secrets stored on the remote host.

```sh
truss push --trusted
truss push
```

Wait for the model to finish deployment before invoking.
Expand Down
4 changes: 2 additions & 2 deletions truss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
__version__ = get_version(__name__, Path(__file__).parent.parent)


def version():
return __version__
def version() -> str:
return __version__ or ""


from truss.api import login, push, whoami
Expand Down
12 changes: 9 additions & 3 deletions truss/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import TYPE_CHECKING, Optional, Type, cast

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,7 +58,7 @@ def push(
publish: bool = False,
promote: bool = False,
preserve_previous_production_deployment: bool = False,
trusted: bool = False,
trusted: Optional[bool] = None,
deployment_name: Optional[str] = None,
environment: Optional[str] = None,
progress_bar: Optional[Type["progress.Progress"]] = None,
Expand All @@ -76,7 +77,7 @@ def push(
preserve_previous_production_deployment: Preserve the previous production deployment’s autoscaling
setting. When not specified, the previous production deployment will be updated to allow it to
scale to zero. Can only be use in combination with `promote` option.
trusted: Give Truss access to secrets on remote host.
trusted: [DEPRECATED]
deployment_name: Name of the deployment created by the push. Can only be
used in combination with `publish` or `promote`. Deployment name must
only contain alphanumeric, ’.’, ’-’ or ’_’ characters.
Expand All @@ -86,6 +87,12 @@ def push(
Returns:
The newly created ModelDeployment.
"""
if trusted is not None:
warnings.warn(
"`trusted` is deprecated and will be ignored, all models are "
"trusted by default now.",
DeprecationWarning,
)

if not remote:
available_remotes = RemoteFactory.get_available_config_names()
Expand All @@ -112,7 +119,6 @@ def push(
tr,
model_name=model_name,
publish=publish,
trusted=trusted,
promote=promote,
preserve_previous_prod_deployment=preserve_previous_production_deployment,
deployment_name=deployment_name,
Expand Down
5 changes: 2 additions & 3 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ def push(
remote: str,
model_name: str,
publish: bool = False,
trusted: bool = False,
trusted: Optional[bool] = None,
disable_truss_download: bool = False,
promote: bool = False,
preserve_previous_production_deployment: bool = False,
Expand Down Expand Up @@ -1167,7 +1167,7 @@ def push(
tr.spec.config.write_to_yaml_file(tr.spec.config_path, verbose=False)

# Log a warning if using --trusted.
if trusted:
if trusted is not None:
trusted_deprecation_notice = (
"[DEPRECATED] `--trusted` option is deprecated and no longer needed"
)
Expand Down Expand Up @@ -1208,7 +1208,6 @@ def push(
tr,
model_name=model_name,
publish=publish,
trusted=True,
promote=promote,
preserve_previous_prod_deployment=preserve_previous_production_deployment,
deployment_name=deployment_name,
Expand Down
51 changes: 23 additions & 28 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def _oracle_data_to_graphql_mutation(oracle: b10_types.OracleData) -> str:
f'model_name: "{oracle.model_name}"',
f's3_key: "{oracle.s3_key}"',
f'encoded_config_str: "{oracle.encoded_config_str}"',
f"is_trusted: {str(oracle.is_trusted).lower()}",
]

if oracle.semver_bump:
Expand Down Expand Up @@ -132,7 +131,6 @@ def create_model_from_truss(
config: str,
semver_bump: str,
client_version: str,
is_trusted: bool,
allow_truss_download: bool = True,
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
Expand All @@ -145,7 +143,6 @@ def create_model_from_truss(
config: "{config}"
semver_bump: "{semver_bump}"
client_version: "{client_version}"
is_trusted: {"true" if is_trusted else "false"}
allow_truss_download: {"true" if allow_truss_download else "false"}
{f'version_name: "{deployment_name}"' if deployment_name else ""}
{f"model_origin: {origin.value}" if origin else ""}
Expand All @@ -172,7 +169,6 @@ def create_model_version_from_truss(
config: str,
semver_bump: str,
client_version: str,
is_trusted: bool,
preserve_previous_prod_deployment: bool = False,
deployment_name: Optional[str] = None,
environment: Optional[str] = None,
Expand All @@ -185,7 +181,6 @@ def create_model_version_from_truss(
config: "{config}"
semver_bump: "{semver_bump}"
client_version: "{client_version}"
is_trusted: {"true" if is_trusted else "false"}
scale_down_old_production: {"false" if preserve_previous_prod_deployment else "true"}
{f'name: "{deployment_name}"' if deployment_name else ""}
{f'environment_name: "{environment}"' if environment else ""}
Expand All @@ -209,7 +204,6 @@ def create_development_model_from_truss(
s3_key,
config,
client_version,
is_trusted=False,
allow_truss_download=True,
origin: Optional[b10_types.ModelOrigin] = None,
):
Expand All @@ -219,7 +213,6 @@ def create_development_model_from_truss(
s3_key: "{s3_key}"
config: "{config}"
client_version: "{client_version}"
is_trusted: {"true" if is_trusted else "false"}
allow_truss_download: {"true" if allow_truss_download else "false"}
{f"model_origin: {origin.value}" if origin else ""}
) {{
Expand Down Expand Up @@ -470,17 +463,18 @@ def patch_draft_truss_two_step(self, model_name, patch_request):
patch = base64_encoded_json_str(patch_request.to_dict())
query_string = f"""
mutation {{
stage_patch_for_draft_truss(name: "{model_name}",
client_version: "{truss.version()}",
patch: "{patch}",
) {{
id,
name,
version_id
succeeded
needs_full_deploy
error
}}
stage_patch_for_draft_truss(
name: "{model_name}"
client_version: "{truss.version()}",
patch: "{patch}"
) {{
id
name
version_id
succeeded
needs_full_deploy
error
}}
}}
"""
resp = self._post_graphql_query(query_string)
Expand All @@ -495,16 +489,17 @@ def patch_draft_truss_two_step(self, model_name, patch_request):
def sync_draft_truss(self, model_name):
query_string = f"""
mutation {{
sync_draft_truss(name: "{model_name}",
client_version: "{truss.version()}",
) {{
id,
name,
version_id
succeeded
needs_full_deploy
error
}}
sync_draft_truss(
name: "{model_name}"
client_version: "{truss.version()}",
) {{
id
name
version_id
succeeded
needs_full_deploy
error
}}
}}
"""
resp = self._post_graphql_query(query_string)
Expand Down
13 changes: 4 additions & 9 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
DEPLOYING_STATUSES = ["BUILDING", "DEPLOYING", "LOADING_MODEL", "UPDATING"]
ACTIVE_STATUS = "ACTIVE"
NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING = (
"Model hasn't been deployed yet. No evironments exist."
"Model hasn't been deployed yet. No environments exist."
)


Expand Down Expand Up @@ -330,7 +330,6 @@ def create_truss_service(
s3_key: str,
config: str,
semver_bump: str = "MINOR",
is_trusted: bool = False,
preserve_previous_prod_deployment: bool = False,
allow_truss_download: bool = False,
is_draft: Optional[bool] = False,
Expand All @@ -348,7 +347,6 @@ def create_truss_service(
s3_key: S3 key of the uploaded TrussHandle.
config: Base64 encoded JSON string of the Truss config.
semver_bump: Semver bump type, defaults to "MINOR".
is_trusted: Whether the model is trusted, defaults to False.
promote: Whether to promote the model after deploy, defaults to False.
preserve_previous_prod_deployment: Whether to scale old production deployment
to zero.
Expand All @@ -358,14 +356,12 @@ def create_truss_service(
Returns:
A Model Version handle.
"""

if is_draft:
model_version_json = api.create_development_model_from_truss(
model_name,
s3_key,
config,
truss.version(),
is_trusted=is_trusted,
client_version=truss.version(),
allow_truss_download=allow_truss_download,
origin=origin,
)
Expand All @@ -386,7 +382,6 @@ def create_truss_service(
config=config,
semver_bump=semver_bump,
client_version=truss.version(),
is_trusted=is_trusted,
allow_truss_download=allow_truss_download,
deployment_name=deployment_name,
origin=origin,
Expand All @@ -405,7 +400,6 @@ def create_truss_service(
config=config,
semver_bump=semver_bump,
client_version=truss.version(),
is_trusted=is_trusted,
preserve_previous_prod_deployment=preserve_previous_prod_deployment,
deployment_name=deployment_name,
environment=environment,
Expand All @@ -416,7 +410,8 @@ def create_truss_service(
== BasetenApi.GraphQLErrorCodes.RESOURCE_NOT_FOUND.value
):
raise ValueError(
f'Environment "{environment}" does not exist. You can create environments in the Baseten UI.'
f"Environment `{environment}` does not exist. You can create "
"environments in the Baseten UI."
) from e
raise e

Expand Down
1 change: 0 additions & 1 deletion truss/remote/baseten/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Config:
s3_key: str
encoded_config_str: str
semver_bump: Optional[str] = "MINOR"
is_trusted: bool
version_name: Optional[str] = None


Expand Down
8 changes: 0 additions & 8 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def _prepare_push(
truss_handle: TrussHandle,
model_name: str,
publish: bool = True,
trusted: bool = False,
promote: bool = False,
preserve_previous_prod_deployment: bool = False,
disable_truss_download: bool = False,
Expand Down Expand Up @@ -182,7 +181,6 @@ def _prepare_push(
encoded_config_str=encoded_config_str,
is_draft=not publish,
model_id=model_id,
is_trusted=trusted,
preserve_previous_prod_deployment=preserve_previous_prod_deployment,
version_name=deployment_name,
origin=origin,
Expand All @@ -195,7 +193,6 @@ def push( # type: ignore
truss_handle: TrussHandle,
model_name: str,
publish: bool = True,
trusted: bool = False,
promote: bool = False,
preserve_previous_prod_deployment: bool = False,
disable_truss_download: bool = False,
Expand All @@ -208,7 +205,6 @@ def push( # type: ignore
truss_handle=truss_handle,
model_name=model_name,
publish=publish,
trusted=trusted,
promote=promote,
preserve_previous_prod_deployment=preserve_previous_prod_deployment,
disable_truss_download=disable_truss_download,
Expand All @@ -229,7 +225,6 @@ def push( # type: ignore
config=push_data.encoded_config_str,
is_draft=push_data.is_draft,
model_id=push_data.model_id,
is_trusted=push_data.is_trusted,
preserve_previous_prod_deployment=push_data.preserve_previous_prod_deployment,
allow_truss_download=push_data.allow_truss_download,
deployment_name=push_data.version_name,
Expand Down Expand Up @@ -271,8 +266,6 @@ def push_chain_atomic(
push_data = self._prepare_push(
truss_handle=truss_handle,
model_name=model_name,
# Models must be trusted to use the API KEY secret.
trusted=True,
publish=publish,
origin=custom_types.ModelOrigin.CHAINS,
progress_bar=progress_bar,
Expand All @@ -283,7 +276,6 @@ def push_chain_atomic(
encoded_config_str=push_data.encoded_config_str,
is_draft=push_data.is_draft,
model_id=push_data.model_id,
is_trusted=push_data.is_trusted,
version_name=push_data.version_name,
)
chainlet_data.append(
Expand Down
Loading

0 comments on commit d5f1794

Please sign in to comment.