css change
Browse files- app.py +8 -5
- model.py +0 -1
- get_output.py → output.py +3 -2
app.py
CHANGED
|
@@ -13,7 +13,7 @@ from timm.data import resolve_data_config
|
|
| 13 |
from timm.data.transforms_factory import create_transform
|
| 14 |
|
| 15 |
from model import Model
|
| 16 |
-
from
|
| 17 |
|
| 18 |
|
| 19 |
# Use GPU if available
|
|
@@ -30,13 +30,13 @@ model.eval()
|
|
| 30 |
state = torch.load('saved_model', map_location=torch.device('cpu'))
|
| 31 |
model.load_state_dict(state['val_model_dict'])
|
| 32 |
|
| 33 |
-
#
|
| 34 |
config = resolve_data_config({}, model=vit)
|
| 35 |
config['no_aug'] = True
|
| 36 |
config['interpolation'] = 'bilinear'
|
| 37 |
transform = create_transform(**config)
|
| 38 |
|
| 39 |
-
|
| 40 |
def query_image(input_img, query, binarize, eval_threshold):
|
| 41 |
|
| 42 |
PIL_image = Image.fromarray(input_img, "RGB")
|
|
@@ -49,10 +49,10 @@ def query_image(input_img, query, binarize, eval_threshold):
|
|
| 49 |
img = visualize_output(img, output, binarize, eval_threshold)
|
| 50 |
return img
|
| 51 |
|
| 52 |
-
|
| 53 |
description = """
|
| 54 |
Gradio demo for an object detection architecture,
|
| 55 |
-
introduced in <a href="https://
|
| 56 |
\n\nLorem ipsum ....
|
| 57 |
*"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
|
| 58 |
"""
|
|
@@ -67,6 +67,9 @@ demo = gr.Interface(
|
|
| 67 |
],
|
| 68 |
allow_flagging = "never",
|
| 69 |
cache_examples=False,
|
|
|
|
|
|
|
|
|
|
| 70 |
)
|
| 71 |
demo.launch(debug=True)
|
| 72 |
|
|
|
|
| 13 |
from timm.data.transforms_factory import create_transform
|
| 14 |
|
| 15 |
from model import Model
|
| 16 |
+
from output import visualize_output
|
| 17 |
|
| 18 |
|
| 19 |
# Use GPU if available
|
|
|
|
| 30 |
state = torch.load('saved_model', map_location=torch.device('cpu'))
|
| 31 |
model.load_state_dict(state['val_model_dict'])
|
| 32 |
|
| 33 |
+
# Create transform for input image
|
| 34 |
config = resolve_data_config({}, model=vit)
|
| 35 |
config['no_aug'] = True
|
| 36 |
config['interpolation'] = 'bilinear'
|
| 37 |
transform = create_transform(**config)
|
| 38 |
|
| 39 |
+
# Inference function
|
| 40 |
def query_image(input_img, query, binarize, eval_threshold):
|
| 41 |
|
| 42 |
PIL_image = Image.fromarray(input_img, "RGB")
|
|
|
|
| 49 |
img = visualize_output(img, output, binarize, eval_threshold)
|
| 50 |
return img
|
| 51 |
|
| 52 |
+
# Gradio interface
|
| 53 |
description = """
|
| 54 |
Gradio demo for an object detection architecture,
|
| 55 |
+
introduced in <a href="https://www.google.com/">my bachelor thesis (link will be added)</a>.
|
| 56 |
\n\nLorem ipsum ....
|
| 57 |
*"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
|
| 58 |
"""
|
|
|
|
| 67 |
],
|
| 68 |
allow_flagging = "never",
|
| 69 |
cache_examples=False,
|
| 70 |
+
css = """
|
| 71 |
+
body {background-color : grey}
|
| 72 |
+
""",
|
| 73 |
)
|
| 74 |
demo.launch(debug=True)
|
| 75 |
|
model.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
|
| 4 |
-
|
| 5 |
class Model(nn.Module):
|
| 6 |
def __init__(self, vit, roberta, tokenizer, device):
|
| 7 |
super().__init__()
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
|
|
|
|
| 4 |
class Model(nn.Module):
|
| 5 |
def __init__(self, vit, roberta, tokenizer, device):
|
| 6 |
super().__init__()
|
get_output.py → output.py
RENAMED
|
@@ -25,6 +25,7 @@ def enlarge_array(output):
|
|
| 25 |
|
| 26 |
return output
|
| 27 |
|
|
|
|
| 28 |
def visualize_output(image, output, binarize, threshold):
|
| 29 |
|
| 30 |
image, output = preprocess(image, output, binarize, threshold)
|
|
@@ -35,9 +36,9 @@ def visualize_output(image, output, binarize, threshold):
|
|
| 35 |
plt.axis('off')
|
| 36 |
plt.imshow(image)
|
| 37 |
if binarize:
|
| 38 |
-
plt.imshow(output_mask, alpha=.
|
| 39 |
else:
|
| 40 |
-
plt.imshow(output_mask, alpha=.
|
| 41 |
fig.tight_layout(pad=0)
|
| 42 |
fig.canvas.draw()
|
| 43 |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
|
|
|
| 25 |
|
| 26 |
return output
|
| 27 |
|
| 28 |
+
|
| 29 |
def visualize_output(image, output, binarize, threshold):
|
| 30 |
|
| 31 |
image, output = preprocess(image, output, binarize, threshold)
|
|
|
|
| 36 |
plt.axis('off')
|
| 37 |
plt.imshow(image)
|
| 38 |
if binarize:
|
| 39 |
+
plt.imshow(output_mask, alpha=.5)
|
| 40 |
else:
|
| 41 |
+
plt.imshow(output_mask, alpha=.6)
|
| 42 |
fig.tight_layout(pad=0)
|
| 43 |
fig.canvas.draw()
|
| 44 |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|