Spaces:
Runtime error
Runtime error
Commit
·
794ada2
1
Parent(s):
898a24b
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import os, gc
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from transformers import CLIPImageProcessor
|
|
@@ -103,17 +105,35 @@ examples = [
|
|
| 103 |
]
|
| 104 |
]
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
def chatbot(image, question):
|
| 107 |
if image is None:
|
| 108 |
yield "Please upload an image."
|
| 109 |
return
|
| 110 |
-
|
| 111 |
-
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
| 112 |
-
# apply layer norm to image feature, very important
|
| 113 |
-
image_features = F.layer_norm(image_features,
|
| 114 |
-
(image_features.shape[-1],),
|
| 115 |
-
weight=model.w['blocks.0.ln0.weight'],
|
| 116 |
-
bias=model.w['blocks.0.ln0.bias'])
|
| 117 |
input_text = generate_prompt(question)
|
| 118 |
for output in generate(input_text, image_features):
|
| 119 |
yield output
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import os, gc
|
| 3 |
+
import base64
|
| 4 |
+
from io import BytesIO
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from transformers import CLIPImageProcessor
|
|
|
|
| 105 |
]
|
| 106 |
]
|
| 107 |
|
| 108 |
+
|
| 109 |
+
def pil_image_to_base64(pil_image):
|
| 110 |
+
buffered = BytesIO()
|
| 111 |
+
pil_image.save(buffered, format="JPEG") # You can change the format as needed (JPEG, PNG, etc.)
|
| 112 |
+
# Encodes the image data into base64 format as a bytes object
|
| 113 |
+
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 114 |
+
return base64_image
|
| 115 |
+
|
| 116 |
+
image_cache = {}
|
| 117 |
+
def get_image_features(image):
|
| 118 |
+
base64_image = pil_image_to_base64(image)
|
| 119 |
+
if base64_image in image_cache:
|
| 120 |
+
image_features = image_cache[base64_image]
|
| 121 |
+
else:
|
| 122 |
+
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
|
| 123 |
+
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
| 124 |
+
# apply layer norm to image feature, very important
|
| 125 |
+
image_features = F.layer_norm(image_features,
|
| 126 |
+
(image_features.shape[-1],),
|
| 127 |
+
weight=model.w['blocks.0.ln0.weight'],
|
| 128 |
+
bias=model.w['blocks.0.ln0.bias'])
|
| 129 |
+
image_cache[base64_image] = image_features
|
| 130 |
+
return image_features
|
| 131 |
+
|
| 132 |
def chatbot(image, question):
|
| 133 |
if image is None:
|
| 134 |
yield "Please upload an image."
|
| 135 |
return
|
| 136 |
+
image_features = get_image_features(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
input_text = generate_prompt(question)
|
| 138 |
for output in generate(input_text, image_features):
|
| 139 |
yield output
|