Spaces:
Runtime error
Runtime error
Commit
·
f21d6cb
1
Parent(s):
818b0c6
Update app.py
Browse files
app.py
CHANGED
|
@@ -221,19 +221,22 @@ def pil_image_to_base64(pil_image):
|
|
| 221 |
return base64_image
|
| 222 |
|
| 223 |
image_cache = {}
|
|
|
|
|
|
|
| 224 |
def compute_image_state(image):
|
| 225 |
base64_image = pil_image_to_base64(image)
|
| 226 |
if base64_image in image_cache:
|
| 227 |
image_state = image_cache[base64_image]
|
| 228 |
else:
|
| 229 |
-
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
|
|
|
|
| 230 |
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
| 231 |
# apply layer norm to image feature, very important
|
| 232 |
image_features = F.layer_norm(image_features,
|
| 233 |
(image_features.shape[-1],),
|
| 234 |
-
weight=
|
| 235 |
-
bias=
|
| 236 |
-
_, image_state =
|
| 237 |
image_cache[base64_image] = image_state
|
| 238 |
return image_state
|
| 239 |
|
|
|
|
| 221 |
return base64_image
|
| 222 |
|
| 223 |
image_cache = {}
|
| 224 |
+
ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
|
| 225 |
+
ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
|
| 226 |
def compute_image_state(image):
|
| 227 |
base64_image = pil_image_to_base64(image)
|
| 228 |
if base64_image in image_cache:
|
| 229 |
image_state = image_cache[base64_image]
|
| 230 |
else:
|
| 231 |
+
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
|
| 232 |
+
image = image.to(device)
|
| 233 |
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
| 234 |
# apply layer norm to image feature, very important
|
| 235 |
image_features = F.layer_norm(image_features,
|
| 236 |
(image_features.shape[-1],),
|
| 237 |
+
weight=ln0_weight,
|
| 238 |
+
bias=ln0_bias)
|
| 239 |
+
_, image_state = model.forward(embs=image_features, state=None)
|
| 240 |
image_cache[base64_image] = image_state
|
| 241 |
return image_state
|
| 242 |
|