Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utilities for video streaming #503

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 63 additions & 17 deletions clarifai/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,14 +424,14 @@ def predict(self,
raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
) # TODO Use Chunker for inputs len > 128

self._override_model_version(inference_params, output_config)
model_info = self._get_model_info_for_inference(inference_params, output_config)
request = service_pb2.PostModelOutputsRequest(
user_app_id=self.user_app_id,
model_id=self.id,
version_id=self.model_version.id,
inputs=inputs,
runner_selector=runner_selector,
model=self.model_info)
model=model_info)

start_time = time.time()
backoff_iterator = BackoffIterator(10)
Expand Down Expand Up @@ -704,14 +704,14 @@ def generate(self,
raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
) # TODO Use Chunker for inputs len > 128

self._override_model_version(inference_params, output_config)
model_info = self._get_model_info_for_inference(inference_params, output_config)
request = service_pb2.PostModelOutputsRequest(
user_app_id=self.user_app_id,
model_id=self.id,
version_id=self.model_version.id,
inputs=inputs,
runner_selector=runner_selector,
model=self.model_info)
model=model_info)

start_time = time.time()
backoff_iterator = BackoffIterator(10)
Expand Down Expand Up @@ -922,15 +922,16 @@ def generate_by_url(self,
inference_params=inference_params,
output_config=output_config)

def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: RunnerSelector):
def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: RunnerSelector,
model_info: resources_pb2.Model):
for inputs in input_iterator:
yield service_pb2.PostModelOutputsRequest(
user_app_id=self.user_app_id,
model_id=self.id,
version_id=self.model_version.id,
inputs=inputs,
runner_selector=runner_selector,
model=self.model_info)
model=model_info)

