Spaces:
Runtime error
Runtime error
Commit
·
b0d85ba
1
Parent(s):
21aea4b
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,11 +35,11 @@ image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
|
|
| 35 |
##########################################################################
|
| 36 |
def generate_prompt(instruction):
|
| 37 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
| 38 |
-
return f"{instruction}\n\nAssistant:"
|
| 39 |
|
| 40 |
def generate(
|
| 41 |
ctx,
|
| 42 |
-
|
| 43 |
token_count=128,
|
| 44 |
temperature=0.2,
|
| 45 |
top_p=0.3,
|
|
@@ -58,10 +58,8 @@ def generate(
|
|
| 58 |
occurrence = {}
|
| 59 |
for i in range(int(token_count)):
|
| 60 |
if i == 0:
|
| 61 |
-
input_ids = pipeline.encode(ctx)
|
| 62 |
-
|
| 63 |
-
input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
|
| 64 |
-
out, state = model.forward(embs=input_embs, state=None)
|
| 65 |
else:
|
| 66 |
input_ids = [token]
|
| 67 |
out, state = model.forward(tokens=input_ids, state=state)
|
|
@@ -113,11 +111,10 @@ def pil_image_to_base64(pil_image):
|
|
| 113 |
return base64_image
|
| 114 |
|
| 115 |
image_cache = {}
|
| 116 |
-
def
|
| 117 |
base64_image = pil_image_to_base64(image)
|
| 118 |
if base64_image in image_cache:
|
| 119 |
-
|
| 120 |
-
print(f"use cache {base64_image[:10]}")
|
| 121 |
else:
|
| 122 |
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
|
| 123 |
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
|
@@ -126,16 +123,17 @@ def get_image_features(image):
|
|
| 126 |
(image_features.shape[-1],),
|
| 127 |
weight=model.w['blocks.0.ln0.weight'],
|
| 128 |
bias=model.w['blocks.0.ln0.bias'])
|
| 129 |
-
|
| 130 |
-
|
|
|
|
| 131 |
|
| 132 |
def chatbot(image, question):
|
| 133 |
if image is None:
|
| 134 |
yield "Please upload an image."
|
| 135 |
return
|
| 136 |
-
|
| 137 |
input_text = generate_prompt(question)
|
| 138 |
-
for output in generate(input_text,
|
| 139 |
yield output
|
| 140 |
|
| 141 |
with gr.Blocks(title=title) as demo:
|
|
|
|
| 35 |
##########################################################################
|
| 36 |
def generate_prompt(instruction):
|
| 37 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
| 38 |
+
return f"\n{instruction}\n\nAssistant:"
|
| 39 |
|
| 40 |
def generate(
|
| 41 |
ctx,
|
| 42 |
+
image_state,
|
| 43 |
token_count=128,
|
| 44 |
temperature=0.2,
|
| 45 |
top_p=0.3,
|
|
|
|
| 58 |
occurrence = {}
|
| 59 |
for i in range(int(token_count)):
|
| 60 |
if i == 0:
|
| 61 |
+
input_ids = pipeline.encode(ctx)[-ctx_limit:]
|
| 62 |
+
out, state = model.forward(tokens=input_ids, state=image_state)
|
|
|
|
|
|
|
| 63 |
else:
|
| 64 |
input_ids = [token]
|
| 65 |
out, state = model.forward(tokens=input_ids, state=state)
|
|
|
|
| 111 |
return base64_image
|
| 112 |
|
| 113 |
image_cache = {}
|
| 114 |
+
def compute_image_state(image):
|
| 115 |
base64_image = pil_image_to_base64(image)
|
| 116 |
if base64_image in image_cache:
|
| 117 |
+
image_state = image_cache[base64_image]
|
|
|
|
| 118 |
else:
|
| 119 |
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
|
| 120 |
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
|
|
|
| 123 |
(image_features.shape[-1],),
|
| 124 |
weight=model.w['blocks.0.ln0.weight'],
|
| 125 |
bias=model.w['blocks.0.ln0.bias'])
|
| 126 |
+
_, image_state = model.forward(embs=image_features, state=None)
|
| 127 |
+
image_cache[base64_image] = image_state
|
| 128 |
+
return image_state
|
| 129 |
|
| 130 |
def chatbot(image, question):
|
| 131 |
if image is None:
|
| 132 |
yield "Please upload an image."
|
| 133 |
return
|
| 134 |
+
image_state = compute_image_state(image)
|
| 135 |
input_text = generate_prompt(question)
|
| 136 |
+
for output in generate(input_text, image_state):
|
| 137 |
yield output
|
| 138 |
|
| 139 |
with gr.Blocks(title=title) as demo:
|