diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bee8a64 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..3f9b90b --- /dev/null +++ b/__init__.py @@ -0,0 +1 @@ +from .modal_inference import * diff --git a/image.py b/image.py new file mode 100644 index 0000000..5ccc077 --- /dev/null +++ b/image.py @@ -0,0 +1,108 @@ +from modal import Volume, Image, Mount +import os +from pathlib import Path +from ai_video_editor.stub import stub, REPO_HOME, LOCAL_CERT_PATH, CERT_PATH, EXTRA_ENV + +LOCAL_VOLUME_DIR = "/video_llava_volume" +HF_DATASETS_CACHE = str(Path(LOCAL_VOLUME_DIR) / "hf_datasets_cache") +MODEL_CACHE = Path(LOCAL_VOLUME_DIR, "models") + +LOCAL_VOLUME_NAME = "video-llava-volume" +local_volume = Volume.from_name(LOCAL_VOLUME_NAME, create_if_missing=True) +local_volumes = { + LOCAL_VOLUME_DIR: local_volume, +} +local_mounts = [ + Mount.from_local_dir("./ai_video_editor/video_llava", remote_path=REPO_HOME), +] + + +def remove_old_files(): + import shutil + shutil.rmtree('/volume/models', ignore_errors=True) + +image = ( + Image.from_registry( + "nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04", add_python="3.11" + ) + .apt_install( + "git", + "curl", + "libgl1-mesa-glx", + "libglib2.0-0", + "libsm6", + "libxrender1", + "libxext6", + "ffmpeg", + "clang", + "libopenmpi-dev", + gpu="any", + ) + + .pip_install( + # "torch==2.1.2", + # "transformers==4.37.2", + # "bitsandbytes==0.42.0", + "torch==2.0.1", "torchvision==0.15.2", + "transformers==4.31.0", "tokenizers>=0.12.1,<0.14", "sentencepiece==0.1.99", "shortuuid", + "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0", + "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2", + "requests", "httpx==0.24.0", "uvicorn", "fastapi", + "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", + "tensorboardX==2.6.2.2", "gradio==3.37.0", "gradio_client==0.7.0", + "deepspeed==0.9.5", "ninja", "wandb", + "wheel", + gpu="any", + ) + .run_commands( + "python -m bitsandbytes", + gpu="any" + ) + .run_commands("pip install flash-attn --no-build-isolation", gpu="any") + .env({"PYTHONPATH": REPO_HOME, "HF_DATASETS_CACHE": HF_DATASETS_CACHE}) + .pip_install( + "decord", + "opencv-python", + "git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d", + gpu="any", + ) + .pip_install( + "aiofiles", + "aioboto3", + ) + .run_function(remove_old_files) + .copy_local_file(LOCAL_CERT_PATH, CERT_PATH) + .pip_install("boto3", "aioboto3") + .env(EXTRA_ENV) + .pip_install("diskcache") +) +# TODO bitsandbytes seems to not be working with gpu + +def function_dec(**extras): + return stub.function( + image=image, + timeout=80000, + # checkpointing doesn't work because it restricts internet access + #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. + _allow_background_volume_commits=True, + container_idle_timeout=120, + volumes=local_volumes, + mounts=local_mounts, + **extras, + ) + +def cls_dec(**extras): + return stub.cls( + image=image, + timeout=80000, + # checkpointing doesn't work because it restricts internet access + #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. + container_idle_timeout=1200, + # TODO maybe turn on + allow_concurrent_inputs=4, + retries=3, + _allow_background_volume_commits=True, + volumes=local_volumes, + mounts=local_mounts, + **extras, + ) diff --git a/modal_inference.py b/modal_inference.py new file mode 100644 index 0000000..14891c7 --- /dev/null +++ b/modal_inference.py @@ -0,0 +1,180 @@ +import os +import shutil +import urllib + +from modal import asgi_app, method, enter, build +from ai_video_editor.utils.fs_utils import async_copy_from_s3 +from .image import LOCAL_VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, local_volume +from ai_video_editor.stub import stub, S3_VIDEO_PATH, VOLUME_DIR, volume as remote_volume +import diskcache as dc +from pathlib import Path +# for local testing +#S3_VIDEO_PATH= "s3_videos" +#MODEL_CACHE = "models" +#Path(VOLUME_DIR).mkdir(exist_ok=True, parents=True) +VIDEOS_DIR = Path(S3_VIDEO_PATH) / "videos" +IMAGES_DIR = Path(S3_VIDEO_PATH) / "images" + + + +@cls_dec(gpu="any") +class VideoLlavaModel: + @enter() + def load_model(self): + self.cache = dc.Cache('.cache') + local_volume.reload() + import torch + from videollava.serve.gradio_utils import Chat + self.conv_mode = "llava_v1" + model_path = 'LanguageBind/Video-LLaVA-7B' + device = 'cuda' + load_8bit = False + load_4bit = True + self.dtype = torch.float16 + self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=str(MODEL_CACHE)) + # self.handler.model.to(dtype=self.dtype) + + def copy_file_from_remote_volume(self, filepath): + in_volume_path = filepath.split('/', 2)[-1] + local_volume_path = Path(LOCAL_VOLUME_DIR) / in_volume_path + local_volume_path.parent.mkdir(parents=True, exist_ok=True) + if not local_volume_path.exists(): + shutil.copy(filepath, str(local_volume_path)) + + async def copy_file_from_s3(self, filepath): + bucket, in_bucket_path = filepath.replace('s3://','').split('/', 1) + await async_copy_from_s3(bucket, in_bucket_path, str(Path(VOLUME_DIR) / in_bucket_path)) + + async def copy_file_to_local(self, filepath): + if not filepath: + return + if filepath.startswith('s3://'): + await self.copy_file_from_s3(filepath) + else: + self.copy_file_from_remote_volume(filepath) + + @method() + async def generate(self, image1, video, textbox_in, use_existing_output=True): + inputs = (image1, video, textbox_in) + if inputs in self.cache and use_existing_output: + res = self.cache[inputs] + self.cache.close() + return res + remote_volume.reload() + local_volume.reload() + await self.copy_file_to_local(image1) + await self.copy_file_to_local(video) + + from videollava.conversation import conv_templates + from videollava.constants import DEFAULT_IMAGE_TOKEN + if not textbox_in: + raise ValueError("no prompt provided") + + image1 = image1 if image1 else "none" + video = video if video else "none" + + state_ = conv_templates[self.conv_mode].copy() + images_tensor = [] + + text_en_in = textbox_in.replace("picture", "image") + + image_processor = self.handler.image_processor + if os.path.exists(image1) and not os.path.exists(video): + tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + video_processor = self.handler.video_processor + if not os.path.exists(image1) and os.path.exists(video): + tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + if os.path.exists(image1) and os.path.exists(video): + tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + + tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + + if os.path.exists(image1) and not os.path.exists(video): + text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in + elif not os.path.exists(image1) and os.path.exists(video): + text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + elif os.path.exists(image1) and os.path.exists(video): + text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN + else: + print("WARNING: No image or video supplied") + + text_en_out, _ = self.handler.generate(images_tensor, text_en_in, first_run=True, state=state_) + + text_en_out = text_en_out.split('#')[0] + textbox_out = text_en_out + + if not textbox_out: + raise ValueError("no text generated") + self.cache.set(inputs, textbox_out) + self.cache.close() + return textbox_out + + + +def fastapi_app(): + from fastapi import FastAPI, UploadFile, File + import aiofiles + + Path(MODEL_CACHE).mkdir(exist_ok=True, parents=True) + VIDEOS_DIR.mkdir(exist_ok=True, parents=True) + IMAGES_DIR.mkdir(exist_ok=True, parents=True) + + app = FastAPI() + model = VideoLlavaModel() + + @app.post("/upload") + async def upload( + file: UploadFile = File(...), + ): + local_volume.reload() + filename_decoded = urllib.parse.unquote(file.filename) + file_path = str(Path(LOCAL_VOLUME_DIR) / filename_decoded) + async with aiofiles.open(file_path, "wb") as buffer: + while content := await file.read(1024): # Read chunks of 1024 bytes + await buffer.write(content) + local_volume.commit() + return {"file_path": file_path} + + @app.post("/inference") + async def inference( + video_file_name: str = '', + video_file_path: str = '', + image_file_name: str = '', + image_file_path: str = '', + prompt: str = '', + ): + video_file_name = urllib.parse.unquote(video_file_name) + video_file_path = urllib.parse.unquote(video_file_path) + if video_file_path is None or video_file_path == '': + if video_file_name is None or video_file_name == '': + raise ValueError("one of video_file_path or video_file_name must be specified") + video_file_path = str(VIDEOS_DIR / video_file_name) + + image_file_name = urllib.parse.unquote(image_file_name) + image_file_path = urllib.parse.unquote(image_file_path) + if image_file_path is None or image_file_path == '': + if image_file_name is not None and image_file_name != '': + image_file_path = str(IMAGES_DIR / image_file_name) + + return model.generate.remote(image_file_path, video_file_path, prompt) + return app + + +@function_dec() +@asgi_app() +def fastapi_app_modal(): + return fastapi_app() + +# local testing: +# comment this out to deploy +# app = fastapi_app() +# conda activate videollava +# uvicorn modal_inference:app diff --git a/pyproject.toml b/pyproject.toml index 7c0b7f1..7f21463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2", "requests", "httpx==0.24.0", "uvicorn", "fastapi", "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", - "tensorboardX==2.6.2.2", "gradio==3.37.0", "gradio_client==0.7.0" + "tensorboardX==2.6.2.2", "gradio==3.37.0", "gradio_client==0.7.0", "modal" ] [project.optional-dependencies] diff --git a/videollava/serve/gradio_web_server.py b/videollava/serve/gradio_web_server.py index dfce85a..fc7d7de 100644 --- a/videollava/serve/gradio_web_server.py +++ b/videollava/serve/gradio_web_server.py @@ -1,249 +1,281 @@ -import shutil -import subprocess - -import torch -import gradio as gr -from fastapi import FastAPI -import os -from PIL import Image -import tempfile -from decord import VideoReader, cpu -from transformers import TextStreamer - -from videollava.constants import DEFAULT_IMAGE_TOKEN -from videollava.conversation import conv_templates, SeparatorStyle, Conversation -from videollava.serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css - - - -def save_image_to_local(image): - filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') - image = Image.open(image) - image.save(filename) - # print(filename) - return filename - - -def save_video_to_local(video_path): - filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') - shutil.copyfile(video_path, filename) - return filename - - -def generate(image1, video, textbox_in, first_run, state, state_, images_tensor): - flag = 1 - if not textbox_in: - if len(state_.messages) > 0: - textbox_in = state_.messages[-1][1] - state_.messages.pop(-1) - flag = 0 - else: - return "Please enter instruction" - - image1 = image1 if image1 else "none" - video = video if video else "none" - # assert not (os.path.exists(image1) and os.path.exists(video)) - - if type(state) is not Conversation: - state = conv_templates[conv_mode].copy() - state_ = conv_templates[conv_mode].copy() - images_tensor = [] - - first_run = False if len(state.messages) > 0 else True - - text_en_in = textbox_in.replace("picture", "image") - - # images_tensor = [[], []] - image_processor = handler.image_processor - if os.path.exists(image1) and not os.path.exists(video): - tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(handler.model.device, dtype=dtype) - images_tensor.append(tensor) - video_processor = handler.video_processor - if not os.path.exists(image1) and os.path.exists(video): - tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(handler.model.device, dtype=dtype) - images_tensor.append(tensor) - if os.path.exists(image1) and os.path.exists(video): - tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(handler.model.device, dtype=dtype) - images_tensor.append(tensor) - - tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(handler.model.device, dtype=dtype) - images_tensor.append(tensor) - - if os.path.exists(image1) and not os.path.exists(video): - text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in - if not os.path.exists(image1) and os.path.exists(video): - text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in - if os.path.exists(image1) and os.path.exists(video): - text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN - # print(text_en_in) - text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) - state_.messages[-1] = (state_.roles[1], text_en_out) - - text_en_out = text_en_out.split('#')[0] - textbox_out = text_en_out - - show_images = "" - if os.path.exists(image1): - filename = save_image_to_local(image1) - show_images += f'' - if os.path.exists(video): - filename = save_video_to_local(video) - show_images += f'' - - if flag: - state.append_message(state.roles[0], textbox_in + "\n" + show_images) - state.append_message(state.roles[1], textbox_out) - - return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) - - -def regenerate(state, state_): - state.messages.pop(-1) - state_.messages.pop(-1) - if len(state.messages) > 0: - return state, state_, state.to_gradio_chatbot(), False - return (state, state_, state.to_gradio_chatbot(), True) - - -def clear_history(state, state_): - state = conv_templates[conv_mode].copy() - state_ = conv_templates[conv_mode].copy() - return (gr.update(value=None, interactive=True), - gr.update(value=None, interactive=True), \ - gr.update(value=None, interactive=True), \ - True, state, state_, state.to_gradio_chatbot(), []) - - -conv_mode = "llava_v1" -model_path = 'LanguageBind/Video-LLaVA-7B' -cache_dir = './cache_dir' -device = 'cuda' -load_8bit = True -load_4bit = False -dtype = torch.float16 -handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir) -# handler.model.to(dtype=dtype) -if not os.path.exists("temp"): - os.makedirs("temp") - -app = FastAPI() - - -textbox = gr.Textbox( - show_label=False, placeholder="Enter text and press ENTER", container=False -) -with gr.Blocks(title='Video-LLaVA๐Ÿš€', theme=gr.themes.Default(), css=block_css) as demo: - gr.Markdown(title_markdown) - state = gr.State() - state_ = gr.State() - first_run = gr.State() - images_tensor = gr.State() - - with gr.Row(): - with gr.Column(scale=3): - image1 = gr.Image(label="Input Image", type="filepath") - video = gr.Video(label="Input Video") - - cur_dir = os.path.dirname(os.path.abspath(__file__)) - gr.Examples( - examples=[ - [ - f"{cur_dir}/examples/extreme_ironing.jpg", - "What is unusual about this image?", - ], - [ - f"{cur_dir}/examples/waterview.jpg", - "What are the things I should be cautious about when I visit here?", - ], - [ - f"{cur_dir}/examples/desert.jpg", - "If there are factual errors in the questions, point it out; if not, proceed answering the question. Whatโ€™s happening in the desert?", - ], - ], - inputs=[image1, textbox], - ) - - with gr.Column(scale=7): - chatbot = gr.Chatbot(label="Video-LLaVA", bubble_full_width=True).style(height=750) - with gr.Row(): - with gr.Column(scale=8): - textbox.render() - with gr.Column(scale=1, min_width=50): - submit_btn = gr.Button( - value="Send", variant="primary", interactive=True - ) - with gr.Row(elem_id="buttons") as button_row: - upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=True) - downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=True) - flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=True) - # stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False) - regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=True) - clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=True) - - with gr.Row(): - gr.Examples( - examples=[ - [ - f"{cur_dir}/examples/sample_img_22.png", - f"{cur_dir}/examples/sample_demo_22.mp4", - "Are the instruments in the pictures used in the video?", - ], - [ - f"{cur_dir}/examples/sample_img_13.png", - f"{cur_dir}/examples/sample_demo_13.mp4", - "Does the flag in the image appear in the video?", - ], - [ - f"{cur_dir}/examples/sample_img_8.png", - f"{cur_dir}/examples/sample_demo_8.mp4", - "Are the image and the video depicting the same place?", - ], - ], - inputs=[image1, video, textbox], - ) - gr.Examples( - examples=[ - [ - f"{cur_dir}/examples/sample_demo_1.mp4", - "Why is this video funny?", - ], - [ - f"{cur_dir}/examples/sample_demo_3.mp4", - "Can you identify any safety hazards in this video?" - ], - [ - f"{cur_dir}/examples/sample_demo_9.mp4", - "Describe the video.", - ], - [ - f"{cur_dir}/examples/sample_demo_22.mp4", - "Describe the activity in the video.", - ], - ], - inputs=[video, textbox], - ) - gr.Markdown(tos_markdown) - gr.Markdown(learn_more_markdown) - - submit_btn.click(generate, [image1, video, textbox, first_run, state, state_, images_tensor], - [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) - - regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( - generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) - - clear_btn.click(clear_history, [state, state_], - [image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) - -# app = gr.mount_gradio_app(app, demo, path="/") -demo.launch() - -# uvicorn videollava.serve.gradio_web_server:app -# python -m videollava.serve.gradio_web_server +# import shutil +# import os +# import tempfile + +# from modal import asgi_app, method, enter +# from ...stub import VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec + + +# def save_image_to_local(image): + # from PIL import Image + # filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.jpg') + # image = Image.open(image) + # image.save(filename) + # # print(filename) + # return filename + + +# def save_video_to_local(video_path): + # filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.mp4') + # shutil.copyfile(video_path, filename) + # return filename + + +# @cls_dec(gpu="any") +# class VideoLlavaModel: + # @enter() + # def load_model(self): + # import torch + # from videollava.serve.gradio_utils import Chat + # self.conv_mode = "llava_v1" + # model_path = 'LanguageBind/Video-LLaVA-7B' + # device = 'cuda' + # load_8bit = True + # load_4bit = False + # self.dtype = torch.float16 + # self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=MODEL_CACHE) + # # self.handler.model.to(dtype=self.dtype) + + # @method() + # def generate(self, image1, video, textbox_in, first_run, state, state_, images_tensor): + # from videollava.conversation import conv_templates, Conversation + # import gradio as gr + # from videollava.constants import DEFAULT_IMAGE_TOKEN + # flag = 1 + # if not textbox_in: + # if len(state_.messages) > 0: + # textbox_in = state_.messages[-1][1] + # state_.messages.pop(-1) + # flag = 0 + # else: + # return "Please enter instruction" + + # image1 = image1 if image1 else "none" + # video = video if video else "none" + # # assert not (os.path.exists(image1) and os.path.exists(video)) + + # if type(state) is not Conversation: + # state = conv_templates[self.conv_mode].copy() + # state_ = conv_templates[self.conv_mode].copy() + # images_tensor = [] + + # first_run = False if len(state.messages) > 0 else True + + # text_en_in = textbox_in.replace("picture", "image") + + # # images_tensor = [[], []] + # image_processor = self.handler.image_processor + # if os.path.exists(image1) and not os.path.exists(video): + # tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + # # print(tensor.shape) + # tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + # images_tensor.append(tensor) + # video_processor = self.handler.video_processor + # if not os.path.exists(image1) and os.path.exists(video): + # tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + # # print(tensor.shape) + # tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + # images_tensor.append(tensor) + # if os.path.exists(image1) and os.path.exists(video): + # tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + # # print(tensor.shape) + # tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + # images_tensor.append(tensor) + + # tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + # # print(tensor.shape) + # tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + # images_tensor.append(tensor) + + # if os.path.exists(image1) and not os.path.exists(video): + # text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in + # if not os.path.exists(image1) and os.path.exists(video): + # text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + # if os.path.exists(image1) and os.path.exists(video): + # text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN + # # print(text_en_in) + # text_en_out, state_ = self.handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) + # state_.messages[-1] = (state_.roles[1], text_en_out) + + # text_en_out = text_en_out.split('#')[0] + # textbox_out = text_en_out + + # show_images = "" + # if os.path.exists(image1): + # filename = save_image_to_local(image1) + # show_images += f'' + # if os.path.exists(video): + # filename = save_video_to_local(video) + # show_images += f'' + + # if flag: + # state.append_message(state.roles[0], textbox_in + "\n" + show_images) + # state.append_message(state.roles[1], textbox_out) + + # return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) + + # @method() + # def clear_history(self, state, state_): + # from videollava.conversation import conv_templates + # import gradio as gr + # state = conv_templates[self.conv_mode].copy() + # state_ = conv_templates[self.conv_mode].copy() + # return (gr.update(value=None, interactive=True), + # gr.update(value=None, interactive=True), \ + # gr.update(value=None, interactive=True), \ + # True, state, state_, state.to_gradio_chatbot(), []) + + + + +# def regenerate(state, state_): + # state.messages.pop(-1) + # state_.messages.pop(-1) + # if len(state.messages) > 0: + # return state, state_, state.to_gradio_chatbot(), False + # return (state, state_, state.to_gradio_chatbot(), True) + + + + + +# def build_gradio_interface(model): + # import gradio as gr + # from videollava.serve.gradio_utils import tos_markdown, learn_more_markdown, title_markdown, block_css + + # # if not os.path.exists("temp"): + # # os.makedirs("temp") + + + # textbox = gr.Textbox( + # show_label=False, placeholder="Enter text and press ENTER", container=False + # ) + # with gr.Blocks(title='Video-LLaVA๐Ÿš€', theme=gr.themes.Default(), css=block_css) as interface: + # gr.Markdown(title_markdown) + # state = gr.State() + # state_ = gr.State() + # first_run = gr.State() + # images_tensor = gr.State() + + # with gr.Row(): + # with gr.Column(scale=3): + # image1 = gr.Image(label="Input Image", type="filepath") + # video = gr.Video(label="Input Video") + + # cur_dir = os.path.dirname(os.path.abspath(__file__)) + # gr.Examples( + # examples=[ + # [ + # f"{cur_dir}/examples/extreme_ironing.jpg", + # "What is unusual about this image?", + # ], + # [ + # f"{cur_dir}/examples/waterview.jpg", + # "What are the things I should be cautious about when I visit here?", + # ], + # [ + # f"{cur_dir}/examples/desert.jpg", + # "If there are factual errors in the questions, point it out; if not, proceed answering the question. Whatโ€™s happening in the desert?", + # ], + # ], + # inputs=[image1, textbox], + # ) + + # with gr.Column(scale=7): + # chatbot = gr.Chatbot(label="Video-LLaVA", bubble_full_width=True).style(height=750) + # with gr.Row(): + # with gr.Column(scale=8): + # textbox.render() + # with gr.Column(scale=1, min_width=50): + # submit_btn = gr.Button( + # value="Send", variant="primary", interactive=True + # ) + # with gr.Row(elem_id="buttons") as button_row: + # upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=True) + # downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=True) + # flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=True) + # # stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False) + # regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=True) + # clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=True) + + # with gr.Row(): + # gr.Examples( + # examples=[ + # [ + # f"{cur_dir}/examples/sample_img_22.png", + # f"{cur_dir}/examples/sample_demo_22.mp4", + # "Are the instruments in the pictures used in the video?", + # ], + # [ + # f"{cur_dir}/examples/sample_img_13.png", + # f"{cur_dir}/examples/sample_demo_13.mp4", + # "Does the flag in the image appear in the video?", + # ], + # [ + # f"{cur_dir}/examples/sample_img_8.png", + # f"{cur_dir}/examples/sample_demo_8.mp4", + # "Are the image and the video depicting the same place?", + # ], + # ], + # inputs=[image1, video, textbox], + # ) + # gr.Examples( + # examples=[ + # [ + # f"{cur_dir}/examples/sample_demo_1.mp4", + # "Why is this video funny?", + # ], + # [ + # f"{cur_dir}/examples/sample_demo_3.mp4", + # "Can you identify any safety hazards in this video?" + # ], + # [ + # f"{cur_dir}/examples/sample_demo_9.mp4", + # "Describe the video.", + # ], + # [ + # f"{cur_dir}/examples/sample_demo_22.mp4", + # "Describe the activity in the video.", + # ], + # ], + # inputs=[video, textbox], + # ) + # gr.Markdown(tos_markdown) + # gr.Markdown(learn_more_markdown) + + # submit_btn.click(model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor], + # [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) + + # regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( + # model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) + + # clear_btn.click(model.clear_history.remote, [state, state_], + # [image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) + # return interface + + +# @function_dec(gpu="any") +# @asgi_app() +# def fastapi_app(): + # from gradio.routes import mount_gradio_app + # from fastapi import FastAPI + # app = FastAPI() + + # model = VideoLlavaModel() + # # interface = gr.Interface( + # # fn=classifier.predict.remote, + # # inputs=gr.Image(shape=(224, 224)), + # # outputs="label", + # # examples=create_demo_examples(), + # # css="/assets/index.css", + # # ) + # return mount_gradio_app( + # app=app, + # blocks=build_gradio_interface(model), + # path="/", + # ) +# # app = gr.mount_gradio_app(app, demo, path="/") +# # demo.launch() + +# # uvicorn videollava.serve.gradio_web_server:app +# # python -m videollava.serve.gradio_web_server