diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..bee8a64
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+__pycache__
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..3f9b90b
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1 @@
+from .modal_inference import *
diff --git a/image.py b/image.py
new file mode 100644
index 0000000..5ccc077
--- /dev/null
+++ b/image.py
@@ -0,0 +1,108 @@
+from modal import Volume, Image, Mount
+import os
+from pathlib import Path
+from ai_video_editor.stub import stub, REPO_HOME, LOCAL_CERT_PATH, CERT_PATH, EXTRA_ENV
+
+LOCAL_VOLUME_DIR = "/video_llava_volume"
+HF_DATASETS_CACHE = str(Path(LOCAL_VOLUME_DIR) / "hf_datasets_cache")
+MODEL_CACHE = Path(LOCAL_VOLUME_DIR, "models")
+
+LOCAL_VOLUME_NAME = "video-llava-volume"
+local_volume = Volume.from_name(LOCAL_VOLUME_NAME, create_if_missing=True)
+local_volumes = {
+ LOCAL_VOLUME_DIR: local_volume,
+}
+local_mounts = [
+ Mount.from_local_dir("./ai_video_editor/video_llava", remote_path=REPO_HOME),
+]
+
+
+def remove_old_files():
+ import shutil
+ shutil.rmtree('/volume/models', ignore_errors=True)
+
+image = (
+ Image.from_registry(
+ "nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04", add_python="3.11"
+ )
+ .apt_install(
+ "git",
+ "curl",
+ "libgl1-mesa-glx",
+ "libglib2.0-0",
+ "libsm6",
+ "libxrender1",
+ "libxext6",
+ "ffmpeg",
+ "clang",
+ "libopenmpi-dev",
+ gpu="any",
+ )
+
+ .pip_install(
+ # "torch==2.1.2",
+ # "transformers==4.37.2",
+ # "bitsandbytes==0.42.0",
+ "torch==2.0.1", "torchvision==0.15.2",
+ "transformers==4.31.0", "tokenizers>=0.12.1,<0.14", "sentencepiece==0.1.99", "shortuuid",
+ "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0",
+ "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
+ "requests", "httpx==0.24.0", "uvicorn", "fastapi",
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
+ "tensorboardX==2.6.2.2", "gradio==3.37.0", "gradio_client==0.7.0",
+ "deepspeed==0.9.5", "ninja", "wandb",
+ "wheel",
+ gpu="any",
+ )
+ .run_commands(
+ "python -m bitsandbytes",
+ gpu="any"
+ )
+ .run_commands("pip install flash-attn --no-build-isolation", gpu="any")
+ .env({"PYTHONPATH": REPO_HOME, "HF_DATASETS_CACHE": HF_DATASETS_CACHE})
+ .pip_install(
+ "decord",
+ "opencv-python",
+ "git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d",
+ gpu="any",
+ )
+ .pip_install(
+ "aiofiles",
+ "aioboto3",
+ )
+ .run_function(remove_old_files)
+ .copy_local_file(LOCAL_CERT_PATH, CERT_PATH)
+ .pip_install("boto3", "aioboto3")
+ .env(EXTRA_ENV)
+ .pip_install("diskcache")
+)
+# TODO bitsandbytes seems to not be working with gpu
+
+def function_dec(**extras):
+ return stub.function(
+ image=image,
+ timeout=80000,
+ # checkpointing doesn't work because it restricts internet access
+ #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts.
+ _allow_background_volume_commits=True,
+ container_idle_timeout=120,
+ volumes=local_volumes,
+ mounts=local_mounts,
+ **extras,
+ )
+
+def cls_dec(**extras):
+ return stub.cls(
+ image=image,
+ timeout=80000,
+ # checkpointing doesn't work because it restricts internet access
+ #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts.
+ container_idle_timeout=1200,
+ # TODO maybe turn on
+ allow_concurrent_inputs=4,
+ retries=3,
+ _allow_background_volume_commits=True,
+ volumes=local_volumes,
+ mounts=local_mounts,
+ **extras,
+ )
diff --git a/modal_inference.py b/modal_inference.py
new file mode 100644
index 0000000..14891c7
--- /dev/null
+++ b/modal_inference.py
@@ -0,0 +1,180 @@
+import os
+import shutil
+import urllib
+
+from modal import asgi_app, method, enter, build
+from ai_video_editor.utils.fs_utils import async_copy_from_s3
+from .image import LOCAL_VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, local_volume
+from ai_video_editor.stub import stub, S3_VIDEO_PATH, VOLUME_DIR, volume as remote_volume
+import diskcache as dc
+from pathlib import Path
+# for local testing
+#S3_VIDEO_PATH= "s3_videos"
+#MODEL_CACHE = "models"
+#Path(VOLUME_DIR).mkdir(exist_ok=True, parents=True)
+VIDEOS_DIR = Path(S3_VIDEO_PATH) / "videos"
+IMAGES_DIR = Path(S3_VIDEO_PATH) / "images"
+
+
+
+@cls_dec(gpu="any")
+class VideoLlavaModel:
+ @enter()
+ def load_model(self):
+ self.cache = dc.Cache('.cache')
+ local_volume.reload()
+ import torch
+ from videollava.serve.gradio_utils import Chat
+ self.conv_mode = "llava_v1"
+ model_path = 'LanguageBind/Video-LLaVA-7B'
+ device = 'cuda'
+ load_8bit = False
+ load_4bit = True
+ self.dtype = torch.float16
+ self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=str(MODEL_CACHE))
+ # self.handler.model.to(dtype=self.dtype)
+
+ def copy_file_from_remote_volume(self, filepath):
+ in_volume_path = filepath.split('/', 2)[-1]
+ local_volume_path = Path(LOCAL_VOLUME_DIR) / in_volume_path
+ local_volume_path.parent.mkdir(parents=True, exist_ok=True)
+ if not local_volume_path.exists():
+ shutil.copy(filepath, str(local_volume_path))
+
+ async def copy_file_from_s3(self, filepath):
+ bucket, in_bucket_path = filepath.replace('s3://','').split('/', 1)
+ await async_copy_from_s3(bucket, in_bucket_path, str(Path(VOLUME_DIR) / in_bucket_path))
+
+ async def copy_file_to_local(self, filepath):
+ if not filepath:
+ return
+ if filepath.startswith('s3://'):
+ await self.copy_file_from_s3(filepath)
+ else:
+ self.copy_file_from_remote_volume(filepath)
+
+ @method()
+ async def generate(self, image1, video, textbox_in, use_existing_output=True):
+ inputs = (image1, video, textbox_in)
+ if inputs in self.cache and use_existing_output:
+ res = self.cache[inputs]
+ self.cache.close()
+ return res
+ remote_volume.reload()
+ local_volume.reload()
+ await self.copy_file_to_local(image1)
+ await self.copy_file_to_local(video)
+
+ from videollava.conversation import conv_templates
+ from videollava.constants import DEFAULT_IMAGE_TOKEN
+ if not textbox_in:
+ raise ValueError("no prompt provided")
+
+ image1 = image1 if image1 else "none"
+ video = video if video else "none"
+
+ state_ = conv_templates[self.conv_mode].copy()
+ images_tensor = []
+
+ text_en_in = textbox_in.replace("picture", "image")
+
+ image_processor = self.handler.image_processor
+ if os.path.exists(image1) and not os.path.exists(video):
+ tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
+ tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
+ images_tensor.append(tensor)
+ video_processor = self.handler.video_processor
+ if not os.path.exists(image1) and os.path.exists(video):
+ tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
+ tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
+ images_tensor.append(tensor)
+ if os.path.exists(image1) and os.path.exists(video):
+ tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
+ tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
+ images_tensor.append(tensor)
+
+ tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
+ tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
+ images_tensor.append(tensor)
+
+ if os.path.exists(image1) and not os.path.exists(video):
+ text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
+ elif not os.path.exists(image1) and os.path.exists(video):
+ text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in
+ elif os.path.exists(image1) and os.path.exists(video):
+ text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN
+ else:
+ print("WARNING: No image or video supplied")
+
+ text_en_out, _ = self.handler.generate(images_tensor, text_en_in, first_run=True, state=state_)
+
+ text_en_out = text_en_out.split('#')[0]
+ textbox_out = text_en_out
+
+ if not textbox_out:
+ raise ValueError("no text generated")
+ self.cache.set(inputs, textbox_out)
+ self.cache.close()
+ return textbox_out
+
+
+
+def fastapi_app():
+ from fastapi import FastAPI, UploadFile, File
+ import aiofiles
+
+ Path(MODEL_CACHE).mkdir(exist_ok=True, parents=True)
+ VIDEOS_DIR.mkdir(exist_ok=True, parents=True)
+ IMAGES_DIR.mkdir(exist_ok=True, parents=True)
+
+ app = FastAPI()
+ model = VideoLlavaModel()
+
+ @app.post("/upload")
+ async def upload(
+ file: UploadFile = File(...),
+ ):
+ local_volume.reload()
+ filename_decoded = urllib.parse.unquote(file.filename)
+ file_path = str(Path(LOCAL_VOLUME_DIR) / filename_decoded)
+ async with aiofiles.open(file_path, "wb") as buffer:
+ while content := await file.read(1024): # Read chunks of 1024 bytes
+ await buffer.write(content)
+ local_volume.commit()
+ return {"file_path": file_path}
+
+ @app.post("/inference")
+ async def inference(
+ video_file_name: str = '',
+ video_file_path: str = '',
+ image_file_name: str = '',
+ image_file_path: str = '',
+ prompt: str = '',
+ ):
+ video_file_name = urllib.parse.unquote(video_file_name)
+ video_file_path = urllib.parse.unquote(video_file_path)
+ if video_file_path is None or video_file_path == '':
+ if video_file_name is None or video_file_name == '':
+ raise ValueError("one of video_file_path or video_file_name must be specified")
+ video_file_path = str(VIDEOS_DIR / video_file_name)
+
+ image_file_name = urllib.parse.unquote(image_file_name)
+ image_file_path = urllib.parse.unquote(image_file_path)
+ if image_file_path is None or image_file_path == '':
+ if image_file_name is not None and image_file_name != '':
+ image_file_path = str(IMAGES_DIR / image_file_name)
+
+ return model.generate.remote(image_file_path, video_file_path, prompt)
+ return app
+
+
+@function_dec()
+@asgi_app()
+def fastapi_app_modal():
+ return fastapi_app()
+
+# local testing:
+# comment this out to deploy
+# app = fastapi_app()
+# conda activate videollava
+# uvicorn modal_inference:app
diff --git a/pyproject.toml b/pyproject.toml
index 7c0b7f1..7f21463 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,7 +19,7 @@ dependencies = [
"pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
"requests", "httpx==0.24.0", "uvicorn", "fastapi",
"einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
- "tensorboardX==2.6.2.2", "gradio==3.37.0", "gradio_client==0.7.0"
+ "tensorboardX==2.6.2.2", "gradio==3.37.0", "gradio_client==0.7.0", "modal"
]
[project.optional-dependencies]
diff --git a/videollava/serve/gradio_web_server.py b/videollava/serve/gradio_web_server.py
index dfce85a..fc7d7de 100644
--- a/videollava/serve/gradio_web_server.py
+++ b/videollava/serve/gradio_web_server.py
@@ -1,249 +1,281 @@
-import shutil
-import subprocess
-
-import torch
-import gradio as gr
-from fastapi import FastAPI
-import os
-from PIL import Image
-import tempfile
-from decord import VideoReader, cpu
-from transformers import TextStreamer
-
-from videollava.constants import DEFAULT_IMAGE_TOKEN
-from videollava.conversation import conv_templates, SeparatorStyle, Conversation
-from videollava.serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
-
-
-
-def save_image_to_local(image):
- filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
- image = Image.open(image)
- image.save(filename)
- # print(filename)
- return filename
-
-
-def save_video_to_local(video_path):
- filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
- shutil.copyfile(video_path, filename)
- return filename
-
-
-def generate(image1, video, textbox_in, first_run, state, state_, images_tensor):
- flag = 1
- if not textbox_in:
- if len(state_.messages) > 0:
- textbox_in = state_.messages[-1][1]
- state_.messages.pop(-1)
- flag = 0
- else:
- return "Please enter instruction"
-
- image1 = image1 if image1 else "none"
- video = video if video else "none"
- # assert not (os.path.exists(image1) and os.path.exists(video))
-
- if type(state) is not Conversation:
- state = conv_templates[conv_mode].copy()
- state_ = conv_templates[conv_mode].copy()
- images_tensor = []
-
- first_run = False if len(state.messages) > 0 else True
-
- text_en_in = textbox_in.replace("picture", "image")
-
- # images_tensor = [[], []]
- image_processor = handler.image_processor
- if os.path.exists(image1) and not os.path.exists(video):
- tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
- # print(tensor.shape)
- tensor = tensor.to(handler.model.device, dtype=dtype)
- images_tensor.append(tensor)
- video_processor = handler.video_processor
- if not os.path.exists(image1) and os.path.exists(video):
- tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
- # print(tensor.shape)
- tensor = tensor.to(handler.model.device, dtype=dtype)
- images_tensor.append(tensor)
- if os.path.exists(image1) and os.path.exists(video):
- tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
- # print(tensor.shape)
- tensor = tensor.to(handler.model.device, dtype=dtype)
- images_tensor.append(tensor)
-
- tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
- # print(tensor.shape)
- tensor = tensor.to(handler.model.device, dtype=dtype)
- images_tensor.append(tensor)
-
- if os.path.exists(image1) and not os.path.exists(video):
- text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
- if not os.path.exists(image1) and os.path.exists(video):
- text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in
- if os.path.exists(image1) and os.path.exists(video):
- text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN
- # print(text_en_in)
- text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
- state_.messages[-1] = (state_.roles[1], text_en_out)
-
- text_en_out = text_en_out.split('#')[0]
- textbox_out = text_en_out
-
- show_images = ""
- if os.path.exists(image1):
- filename = save_image_to_local(image1)
- show_images += f''
- if os.path.exists(video):
- filename = save_video_to_local(video)
- show_images += f''
-
- if flag:
- state.append_message(state.roles[0], textbox_in + "\n" + show_images)
- state.append_message(state.roles[1], textbox_out)
-
- return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
-
-
-def regenerate(state, state_):
- state.messages.pop(-1)
- state_.messages.pop(-1)
- if len(state.messages) > 0:
- return state, state_, state.to_gradio_chatbot(), False
- return (state, state_, state.to_gradio_chatbot(), True)
-
-
-def clear_history(state, state_):
- state = conv_templates[conv_mode].copy()
- state_ = conv_templates[conv_mode].copy()
- return (gr.update(value=None, interactive=True),
- gr.update(value=None, interactive=True), \
- gr.update(value=None, interactive=True), \
- True, state, state_, state.to_gradio_chatbot(), [])
-
-
-conv_mode = "llava_v1"
-model_path = 'LanguageBind/Video-LLaVA-7B'
-cache_dir = './cache_dir'
-device = 'cuda'
-load_8bit = True
-load_4bit = False
-dtype = torch.float16
-handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir)
-# handler.model.to(dtype=dtype)
-if not os.path.exists("temp"):
- os.makedirs("temp")
-
-app = FastAPI()
-
-
-textbox = gr.Textbox(
- show_label=False, placeholder="Enter text and press ENTER", container=False
-)
-with gr.Blocks(title='Video-LLaVA๐', theme=gr.themes.Default(), css=block_css) as demo:
- gr.Markdown(title_markdown)
- state = gr.State()
- state_ = gr.State()
- first_run = gr.State()
- images_tensor = gr.State()
-
- with gr.Row():
- with gr.Column(scale=3):
- image1 = gr.Image(label="Input Image", type="filepath")
- video = gr.Video(label="Input Video")
-
- cur_dir = os.path.dirname(os.path.abspath(__file__))
- gr.Examples(
- examples=[
- [
- f"{cur_dir}/examples/extreme_ironing.jpg",
- "What is unusual about this image?",
- ],
- [
- f"{cur_dir}/examples/waterview.jpg",
- "What are the things I should be cautious about when I visit here?",
- ],
- [
- f"{cur_dir}/examples/desert.jpg",
- "If there are factual errors in the questions, point it out; if not, proceed answering the question. Whatโs happening in the desert?",
- ],
- ],
- inputs=[image1, textbox],
- )
-
- with gr.Column(scale=7):
- chatbot = gr.Chatbot(label="Video-LLaVA", bubble_full_width=True).style(height=750)
- with gr.Row():
- with gr.Column(scale=8):
- textbox.render()
- with gr.Column(scale=1, min_width=50):
- submit_btn = gr.Button(
- value="Send", variant="primary", interactive=True
- )
- with gr.Row(elem_id="buttons") as button_row:
- upvote_btn = gr.Button(value="๐ Upvote", interactive=True)
- downvote_btn = gr.Button(value="๐ Downvote", interactive=True)
- flag_btn = gr.Button(value="โ ๏ธ Flag", interactive=True)
- # stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False)
- regenerate_btn = gr.Button(value="๐ Regenerate", interactive=True)
- clear_btn = gr.Button(value="๐๏ธ Clear history", interactive=True)
-
- with gr.Row():
- gr.Examples(
- examples=[
- [
- f"{cur_dir}/examples/sample_img_22.png",
- f"{cur_dir}/examples/sample_demo_22.mp4",
- "Are the instruments in the pictures used in the video?",
- ],
- [
- f"{cur_dir}/examples/sample_img_13.png",
- f"{cur_dir}/examples/sample_demo_13.mp4",
- "Does the flag in the image appear in the video?",
- ],
- [
- f"{cur_dir}/examples/sample_img_8.png",
- f"{cur_dir}/examples/sample_demo_8.mp4",
- "Are the image and the video depicting the same place?",
- ],
- ],
- inputs=[image1, video, textbox],
- )
- gr.Examples(
- examples=[
- [
- f"{cur_dir}/examples/sample_demo_1.mp4",
- "Why is this video funny?",
- ],
- [
- f"{cur_dir}/examples/sample_demo_3.mp4",
- "Can you identify any safety hazards in this video?"
- ],
- [
- f"{cur_dir}/examples/sample_demo_9.mp4",
- "Describe the video.",
- ],
- [
- f"{cur_dir}/examples/sample_demo_22.mp4",
- "Describe the activity in the video.",
- ],
- ],
- inputs=[video, textbox],
- )
- gr.Markdown(tos_markdown)
- gr.Markdown(learn_more_markdown)
-
- submit_btn.click(generate, [image1, video, textbox, first_run, state, state_, images_tensor],
- [state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
-
- regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
- generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
-
- clear_btn.click(clear_history, [state, state_],
- [image1, video, textbox, first_run, state, state_, chatbot, images_tensor])
-
-# app = gr.mount_gradio_app(app, demo, path="/")
-demo.launch()
-
-# uvicorn videollava.serve.gradio_web_server:app
-# python -m videollava.serve.gradio_web_server
+# import shutil
+# import os
+# import tempfile
+
+# from modal import asgi_app, method, enter
+# from ...stub import VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec
+
+
+# def save_image_to_local(image):
+ # from PIL import Image
+ # filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.jpg')
+ # image = Image.open(image)
+ # image.save(filename)
+ # # print(filename)
+ # return filename
+
+
+# def save_video_to_local(video_path):
+ # filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.mp4')
+ # shutil.copyfile(video_path, filename)
+ # return filename
+
+
+# @cls_dec(gpu="any")
+# class VideoLlavaModel:
+ # @enter()
+ # def load_model(self):
+ # import torch
+ # from videollava.serve.gradio_utils import Chat
+ # self.conv_mode = "llava_v1"
+ # model_path = 'LanguageBind/Video-LLaVA-7B'
+ # device = 'cuda'
+ # load_8bit = True
+ # load_4bit = False
+ # self.dtype = torch.float16
+ # self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=MODEL_CACHE)
+ # # self.handler.model.to(dtype=self.dtype)
+
+ # @method()
+ # def generate(self, image1, video, textbox_in, first_run, state, state_, images_tensor):
+ # from videollava.conversation import conv_templates, Conversation
+ # import gradio as gr
+ # from videollava.constants import DEFAULT_IMAGE_TOKEN
+ # flag = 1
+ # if not textbox_in:
+ # if len(state_.messages) > 0:
+ # textbox_in = state_.messages[-1][1]
+ # state_.messages.pop(-1)
+ # flag = 0
+ # else:
+ # return "Please enter instruction"
+
+ # image1 = image1 if image1 else "none"
+ # video = video if video else "none"
+ # # assert not (os.path.exists(image1) and os.path.exists(video))
+
+ # if type(state) is not Conversation:
+ # state = conv_templates[self.conv_mode].copy()
+ # state_ = conv_templates[self.conv_mode].copy()
+ # images_tensor = []
+
+ # first_run = False if len(state.messages) > 0 else True
+
+ # text_en_in = textbox_in.replace("picture", "image")
+
+ # # images_tensor = [[], []]
+ # image_processor = self.handler.image_processor
+ # if os.path.exists(image1) and not os.path.exists(video):
+ # tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
+ # # print(tensor.shape)
+ # tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
+ # images_tensor.append(tensor)
+ # video_processor = self.handler.video_processor
+ # if not os.path.exists(image1) and os.path.exists(video):
+ # tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
+ # # print(tensor.shape)
+ # tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
+ # images_tensor.append(tensor)
+ # if os.path.exists(image1) and os.path.exists(video):
+ # tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
+ # # print(tensor.shape)
+ # tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
+ # images_tensor.append(tensor)
+
+ # tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
+ # # print(tensor.shape)
+ # tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
+ # images_tensor.append(tensor)
+
+ # if os.path.exists(image1) and not os.path.exists(video):
+ # text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
+ # if not os.path.exists(image1) and os.path.exists(video):
+ # text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in
+ # if os.path.exists(image1) and os.path.exists(video):
+ # text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN
+ # # print(text_en_in)
+ # text_en_out, state_ = self.handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
+ # state_.messages[-1] = (state_.roles[1], text_en_out)
+
+ # text_en_out = text_en_out.split('#')[0]
+ # textbox_out = text_en_out
+
+ # show_images = ""
+ # if os.path.exists(image1):
+ # filename = save_image_to_local(image1)
+ # show_images += f''
+ # if os.path.exists(video):
+ # filename = save_video_to_local(video)
+ # show_images += f''
+
+ # if flag:
+ # state.append_message(state.roles[0], textbox_in + "\n" + show_images)
+ # state.append_message(state.roles[1], textbox_out)
+
+ # return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
+
+ # @method()
+ # def clear_history(self, state, state_):
+ # from videollava.conversation import conv_templates
+ # import gradio as gr
+ # state = conv_templates[self.conv_mode].copy()
+ # state_ = conv_templates[self.conv_mode].copy()
+ # return (gr.update(value=None, interactive=True),
+ # gr.update(value=None, interactive=True), \
+ # gr.update(value=None, interactive=True), \
+ # True, state, state_, state.to_gradio_chatbot(), [])
+
+
+
+
+# def regenerate(state, state_):
+ # state.messages.pop(-1)
+ # state_.messages.pop(-1)
+ # if len(state.messages) > 0:
+ # return state, state_, state.to_gradio_chatbot(), False
+ # return (state, state_, state.to_gradio_chatbot(), True)
+
+
+
+
+
+# def build_gradio_interface(model):
+ # import gradio as gr
+ # from videollava.serve.gradio_utils import tos_markdown, learn_more_markdown, title_markdown, block_css
+
+ # # if not os.path.exists("temp"):
+ # # os.makedirs("temp")
+
+
+ # textbox = gr.Textbox(
+ # show_label=False, placeholder="Enter text and press ENTER", container=False
+ # )
+ # with gr.Blocks(title='Video-LLaVA๐', theme=gr.themes.Default(), css=block_css) as interface:
+ # gr.Markdown(title_markdown)
+ # state = gr.State()
+ # state_ = gr.State()
+ # first_run = gr.State()
+ # images_tensor = gr.State()
+
+ # with gr.Row():
+ # with gr.Column(scale=3):
+ # image1 = gr.Image(label="Input Image", type="filepath")
+ # video = gr.Video(label="Input Video")
+
+ # cur_dir = os.path.dirname(os.path.abspath(__file__))
+ # gr.Examples(
+ # examples=[
+ # [
+ # f"{cur_dir}/examples/extreme_ironing.jpg",
+ # "What is unusual about this image?",
+ # ],
+ # [
+ # f"{cur_dir}/examples/waterview.jpg",
+ # "What are the things I should be cautious about when I visit here?",
+ # ],
+ # [
+ # f"{cur_dir}/examples/desert.jpg",
+ # "If there are factual errors in the questions, point it out; if not, proceed answering the question. Whatโs happening in the desert?",
+ # ],
+ # ],
+ # inputs=[image1, textbox],
+ # )
+
+ # with gr.Column(scale=7):
+ # chatbot = gr.Chatbot(label="Video-LLaVA", bubble_full_width=True).style(height=750)
+ # with gr.Row():
+ # with gr.Column(scale=8):
+ # textbox.render()
+ # with gr.Column(scale=1, min_width=50):
+ # submit_btn = gr.Button(
+ # value="Send", variant="primary", interactive=True
+ # )
+ # with gr.Row(elem_id="buttons") as button_row:
+ # upvote_btn = gr.Button(value="๐ Upvote", interactive=True)
+ # downvote_btn = gr.Button(value="๐ Downvote", interactive=True)
+ # flag_btn = gr.Button(value="โ ๏ธ Flag", interactive=True)
+ # # stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False)
+ # regenerate_btn = gr.Button(value="๐ Regenerate", interactive=True)
+ # clear_btn = gr.Button(value="๐๏ธ Clear history", interactive=True)
+
+ # with gr.Row():
+ # gr.Examples(
+ # examples=[
+ # [
+ # f"{cur_dir}/examples/sample_img_22.png",
+ # f"{cur_dir}/examples/sample_demo_22.mp4",
+ # "Are the instruments in the pictures used in the video?",
+ # ],
+ # [
+ # f"{cur_dir}/examples/sample_img_13.png",
+ # f"{cur_dir}/examples/sample_demo_13.mp4",
+ # "Does the flag in the image appear in the video?",
+ # ],
+ # [
+ # f"{cur_dir}/examples/sample_img_8.png",
+ # f"{cur_dir}/examples/sample_demo_8.mp4",
+ # "Are the image and the video depicting the same place?",
+ # ],
+ # ],
+ # inputs=[image1, video, textbox],
+ # )
+ # gr.Examples(
+ # examples=[
+ # [
+ # f"{cur_dir}/examples/sample_demo_1.mp4",
+ # "Why is this video funny?",
+ # ],
+ # [
+ # f"{cur_dir}/examples/sample_demo_3.mp4",
+ # "Can you identify any safety hazards in this video?"
+ # ],
+ # [
+ # f"{cur_dir}/examples/sample_demo_9.mp4",
+ # "Describe the video.",
+ # ],
+ # [
+ # f"{cur_dir}/examples/sample_demo_22.mp4",
+ # "Describe the activity in the video.",
+ # ],
+ # ],
+ # inputs=[video, textbox],
+ # )
+ # gr.Markdown(tos_markdown)
+ # gr.Markdown(learn_more_markdown)
+
+ # submit_btn.click(model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor],
+ # [state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
+
+ # regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
+ # model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
+
+ # clear_btn.click(model.clear_history.remote, [state, state_],
+ # [image1, video, textbox, first_run, state, state_, chatbot, images_tensor])
+ # return interface
+
+
+# @function_dec(gpu="any")
+# @asgi_app()
+# def fastapi_app():
+ # from gradio.routes import mount_gradio_app
+ # from fastapi import FastAPI
+ # app = FastAPI()
+
+ # model = VideoLlavaModel()
+ # # interface = gr.Interface(
+ # # fn=classifier.predict.remote,
+ # # inputs=gr.Image(shape=(224, 224)),
+ # # outputs="label",
+ # # examples=create_demo_examples(),
+ # # css="/assets/index.css",
+ # # )
+ # return mount_gradio_app(
+ # app=app,
+ # blocks=build_gradio_interface(model),
+ # path="/",
+ # )
+# # app = gr.mount_gradio_app(app, demo, path="/")
+# # demo.launch()
+
+# # uvicorn videollava.serve.gradio_web_server:app
+# # python -m videollava.serve.gradio_web_server