Spaces:
Runtime error
Runtime error
update code
Browse files
app.py
CHANGED
|
@@ -88,8 +88,8 @@ def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float =
|
|
| 88 |
|
| 89 |
def inference(ic_image, ic_mask, image1, image2):
|
| 90 |
# in context image and mask
|
| 91 |
-
ic_image =
|
| 92 |
-
|
| 93 |
|
| 94 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 95 |
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
|
@@ -114,7 +114,7 @@ def inference(ic_image, ic_mask, image1, image2):
|
|
| 114 |
|
| 115 |
for test_image in [image1, image2]:
|
| 116 |
print("======> Testing Image" )
|
| 117 |
-
test_image =
|
| 118 |
|
| 119 |
# Image feature encoding
|
| 120 |
predictor.set_image(test_image)
|
|
@@ -188,8 +188,8 @@ def inference_scribble(image, image1, image2):
|
|
| 188 |
# in context image and mask
|
| 189 |
ic_image = image["image"]
|
| 190 |
ic_mask = image["mask"]
|
| 191 |
-
ic_image =
|
| 192 |
-
|
| 193 |
|
| 194 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 195 |
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
|
@@ -214,7 +214,7 @@ def inference_scribble(image, image1, image2):
|
|
| 214 |
|
| 215 |
for test_image in [image1, image2]:
|
| 216 |
print("======> Testing Image" )
|
| 217 |
-
test_image =
|
| 218 |
|
| 219 |
# Image feature encoding
|
| 220 |
predictor.set_image(test_image)
|
|
@@ -286,8 +286,8 @@ def inference_scribble(image, image1, image2):
|
|
| 286 |
|
| 287 |
def inference_finetune(ic_image, ic_mask, image1, image2):
|
| 288 |
# in context image and mask
|
| 289 |
-
ic_image =
|
| 290 |
-
|
| 291 |
|
| 292 |
gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
|
| 293 |
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
|
|
@@ -377,7 +377,7 @@ def inference_finetune(ic_image, ic_mask, image1, image2):
|
|
| 377 |
output_image = []
|
| 378 |
|
| 379 |
for test_image in [image1, image2]:
|
| 380 |
-
test_image =
|
| 381 |
|
| 382 |
# Image feature encoding
|
| 383 |
predictor.set_image(test_image)
|
|
@@ -466,14 +466,14 @@ description = """
|
|
| 466 |
main = gr.Interface(
|
| 467 |
fn=inference,
|
| 468 |
inputs=[
|
| 469 |
-
gr.Image(label="in context image",),
|
| 470 |
-
gr.Image(label="in context mask"),
|
| 471 |
-
gr.Image(label="test image1"),
|
| 472 |
-
gr.Image(label="test image2"),
|
| 473 |
],
|
| 474 |
outputs=[
|
| 475 |
-
gr.Image(label="output image1").style(height=256, width=256),
|
| 476 |
-
gr.Image(label="output image2").style(height=256, width=256),
|
| 477 |
],
|
| 478 |
allow_flagging="never",
|
| 479 |
cache_examples=False,
|
|
@@ -490,13 +490,13 @@ main = gr.Interface(
|
|
| 490 |
main_scribble = gr.Interface(
|
| 491 |
fn=inference_scribble,
|
| 492 |
inputs=[
|
| 493 |
-
gr.ImageMask(label="[Stroke] Draw on Image"),
|
| 494 |
-
gr.Image(label="test image1"),
|
| 495 |
-
gr.Image(label="test image2"),
|
| 496 |
],
|
| 497 |
outputs=[
|
| 498 |
-
gr.Image(label="output image1").style(height=256, width=256),
|
| 499 |
-
gr.Image(label="output image2").style(height=256, width=256),
|
| 500 |
],
|
| 501 |
allow_flagging="never",
|
| 502 |
cache_examples=False,
|
|
@@ -510,17 +510,18 @@ main_scribble = gr.Interface(
|
|
| 510 |
)
|
| 511 |
"""
|
| 512 |
|
|
|
|
| 513 |
main_finetune = gr.Interface(
|
| 514 |
fn=inference_finetune,
|
| 515 |
inputs=[
|
| 516 |
-
gr.Image(label="in context image",),
|
| 517 |
-
gr.Image(label="in context mask"),
|
| 518 |
-
gr.Image(label="test image1"),
|
| 519 |
-
gr.Image(label="test image2"),
|
| 520 |
],
|
| 521 |
outputs=[
|
| 522 |
-
gr.Image(label="output image1").style(height=256, width=256),
|
| 523 |
-
gr.Image(label="output image2").style(height=256, width=256),
|
| 524 |
],
|
| 525 |
allow_flagging="never",
|
| 526 |
cache_examples=False,
|
|
|
|
| 88 |
|
| 89 |
def inference(ic_image, ic_mask, image1, image2):
|
| 90 |
# in context image and mask
|
| 91 |
+
ic_image = np.array(ic_image.convert("RGB"))
|
| 92 |
+
ic_mask = np.array(ic_mask.convert("RGB"))
|
| 93 |
|
| 94 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 95 |
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
|
|
|
| 114 |
|
| 115 |
for test_image in [image1, image2]:
|
| 116 |
print("======> Testing Image" )
|
| 117 |
+
test_image = np.array(test_image.convert("RGB"))
|
| 118 |
|
| 119 |
# Image feature encoding
|
| 120 |
predictor.set_image(test_image)
|
|
|
|
| 188 |
# in context image and mask
|
| 189 |
ic_image = image["image"]
|
| 190 |
ic_mask = image["mask"]
|
| 191 |
+
ic_image = np.array(ic_image.convert("RGB"))
|
| 192 |
+
ic_mask = np.array(ic_mask.convert("RGB"))
|
| 193 |
|
| 194 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 195 |
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
|
|
|
| 214 |
|
| 215 |
for test_image in [image1, image2]:
|
| 216 |
print("======> Testing Image" )
|
| 217 |
+
test_image = np.array(test_image.convert("RGB"))
|
| 218 |
|
| 219 |
# Image feature encoding
|
| 220 |
predictor.set_image(test_image)
|
|
|
|
| 286 |
|
| 287 |
def inference_finetune(ic_image, ic_mask, image1, image2):
|
| 288 |
# in context image and mask
|
| 289 |
+
ic_image = np.array(ic_image.convert("RGB"))
|
| 290 |
+
ic_mask = np.array(ic_mask.convert("RGB"))
|
| 291 |
|
| 292 |
gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
|
| 293 |
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
|
|
|
|
| 377 |
output_image = []
|
| 378 |
|
| 379 |
for test_image in [image1, image2]:
|
| 380 |
+
test_image = np.array(test_image.convert("RGB"))
|
| 381 |
|
| 382 |
# Image feature encoding
|
| 383 |
predictor.set_image(test_image)
|
|
|
|
| 466 |
main = gr.Interface(
|
| 467 |
fn=inference,
|
| 468 |
inputs=[
|
| 469 |
+
gr.Image(label="in context image", type='pil'),
|
| 470 |
+
gr.Image(label="in context mask", type='pil'),
|
| 471 |
+
gr.Image(label="test image1", type='pil'),
|
| 472 |
+
gr.Image(label="test image2", type='pil'),
|
| 473 |
],
|
| 474 |
outputs=[
|
| 475 |
+
gr.Image(label="output image1", type='pil').style(height=256, width=256),
|
| 476 |
+
gr.Image(label="output image2", type='pil').style(height=256, width=256),
|
| 477 |
],
|
| 478 |
allow_flagging="never",
|
| 479 |
cache_examples=False,
|
|
|
|
| 490 |
main_scribble = gr.Interface(
|
| 491 |
fn=inference_scribble,
|
| 492 |
inputs=[
|
| 493 |
+
gr.ImageMask(label="[Stroke] Draw on Image", brush_radius=4, type='pil'),
|
| 494 |
+
gr.Image(label="test image1", type='pil'),
|
| 495 |
+
gr.Image(label="test image2", type='pil'),
|
| 496 |
],
|
| 497 |
outputs=[
|
| 498 |
+
gr.Image(label="output image1", type='pil').style(height=256, width=256),
|
| 499 |
+
gr.Image(label="output image2", type='pil').style(height=256, width=256),
|
| 500 |
],
|
| 501 |
allow_flagging="never",
|
| 502 |
cache_examples=False,
|
|
|
|
| 510 |
)
|
| 511 |
"""
|
| 512 |
|
| 513 |
+
|
| 514 |
main_finetune = gr.Interface(
|
| 515 |
fn=inference_finetune,
|
| 516 |
inputs=[
|
| 517 |
+
gr.Image(label="in context image", type='pil'),
|
| 518 |
+
gr.Image(label="in context mask", type='pil'),
|
| 519 |
+
gr.Image(label="test image1", type='pil'),
|
| 520 |
+
gr.Image(label="test image2", type='pil'),
|
| 521 |
],
|
| 522 |
outputs=[
|
| 523 |
+
gr.Image(label="output image1", type='pil').style(height=256, width=256),
|
| 524 |
+
gr.Image(label="output image2", type='pil').style(height=256, width=256),
|
| 525 |
],
|
| 526 |
allow_flagging="never",
|
| 527 |
cache_examples=False,
|