def stream(self,
inputs: Iterator[List[Input]],
Expand All @@ -954,8 +955,8 @@ def stream(self,
# if not isinstance(inputs, Iterator[List[Input]]):
# raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')

self._override_model_version(inference_params, output_config)
request = self._req_iterator(inputs, runner_selector)
model_info = self._get_model_info_for_inference(inference_params, output_config)
request = self._req_iterator(inputs, runner_selector, model_info)

start_time = time.time()
backoff_iterator = BackoffIterator(10)
Expand Down Expand Up @@ -1168,8 +1169,54 @@ def input_generator():
inference_params=inference_params,
output_config=output_config)

def _override_model_version(self, inference_params: Dict = {}, output_config: Dict = {}) -> None:
"""Overrides the model version.
def stream_by_video_file(self,
filepath: str,
input_type: str = 'video',
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
user_id: str = None,
inference_params: Dict = {},
output_config: Dict = {}):
"""
Stream the model output based on the given video file.

Converts the video file to a streamable format, streams as bytes to the model,
and streams back the model outputs.

Args:
filepath (str): The filepath to predict.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
inference_params (dict): The inference params to override.
output_config (dict): The output config to override.
"""

if not os.path.isfile(filepath):
raise UserError('Invalid filepath.')

# TODO check if the file is streamable already

# Convert the video file to a streamable format
# TODO this conversion can offset the start time by a little bit; we should account for this
# by getting the original start time ffprobe and either sending that to the model so it can adjust
# with the ts of the first frame (too fragile to do all of this adjustment in the client input stream)
# or by adjusting the timestamps in the output stream
from clarifai.runners.utils import video_utils
stream = video_utils.recontain_as_streamable(filepath)

# TODO accumulate reads to fill the chunk size
chunk_size = 1024 * 1024 # 1 MB
chunk_iterator = iter(lambda: stream.read(chunk_size), b'')

return self.stream_by_bytes(chunk_iterator, input_type, compute_cluster_id, nodepool_id,
deployment_id, user_id, inference_params, output_config)

def _get_model_info_for_inference(self, inference_params: Dict = {},
output_config: Dict = {}) -> None:
"""Gets the model_info with modified inference params and output config.

Args:
inference_params (dict): The inference params to override.
Expand All @@ -1179,13 +1226,12 @@ def _override_model_version(self, inference_params: Dict = {}, output_config: Di
select_concepts (list[Concept]): The concepts to select.
sample_ms (int): The number of milliseconds to sample.
"""
params = Struct()
if inference_params is not None:
params.update(inference_params)

self.model_info.model_version.output_info.CopyFrom(
resources_pb2.OutputInfo(
output_config=resources_pb2.OutputConfig(**output_config), params=params))
model_info = resources_pb2.Model()
model_info.CopyFrom(self.model_info)
model_info.model_version.output_info.params = inference_params
model_info.model_version.output_info.output_config.CopyFrom(
resources_pb2.OutputConfig(**output_config))
return model_info

def _list_concepts(self) -> List[str]:
"""Lists all the concepts for the model type.
Expand Down
15 changes: 12 additions & 3 deletions clarifai/runners/models/base_typed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from clarifai_grpc.grpc.api.service_pb2 import PostModelOutputsRequest
from google.protobuf import json_format

from clarifai.runners.utils.stream_utils import readahead
from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded

from ..utils.data_handler import InputDataHandler, OutputDataHandler
from .model_class import ModelClass

Expand Down Expand Up @@ -46,12 +49,16 @@ def convert_output_to_proto(self, outputs: list):

def predict_wrapper(
self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
if self.download_request_urls:
ensure_urls_downloaded(request)
list_dict_input, inference_params = self.parse_input_request(request)
outputs = self.predict(list_dict_input, inference_parameters=inference_params)
return self.convert_output_to_proto(outputs)

def generate_wrapper(
self, request: PostModelOutputsRequest) -> Iterator[service_pb2.MultiOutputResponse]:
if self.download_request_urls:
ensure_urls_downloaded(request)
list_dict_input, inference_params = self.parse_input_request(request)
outputs = self.generate(list_dict_input, inference_parameters=inference_params)
for output in outputs:
Expand All @@ -64,11 +71,13 @@ def _preprocess_stream(
input_data, _ = self.parse_input_request(req)
yield input_data

def stream_wrapper(self, request: Iterator[PostModelOutputsRequest]
def stream_wrapper(self, request_iterator: Iterator[PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
first_request = next(request)
if self.download_request_urls:
request_iterator = readahead(map(ensure_urls_downloaded, request_iterator))
first_request = next(request_iterator)
_, inference_params = self.parse_input_request(first_request)
request_iterator = itertools.chain([first_request], request)
request_iterator = itertools.chain([first_request], request_iterator)
outputs = self.stream(self._preprocess_stream(request_iterator), inference_params)
for output in outputs:
yield self.convert_output_to_proto(output)
Expand Down
22 changes: 20 additions & 2 deletions clarifai/runners/models/model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,41 @@

from clarifai_grpc.grpc.api import service_pb2

from clarifai.runners.utils.stream_utils import readahead
from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded


class ModelClass(ABC):

download_request_urls = True

def predict_wrapper(
self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
"""This method is used for input/output proto data conversion"""
# Download any urls that are not already bytes.
if self.download_request_urls:
ensure_urls_downloaded(request)

return self.predict(request)

def generate_wrapper(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This method is used for input/output proto data conversion and yield outcome"""
# Download any urls that are not already bytes.
if self.download_request_urls:
ensure_urls_downloaded(request)

return self.generate(request)

def stream_wrapper(self, request: service_pb2.PostModelOutputsRequest
def stream_wrapper(self, request_stream: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This method is used for input/output proto data conversion and yield outcome"""
return self.stream(request)

# Download any urls that are not already bytes.
if self.download_request_urls:
request_stream = readahead(map(ensure_urls_downloaded, request_stream))

return self.stream(request_stream)

@abstractmethod
def load_model(self):
Expand Down
4 changes: 0 additions & 4 deletions clarifai/runners/models/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from clarifai_protocol import BaseRunner
from clarifai_protocol.utils.health import HealthProbeRequestHandler
from ..utils.url_fetcher import ensure_urls_downloaded

from .model_class import ModelClass

Expand Down Expand Up @@ -79,7 +78,6 @@ def runner_item_predict(self,
if not runner_item.HasField('post_model_outputs_request'):
raise Exception("Unexpected work item type: {}".format(runner_item))
request = runner_item.post_model_outputs_request
ensure_urls_downloaded(request)

resp = self.model.predict_wrapper(request)
successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs]
Expand Down Expand Up @@ -109,7 +107,6 @@ def runner_item_generate(
if not runner_item.HasField('post_model_outputs_request'):
raise Exception("Unexpected work item type: {}".format(runner_item))
request = runner_item.post_model_outputs_request
ensure_urls_downloaded(request)

for resp in self.model.generate_wrapper(request):
successes = []
Expand Down Expand Up @@ -169,5 +166,4 @@ def pmo_iterator(runner_item_iterator):
for runner_item in runner_item_iterator:
if not runner_item.HasField('post_model_outputs_request'):
raise Exception("Unexpected work item type: {}".format(runner_item))
ensure_urls_downloaded(runner_item.post_model_outputs_request)
yield runner_item.post_model_outputs_request
18 changes: 1 addition & 17 deletions clarifai/runners/models/model_servicer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from itertools import tee
from typing import Iterator

from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2

from ..utils.url_fetcher import ensure_urls_downloaded


class ModelServicer(service_pb2_grpc.V2Servicer):
"""
Expand All @@ -27,9 +24,6 @@ def PostModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
returns an output.
"""

# Download any urls that are not already bytes.
ensure_urls_downloaded(request)

try:
return self.model.predict_wrapper(request)
except Exception as e:
Expand All @@ -46,9 +40,6 @@ def GenerateModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
This is the method that will be called when the servicer is run. It takes in an input and
returns an output.
"""
# Download any urls that are not already bytes.
ensure_urls_downloaded(request)

try:
return self.model.generate_wrapper(request)
except Exception as e:
Expand All @@ -66,15 +57,8 @@ def StreamModelOutputs(self,
This is the method that will be called when the servicer is run. It takes in an input and
returns an output.
"""
# Duplicate the iterator
request, request_copy = tee(request)

# Download any urls that are not already bytes.
for req in request:
ensure_urls_downloaded(req)

try:
return self.model.stream_wrapper(request_copy)
return self.model_class.stream_wrapper(request)
except Exception as e:
yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
code=status_code_pb2.MODEL_PREDICTION_FAILED,
Expand Down
Loading
Loading