Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DEVX-896]: Update Model Predict CLI #500

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 10 additions & 36 deletions clarifai/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,6 @@ def run_locally(model_path, port, mode, keep_env, keep_image):
@click.option('--file_path', required=False, help='File path of file for the model to predict')
@click.option('--url', required=False, help='URL to the file for the model to predict')
@click.option('--bytes', required=False, help='Bytes to the file for the model to predict')
@click.option(
'--input_id', required=False, help='Existing input id in the app for the model to predict')
@click.option('--input_type', required=False, help='Type of input')
@click.option(
'-cc_id',
Expand All @@ -188,36 +186,28 @@ def run_locally(model_path, port, mode, keep_env, keep_image):
'--inference_params', required=False, default='{}', help='Inference parameters to override')
@click.option('--output_config', required=False, default='{}', help='Output config to override')
@click.pass_context
def predict(ctx, config, model_id, user_id, app_id, model_url, file_path, url, bytes, input_id,
input_type, compute_cluster_id, nodepool_id, deployment_id, inference_params,
output_config):
def predict(ctx, config, model_id, user_id, app_id, model_url, file_path, url, bytes, input_type,
compute_cluster_id, nodepool_id, deployment_id, inference_params, output_config):
"""Predict using the given model"""
import json

from clarifai.client.deployment import Deployment
from clarifai.client.input import Input
from clarifai.client.model import Model
from clarifai.client.nodepool import Nodepool
from clarifai.utils.cli import from_yaml
if config:
config = from_yaml(config)
model_id, user_id, app_id, model_url, file_path, url, bytes, input_id, input_type, compute_cluster_id, nodepool_id, deployment_id, inference_params, output_config = (
model_id, user_id, app_id, model_url, file_path, url, bytes, input_type, compute_cluster_id, nodepool_id, deployment_id, inference_params, output_config = (
config.get(k, v)
for k, v in [('model_id', model_id), ('user_id', user_id), ('app_id', app_id), (
'model_url', model_url), ('file_path', file_path), ('url', url), ('bytes', bytes), (
'input_id',
input_id), ('input_type',
input_type), ('compute_cluster_id',
compute_cluster_id), ('nodepool_id', nodepool_id), (
'deployment_id',
deployment_id), ('inference_params',
inference_params), ('output_config',
output_config)])
'input_type', input_type), ('compute_cluster_id', compute_cluster_id), (
'nodepool_id',
nodepool_id), ('deployment_id',
deployment_id), ('inference_params',
inference_params), ('output_config',
output_config)])
if sum([opt[1] for opt in [(model_id, 1), (user_id, 1), (app_id, 1), (model_url, 3)]
if opt[0]]) != 3:
raise ValueError("Either --model_id & --user_id & --app_id or --model_url must be provided.")
if sum([1 for opt in [file_path, url, bytes, input_id] if opt]) != 1:
raise ValueError("Exactly one of --file_path, --url, --bytes or --input_id must be provided.")
if compute_cluster_id or nodepool_id or deployment_id:
if sum([
opt[1] for opt in [(compute_cluster_id, 0.5), (nodepool_id, 0.5), (deployment_id, 1)]
Expand Down Expand Up @@ -267,21 +257,5 @@ def predict(ctx, config, model_id, user_id, app_id, model_url, file_path, url, b
nodepool_id=nodepool_id,
deployment_id=deployment_id,
inference_params=inference_params,
output_config=output_config)
elif input_id:
inputs = [Input.get_input(input_id)]
runner_selector = None
if deployment_id:
runner_selector = Deployment.get_runner_selector(
user_id=ctx.obj['user_id'], deployment_id=deployment_id)
elif compute_cluster_id and nodepool_id:
runner_selector = Nodepool.get_runner_selector(
user_id=ctx.obj['user_id'],
compute_cluster_id=compute_cluster_id,
nodepool_id=nodepool_id)
model_prediction = model.predict(
inputs=inputs,
runner_selector=runner_selector,
inference_params=inference_params,
output_config=output_config)
output_config=output_config) ## TO DO: Add support for input_id
click.echo(model_prediction)
18 changes: 18 additions & 0 deletions clarifai/client/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,24 @@ def get_mask_proto(input_id: str,

return input_mask_proto

def get_input(self, input_id: str) -> Input:
"""Get Input object of input with input_id provided from the app.

Args:
input_id (str): The input ID for the annotation to get.

Returns:
Input: An Input object for the specified input ID.

Example:
>>> from clarifai.client.input import Inputs
>>> input_obj = Inputs(user_id = 'user_id', app_id = 'demo_app')
>>> input_obj.get_input(input_id='demo')
"""
request = service_pb2.GetInputRequest(user_app_id=self.user_app_id, input_id=input_id)
response = self._grpc_request(self.STUB.GetInput, request)
return response.input

def upload_from_url(self,
input_id: str,
image_url: str = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ deployment:
disable_packing: false
worker:
model:
id: "apparel-clusterering"
id: "python_string_cat"
model_version:
id: "cc911f6b0ed748efb89e3d1359c146c4"
id: "b7038e059a0c4ddca29c22aec561824d"
user_id: "clarifai"
app_id: "main"
app_id: "Test-Model-Upload"
scheduling_choice: 4
nodepools:
- id: "test-nodepool-6"
Expand Down
1 change: 0 additions & 1 deletion tests/runners/hf_mbart_model/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ inference_compute_info:
checkpoints:
type: "huggingface"
repo_id: "sshleifer/tiny-mbart"
hf_token: ""
6 changes: 5 additions & 1 deletion tests/test_data_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_upload_text_bytes(self, caplog):
self.input_object.upload_from_bytes(input_id='input_13', text_bytes=text_bytes)
assert "SUCCESS" in caplog.text

def test_get_multimodal_input(self, caplog):
def test_get_multimodal_input(self):
input_object = self.input_object.get_multimodal_input(
input_id='input_14', raw_text='This is a multimodal test text', image_url=IMAGE_URL)
assert input_object.id == 'input_14' and input_object.data.text.raw == 'This is a multimodal test text'
Expand All @@ -146,6 +146,10 @@ def test_get_text_inputs_from_folder(self):
text_inputs = self.input_object.get_text_inputs_from_folder(TEXTS_FOLDER_PATH)
assert len(text_inputs) == 3

def test_get_input_from_app(self):
input_object = self.input_object.get_input(input_id='input_1')
assert input_object.id == 'input_1'

def test_get_mask_proto(self):
polygon_points = [[.2, .2], [.8, .2], [.8, .8], [.2, .8]]
annotation = self.input_object.get_mask_proto(
Expand Down
Loading