Spaces:
Running
Running
| import os | |
| import re | |
| import shutil | |
| import time | |
| from types import SimpleNamespace | |
| from typing import Any | |
| import gradio as gr | |
| import numpy as np | |
| from detectron2 import engine | |
| from PIL import Image | |
| from inference import main, setup_cfg | |
| # internal settings | |
| NUM_PROCESSES = 1 | |
| CROP = False | |
| SCORE_THRESHOLD = 0.8 | |
| MAX_PARTS = 5 | |
| ARGS = SimpleNamespace( | |
| config_file="configs/coco/instance-segmentation/swin/opd_v1_real.yaml", | |
| model=".data/models/motion_state_pred_opdformerp_rgb.pth", | |
| input_format="RGB", | |
| output=".output", | |
| cpu=True, | |
| ) | |
| NUM_SAMPLES = 10 | |
| outputs = [] | |
| def predict(rgb_image: str, depth_image: str, intrinsics: np.ndarray, num_samples: int) -> list[Any]: | |
| global outputs | |
| def find_gifs(path: str) -> list[str]: | |
| """Scrape folders for all generated gif files.""" | |
| for file in os.listdir(path): | |
| sub_path = os.path.join(path, file) | |
| if os.path.isdir(sub_path): | |
| for image_file in os.listdir(sub_path): | |
| if re.match(r".*\.gif$", image_file): | |
| yield os.path.join(sub_path, image_file) | |
| def find_images(path: str) -> list[str]: | |
| """Scrape folders for all generated gif files.""" | |
| images = {} | |
| for file in os.listdir(path): | |
| sub_path = os.path.join(path, file) | |
| if os.path.isdir(sub_path): | |
| images[file] = [] | |
| for image_file in sorted(os.listdir(sub_path)): | |
| if re.match(r".*\.png$", image_file): | |
| images[file].append(os.path.join(sub_path, image_file)) | |
| return images | |
| # clear old predictions | |
| os.makedirs(ARGS.output, exist_ok=True) | |
| for path in os.listdir(ARGS.output): | |
| full_path = os.path.join(ARGS.output, path) | |
| if os.path.isdir(full_path): | |
| shutil.rmtree(full_path) | |
| else: | |
| os.remove(full_path) | |
| cfg = setup_cfg(ARGS) | |
| engine.launch( | |
| main, | |
| NUM_PROCESSES, | |
| args=( | |
| cfg, | |
| rgb_image, | |
| depth_image, | |
| intrinsics, | |
| num_samples, | |
| CROP, | |
| SCORE_THRESHOLD, | |
| ), | |
| ) | |
| # process output | |
| # TODO: may want to select these in decreasing order of score | |
| image_files = find_images(ARGS.output) | |
| outputs = [] | |
| for count, part in enumerate(image_files): | |
| if count < MAX_PARTS: | |
| outputs.append([Image.open(im) for im in image_files[part]]) | |
| return [ | |
| *[gr.update(value=out[0], visible=True) for out in outputs], | |
| *[gr.update(visible=False) for _ in range(MAX_PARTS - len(outputs))], | |
| ] | |
| def get_trigger(idx: int, fps: int = 40, oscillate: bool = True): | |
| def iter_images(*args, **kwargs): | |
| if idx < len(outputs): | |
| for im in outputs[idx]: | |
| time.sleep(1.0 / fps) | |
| yield im | |
| if oscillate: | |
| for im in reversed(outputs[idx]): | |
| time.sleep(1.0 / fps) | |
| yield im | |
| else: | |
| raise ValueError("Could not find any images to load into this module.") | |
| return iter_images | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # OPDMulti Demo | |
| Upload an image to see its range of motion. | |
| """ | |
| ) | |
| # TODO: add gr.Examples | |
| with gr.Row(): | |
| rgb_image = gr.Image( | |
| image_mode="RGB", source="upload", type="filepath", label="RGB Image", show_label=True, interactive=True | |
| ) | |
| depth_image = gr.Image( | |
| image_mode="I;16", source="upload", type="filepath", label="Depth Image", show_label=True, interactive=True | |
| ) | |
| intrinsics = gr.Dataframe( | |
| value=[ | |
| [ | |
| 214.85935872395834, | |
| 0.0, | |
| 125.90160319010417, | |
| ], | |
| [ | |
| 0.0, | |
| 214.85935872395834, | |
| 95.13726399739583, | |
| ], | |
| [ | |
| 0.0, | |
| 0.0, | |
| 1.0, | |
| ], | |
| ], | |
| row_count=(3, "fixed"), | |
| col_count=(3, "fixed"), | |
| datatype="number", | |
| type="numpy", | |
| label="Intrinsics matrix", | |
| show_label=True, | |
| interactive=True, | |
| ) | |
| num_samples = gr.Number( | |
| value=NUM_SAMPLES, | |
| label="Number of samples", | |
| show_label=True, | |
| interactive=True, | |
| precision=0, | |
| minimum=3, | |
| maximum=20, | |
| ) | |
| examples = gr.Examples( | |
| examples=[ | |
| ["examples/59-4860.png", "examples/59-4860_d.png"], | |
| ["examples/174-8460.png", "examples/174-8460_d.png"], | |
| ["examples/187-0.png", "examples/187-0_d.png"], | |
| ["examples/187-23040.png", "examples/187-23040_d.png"], | |
| ], | |
| inputs=[rgb_image, depth_image], | |
| api_name=False, | |
| examples_per_page=2, | |
| ) | |
| submit_btn = gr.Button("Run model") | |
| # TODO: do we want to set a maximum limit on how many parts we render? We could also show the number of components | |
| # identified. | |
| images = [gr.Image(type="pil", label=f"Part {idx + 1}", visible=False) for idx in range(MAX_PARTS)] | |
| for idx, image_comp in enumerate(images): | |
| image_comp.select(get_trigger(idx), inputs=[], outputs=image_comp, api_name=False) | |
| submit_btn.click( | |
| fn=predict, inputs=[rgb_image, depth_image, intrinsics, num_samples], outputs=images, api_name=False | |
| ) | |
| demo.queue(api_open=False) | |
| demo.launch() | |