Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,73 +1,63 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
|
| 3 |
-
from transformers.models.smolvlm.video_processing_smolvlm import load_smolvlm_video
|
| 4 |
-
from transformers.image_utils import load_image
|
| 5 |
from threading import Thread
|
| 6 |
import re
|
| 7 |
import time
|
| 8 |
import torch
|
| 9 |
-
|
| 10 |
#import subprocess
|
| 11 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 12 |
|
| 13 |
from io import BytesIO
|
| 14 |
-
from transformers.image_utils import load_image
|
| 15 |
|
| 16 |
-
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-
|
| 17 |
-
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM2-
|
| 18 |
_attn_implementation="flash_attention_2",
|
| 19 |
-
torch_dtype=torch.bfloat16
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
def model_inference(
|
| 24 |
-
input_dict, history
|
| 25 |
):
|
| 26 |
text = input_dict["text"]
|
| 27 |
-
# first turn input_dict {'text': 'What', 'files': ['/tmp/gradio/0350274350a64a5737e1a5732f014aee2f28bb7344bbad5105c0d0b7e7334375/cats_2.mp4', '/tmp/gradio/2dd39f382fcf5444a1a2ac57ed6f9acafa775dd855248cf273034e8ce18aeff4/IMG_2201.JPG']}
|
| 28 |
-
# first turn history []
|
| 29 |
-
print("input_dict", input_dict)
|
| 30 |
-
print("history", history)
|
| 31 |
-
print("model.device", model.device)
|
| 32 |
images = []
|
| 33 |
# first conv turn
|
| 34 |
if history == []:
|
| 35 |
text = input_dict["text"]
|
| 36 |
-
resulting_messages = [{"role": "user", "content": [{"type": "text"
|
| 37 |
for file in input_dict["files"]:
|
| 38 |
if file.endswith(".mp4"):
|
| 39 |
-
resulting_messages[0]["content"].append({"type": "video"})
|
| 40 |
-
|
| 41 |
-
file, sampling_fps=1, max_frames=64
|
| 42 |
-
)
|
| 43 |
-
print("frames", frames)
|
| 44 |
-
images.append(frames)
|
| 45 |
elif file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
|
| 46 |
-
resulting_messages[0]["content"].append({"type": "image"})
|
| 47 |
-
|
| 48 |
-
print("images", images)
|
| 49 |
-
|
| 50 |
-
# second turn input_dict {'text': 'what', 'files': ['/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG']}
|
| 51 |
-
# second turn history [[('/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG',), None],
|
| 52 |
-
# [('/tmp/gradio/5b105e97e4876912b4e763902144540bd3ab00d9fd4016491337ee4f4c36f320/football.mp4',), None], ['what', None]]
|
| 53 |
-
|
| 54 |
-
# later conv turn
|
| 55 |
elif len(history) > 0:
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
|
|
@@ -75,26 +65,22 @@ def model_inference(
|
|
| 75 |
gr.Error("Please input a query and optionally image(s).")
|
| 76 |
|
| 77 |
if text == "" and images:
|
| 78 |
-
gr.Error("Please input a text query along the
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
inputs = processor(text=prompt, images=[images], padding=True, return_tensors="pt")
|
| 84 |
inputs = inputs.to(model.device)
|
| 85 |
-
|
| 86 |
-
"input_ids": inputs.input_ids,
|
| 87 |
-
"pixel_values": inputs.pixel_values,
|
| 88 |
-
"attention_mask": inputs.attention_mask,
|
| 89 |
-
"num_return_sequences": 1,
|
| 90 |
-
"no_repeat_ngram_size": 2,
|
| 91 |
-
"max_new_tokens": 500,
|
| 92 |
-
"min_new_tokens": 10,
|
| 93 |
-
}
|
| 94 |
|
| 95 |
# Generate
|
| 96 |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
| 97 |
-
generation_args = dict(inputs, streamer=streamer, max_new_tokens=
|
| 98 |
generated_text = ""
|
| 99 |
|
| 100 |
thread = Thread(target=model.generate, kwargs=generation_args)
|
|
@@ -127,6 +113,7 @@ demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video
|
|
| 127 |
examples=examples,
|
| 128 |
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
|
| 129 |
cache_examples=False,
|
|
|
|
| 130 |
type="messages"
|
| 131 |
)
|
| 132 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
|
|
|
|
|
|
|
| 3 |
from threading import Thread
|
| 4 |
import re
|
| 5 |
import time
|
| 6 |
import torch
|
| 7 |
+
import spaces
|
| 8 |
#import subprocess
|
| 9 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 10 |
|
| 11 |
from io import BytesIO
|
|
|
|
| 12 |
|
| 13 |
+
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-500M-Instruct")
|
| 14 |
+
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM2-500M-Instruct",
|
| 15 |
_attn_implementation="flash_attention_2",
|
| 16 |
+
torch_dtype=torch.bfloat16).to("cuda:0")
|
| 17 |
|
| 18 |
|
| 19 |
+
@spaces.GPU
|
| 20 |
def model_inference(
|
| 21 |
+
input_dict, history, max_tokens
|
| 22 |
):
|
| 23 |
text = input_dict["text"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
images = []
|
| 25 |
# first conv turn
|
| 26 |
if history == []:
|
| 27 |
text = input_dict["text"]
|
| 28 |
+
resulting_messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
| 29 |
for file in input_dict["files"]:
|
| 30 |
if file.endswith(".mp4"):
|
| 31 |
+
resulting_messages[0]["content"].append({"type": "video", "path": file})
|
| 32 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
elif file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
|
| 34 |
+
resulting_messages[0]["content"].append({"type": "image", "path": file})
|
| 35 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
elif len(history) > 0:
|
| 37 |
+
resulting_messages = []
|
| 38 |
+
for entry in history:
|
| 39 |
+
if entry["role"] == "user":
|
| 40 |
+
user_content = []
|
| 41 |
+
if isinstance(entry["content"], tuple):
|
| 42 |
+
file_name = entry["content"][0]
|
| 43 |
+
if file_name.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
|
| 44 |
+
user_content.append({"type": "image", "path": file_name})
|
| 45 |
+
elif file_name.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
|
| 46 |
+
user_content.append({"type": "video", "path": file_name})
|
| 47 |
+
elif isinstance(entry["content"], str):
|
| 48 |
+
user_content.insert(0, {"type": "text", "text": entry["content"]})
|
| 49 |
+
|
| 50 |
+
elif entry["role"] == "assistant":
|
| 51 |
+
resulting_messages.append({
|
| 52 |
+
"role": "user",
|
| 53 |
+
"content": user_content
|
| 54 |
+
})
|
| 55 |
+
resulting_messages.append({
|
| 56 |
+
"role": "assistant",
|
| 57 |
+
"content": [{"type": "text", "text": entry["content"]}]
|
| 58 |
+
})
|
| 59 |
+
user_content = []
|
| 60 |
+
|
| 61 |
|
| 62 |
|
| 63 |
|
|
|
|
| 65 |
gr.Error("Please input a query and optionally image(s).")
|
| 66 |
|
| 67 |
if text == "" and images:
|
| 68 |
+
gr.Error("Please input a text query along the images(s).")
|
| 69 |
|
| 70 |
+
inputs = processor.apply_chat_template(
|
| 71 |
+
resulting_messages,
|
| 72 |
+
add_generation_prompt=True,
|
| 73 |
+
tokenize=True,
|
| 74 |
+
return_dict=True,
|
| 75 |
+
return_tensors="pt",
|
| 76 |
+
)
|
| 77 |
|
|
|
|
| 78 |
inputs = inputs.to(model.device)
|
| 79 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# Generate
|
| 82 |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
| 83 |
+
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
|
| 84 |
generated_text = ""
|
| 85 |
|
| 86 |
thread = Thread(target=model.generate, kwargs=generation_args)
|
|
|
|
| 113 |
examples=examples,
|
| 114 |
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
|
| 115 |
cache_examples=False,
|
| 116 |
+
additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
|
| 117 |
type="messages"
|
| 118 |
)
|
| 119 |
|