Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| os.environ["KERAS_BACKEND"] = "jax" | |
| import gradio as gr | |
| import jax | |
| import numpy as np | |
| import spaces | |
| from PIL import Image | |
| from zea import init_device | |
| from main import Config, init, run | |
| from utils import load_image | |
| CONFIG_PATH = "configs/semantic_dps.yaml" | |
| SLIDER_CONFIG_PATH = "configs/slider_params.yaml" | |
| ASSETS_DIR = "assets" | |
| DEVICE = None | |
| STATUS_STYLE_LOAD = "display:flex;align-items:center;justify-content:center;padding:40px 10px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.5;align-items:center;" | |
| STATUS_STYLE = "display:flex;align-items:center;justify-content:center;padding:18px 18px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.1;align-items:center;" | |
| description = """ | |
| # Cardiac Ultrasound Dehazing with Semantic Diffusion | |
| Select an example image below to see the dehazing algorithm in action. The algorithm was tuned for the DehazingEcho2025 challenge dataset, so be wary of using it on other datasets. | |
| Tip: Adjust "Omega (Ventricle)" and "Eta (haze prior)" to control the dehazing effect. | |
| """ | |
| # Model and config will be loaded after UI is rendered | |
| config, diffusion_model = None, None | |
| model_loaded = False | |
| def initialize_model(): | |
| global config, diffusion_model, model_loaded | |
| if config is None or diffusion_model is None: | |
| config = Config.from_yaml(CONFIG_PATH) | |
| diffusion_model = init(config) | |
| # Warm-up: run a dummy inference to initialize weights, JIT, etc. | |
| h, w = diffusion_model.input_shape[:2] | |
| dummy_img = np.zeros((1, h, w), dtype=np.float32) | |
| params = config.params | |
| guidance_kwargs = { | |
| "omega": params["guidance_kwargs"]["omega"], | |
| "omega_vent": params["guidance_kwargs"].get("omega_vent", 1.0), | |
| "omega_sept": params["guidance_kwargs"].get("omega_sept", 1.0), | |
| "eta": params["guidance_kwargs"].get("eta", 1.0), | |
| "smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], | |
| } | |
| seed = jax.random.PRNGKey(config.seed) | |
| run( | |
| hazy_images=dummy_img, | |
| diffusion_model=diffusion_model, | |
| seed=seed, | |
| guidance_kwargs=guidance_kwargs, | |
| mask_params=params["mask_params"], | |
| fixed_mask_params=params["fixed_mask_params"], | |
| skeleton_params=params["skeleton_params"], | |
| batch_size=1, | |
| diffusion_steps=1, | |
| verbose=False, | |
| ) | |
| model_loaded = True | |
| return config, diffusion_model | |
| def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta): | |
| global config, diffusion_model, model_loaded | |
| if not model_loaded: | |
| yield ( | |
| gr.update( | |
| value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⏳ Model is still loading. Please wait...</div>' | |
| ), | |
| None, | |
| ) | |
| return | |
| if input_img is None: | |
| yield ( | |
| gr.update( | |
| value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⚠️ No input image was provided. Please select or upload an image before running.</div>' | |
| ), | |
| None, | |
| ) | |
| return | |
| params = config.params | |
| def _prepare_image(image): | |
| resized = False | |
| if image.mode != "L": | |
| image = image.convert("L") | |
| orig_shape = image.size[::-1] | |
| h, w = diffusion_model.input_shape[:2] | |
| if image.size != (w, h): | |
| image = image.resize((w, h), Image.BILINEAR) | |
| resized = True | |
| image = np.array(image) | |
| image = image.astype(np.float32) | |
| image = image[None, ...] | |
| return image, resized, orig_shape | |
| try: | |
| image, resized, orig_shape = _prepare_image(input_img) | |
| except Exception as e: | |
| yield ( | |
| gr.update( | |
| value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error preparing input image: {e}</div>' | |
| ), | |
| None, | |
| ) | |
| return | |
| guidance_kwargs = { | |
| "omega": omega, | |
| "omega_vent": omega_vent, | |
| "omega_sept": omega_sept, | |
| "eta": eta, | |
| "smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], | |
| } | |
| seed = jax.random.PRNGKey(config.seed) | |
| try: | |
| yield ( | |
| gr.update( | |
| value=f'<div style="background:#cce5ff;{STATUS_STYLE}color:#004085;">🌀 Running dehazing algorithm...</div>' | |
| ), | |
| None, | |
| ) | |
| _, pred_tissue_images, *_ = run( | |
| hazy_images=image, | |
| diffusion_model=diffusion_model, | |
| seed=seed, | |
| guidance_kwargs=guidance_kwargs, | |
| mask_params=params["mask_params"], | |
| fixed_mask_params=params["fixed_mask_params"], | |
| skeleton_params=params["skeleton_params"], | |
| batch_size=1, | |
| diffusion_steps=diffusion_steps, | |
| threshold_output_quantile=params.get("threshold_output_quantile", None), | |
| preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0), | |
| bottom_transition_width=params.get("bottom_transition_width", 10.0), | |
| verbose=False, | |
| ) | |
| except Exception as e: | |
| yield ( | |
| gr.update( | |
| value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ The algorithm failed to process the image: {e}</div>' | |
| ), | |
| None, | |
| ) | |
| return | |
| out_img = np.squeeze(pred_tissue_images[0]) | |
| out_img = np.clip(out_img, 0, 255).astype(np.uint8) | |
| out_pil = Image.fromarray(out_img) | |
| if resized and out_pil.size != (orig_shape[1], orig_shape[0]): | |
| out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR) | |
| yield ( | |
| gr.update( | |
| value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Done!</div>' | |
| ), | |
| (input_img, out_pil), | |
| ) | |
| slider_params = Config.from_yaml(SLIDER_CONFIG_PATH) | |
| diffusion_steps_default = slider_params["diffusion_steps"]["default"] | |
| diffusion_steps_min = slider_params["diffusion_steps"]["min"] | |
| diffusion_steps_max = slider_params["diffusion_steps"]["max"] | |
| diffusion_steps_step = slider_params["diffusion_steps"]["step"] | |
| omega_default = slider_params["omega"]["default"] | |
| omega_min = slider_params["omega"]["min"] | |
| omega_max = slider_params["omega"]["max"] | |
| omega_step = slider_params["omega"]["step"] | |
| omega_vent_default = slider_params["omega_vent"]["default"] | |
| omega_vent_min = slider_params["omega_vent"]["min"] | |
| omega_vent_max = slider_params["omega_vent"]["max"] | |
| omega_vent_step = slider_params["omega_vent"]["step"] | |
| omega_sept_default = slider_params["omega_sept"]["default"] | |
| omega_sept_min = slider_params["omega_sept"]["min"] | |
| omega_sept_max = slider_params["omega_sept"]["max"] | |
| omega_sept_step = slider_params["omega_sept"]["step"] | |
| eta_default = slider_params["eta"]["default"] | |
| eta_min = slider_params["eta"]["min"] | |
| eta_max = slider_params["eta"]["max"] | |
| eta_step = slider_params["eta"]["step"] | |
| example_image_paths = [ | |
| os.path.join(ASSETS_DIR, f) | |
| for f in os.listdir(ASSETS_DIR) | |
| if f.lower().endswith(".png") | |
| ] | |
| example_images = [load_image(p) for p in example_image_paths] | |
| examples = [[img] for img in example_images] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(description) | |
| status = gr.Markdown( | |
| f'<div style="background:#ffeeba;{STATUS_STYLE_LOAD}color:#856404;">⏳ Loading model...</div>', | |
| visible=True, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img1 = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| webcam_options=False, | |
| value=example_images[0] if example_images else None, | |
| ) | |
| gr.Examples(examples=examples, inputs=[img1]) | |
| with gr.Column(): | |
| img2 = gr.ImageSlider(label="Dehazed Image", type="pil") | |
| with gr.Row(): | |
| diffusion_steps_slider = gr.Slider( | |
| minimum=diffusion_steps_min, | |
| maximum=diffusion_steps_max, | |
| step=diffusion_steps_step, | |
| value=diffusion_steps_default, | |
| label="Diffusion Steps", | |
| ) | |
| omega_slider = gr.Slider( | |
| minimum=omega_min, | |
| maximum=omega_max, | |
| step=omega_step, | |
| value=omega_default, | |
| label="Omega (background)", | |
| ) | |
| omega_vent_slider = gr.Slider( | |
| minimum=omega_vent_min, | |
| maximum=omega_vent_max, | |
| step=omega_vent_step, | |
| value=omega_vent_default, | |
| label="Omega Ventricle", | |
| ) | |
| omega_sept_slider = gr.Slider( | |
| minimum=omega_sept_min, | |
| maximum=omega_sept_max, | |
| step=omega_sept_step, | |
| value=omega_sept_default, | |
| label="Omega Septum", | |
| ) | |
| eta_slider = gr.Slider( | |
| minimum=eta_min, | |
| maximum=eta_max, | |
| step=eta_step, | |
| value=eta_default, | |
| label="Eta (haze prior)", | |
| ) | |
| run_btn = gr.Button("Run", interactive=False) | |
| run_btn.click( | |
| process_image, | |
| inputs=[ | |
| img1, | |
| diffusion_steps_slider, | |
| omega_slider, | |
| omega_vent_slider, | |
| omega_sept_slider, | |
| eta_slider, | |
| ], | |
| outputs=[status, img2], | |
| queue=True, | |
| ) | |
| def load_model_event(): | |
| global config, diffusion_model, model_loaded, DEVICE | |
| try: | |
| if DEVICE is None: | |
| DEVICE = init_device() | |
| config, diffusion_model = initialize_model() | |
| ready_msg = gr.update( | |
| value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Model loaded! You can now press Run.</div>' | |
| ) | |
| return ready_msg, gr.update(interactive=True) | |
| except Exception as e: | |
| return gr.update( | |
| value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error loading model: {e}</div>' | |
| ), gr.update(interactive=False) | |
| demo.load( | |
| load_model_event, | |
| inputs=None, | |
| outputs=[status, run_btn], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |