loads n closet via api endpoint
Browse files
app.py
CHANGED
|
@@ -79,30 +79,26 @@ def main(
|
|
| 79 |
seed=None
|
| 80 |
):
|
| 81 |
|
| 82 |
-
if seed == None:
|
| 83 |
-
seed = np.random.randint(2147483647)
|
| 84 |
-
# if device contains cuda
|
| 85 |
-
if device.type == 'cuda':
|
| 86 |
-
generator = torch.Generator(device=device).manual_seed(int(seed))
|
| 87 |
-
else:
|
| 88 |
-
generator = torch.Generator().manual_seed(int(seed)) # use cpu as does not work on mps
|
| 89 |
-
|
| 90 |
embeddings = base64_to_embedding(embeddings)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
# inp.tile(n_samples, 1, 1, 1),
|
| 95 |
-
# [embeddings * n_samples],
|
| 96 |
-
embeddings,
|
| 97 |
-
guidance_scale=scale,
|
| 98 |
-
num_inference_steps=steps,
|
| 99 |
-
generator=generator,
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
images = []
|
| 103 |
-
for
|
| 104 |
-
images
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
return images
|
| 107 |
|
| 108 |
def on_image_load_update_embeddings(image_data):
|
|
@@ -146,10 +142,10 @@ def update_average_embeddings(embedding_base64s_state, embedding_powers):
|
|
| 146 |
return gr.Text.update('')
|
| 147 |
|
| 148 |
# TODO toggle this to support average or sum
|
| 149 |
-
final_embedding = final_embedding / num_embeddings
|
| 150 |
|
| 151 |
# normalize embeddings in numpy
|
| 152 |
-
final_embedding /= np.linalg.norm(final_embedding)
|
| 153 |
|
| 154 |
embeddings_b64 = embedding_to_base64(final_embedding)
|
| 155 |
return embeddings_b64
|
|
@@ -368,12 +364,12 @@ Try uploading a few images and/or add some text prompts and click generate image
|
|
| 368 |
with gr.Accordion(f"Avergage embeddings in base 64", open=False):
|
| 369 |
average_embedding_base64 = gr.Textbox(show_label=False)
|
| 370 |
with gr.Row():
|
| 371 |
-
submit = gr.Button("
|
| 372 |
with gr.Row():
|
| 373 |
with gr.Column(scale=1, min_width=200):
|
| 374 |
scale = gr.Slider(0, 25, value=3, step=1, label="Guidance scale")
|
| 375 |
with gr.Column(scale=1, min_width=200):
|
| 376 |
-
n_samples = gr.Slider(1,
|
| 377 |
with gr.Column(scale=1, min_width=200):
|
| 378 |
steps = gr.Slider(5, 50, value=25, step=5, label="Steps")
|
| 379 |
with gr.Column(scale=1, min_width=200):
|
|
|
|
| 79 |
seed=None
|
| 80 |
):
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
embeddings = base64_to_embedding(embeddings)
|
| 83 |
+
# convert to python array
|
| 84 |
+
embeddings = embeddings.tolist()
|
| 85 |
+
results = clip_retrieval_client.query(embedding_input=embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
images = []
|
| 87 |
+
for result in results:
|
| 88 |
+
if len(images) >= n_samples:
|
| 89 |
+
break
|
| 90 |
+
# dowload image
|
| 91 |
+
import requests
|
| 92 |
+
from io import BytesIO
|
| 93 |
+
response = requests.get(result["url"])
|
| 94 |
+
if not response.ok:
|
| 95 |
+
continue
|
| 96 |
+
try:
|
| 97 |
+
bytes = BytesIO(response.content)
|
| 98 |
+
image = Image.open(bytes)
|
| 99 |
+
images.append(image)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(e)
|
| 102 |
return images
|
| 103 |
|
| 104 |
def on_image_load_update_embeddings(image_data):
|
|
|
|
| 142 |
return gr.Text.update('')
|
| 143 |
|
| 144 |
# TODO toggle this to support average or sum
|
| 145 |
+
# final_embedding = final_embedding / num_embeddings
|
| 146 |
|
| 147 |
# normalize embeddings in numpy
|
| 148 |
+
# final_embedding /= np.linalg.norm(final_embedding)
|
| 149 |
|
| 150 |
embeddings_b64 = embedding_to_base64(final_embedding)
|
| 151 |
return embeddings_b64
|
|
|
|
| 364 |
with gr.Accordion(f"Avergage embeddings in base 64", open=False):
|
| 365 |
average_embedding_base64 = gr.Textbox(show_label=False)
|
| 366 |
with gr.Row():
|
| 367 |
+
submit = gr.Button("Search embedding space")
|
| 368 |
with gr.Row():
|
| 369 |
with gr.Column(scale=1, min_width=200):
|
| 370 |
scale = gr.Slider(0, 25, value=3, step=1, label="Guidance scale")
|
| 371 |
with gr.Column(scale=1, min_width=200):
|
| 372 |
+
n_samples = gr.Slider(1, 16, value=4, step=1, label="Number images")
|
| 373 |
with gr.Column(scale=1, min_width=200):
|
| 374 |
steps = gr.Slider(5, 50, value=25, step=5, label="Steps")
|
| 375 |
with gr.Column(scale=1, min_width=200):
|