Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import subprocess | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| import spaces | |
| # import the updated recursive_multiscale_sr that expects a list of centers | |
| from inference_coz_single import recursive_multiscale_sr | |
| from PIL import Image, ImageDraw | |
| # ------------------------------------------------------------------ | |
| # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE | |
| # ------------------------------------------------------------------ | |
| INPUT_DIR = "samples" | |
| OUTPUT_DIR = "inference_results/coz_vlmprompt" | |
| # ------------------------------------------------------------------ | |
| # HELPER: Resize & center-crop to 512, preserving aspect ratio | |
| # ------------------------------------------------------------------ | |
| def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: | |
| """ | |
| Resize the input PIL image so that its shorter side == `size`, | |
| then center-crop to exactly (size x size). | |
| """ | |
| w, h = img.size | |
| scale = size / min(w, h) | |
| new_w, new_h = int(w * scale), int(h * scale) | |
| img = img.resize((new_w, new_h), Image.LANCZOS) | |
| left = (new_w - size) // 2 | |
| top = (new_h - size) // 2 | |
| return img.crop((left, top, left + size, top + size)) | |
| # ------------------------------------------------------------------ | |
| # HELPER: Draw four true “nested” rectangles, matching the SR logic | |
| # ------------------------------------------------------------------ | |
| def make_preview_with_boxes( | |
| image_path: str, | |
| scale_option: str, | |
| cx_norm: float, | |
| cy_norm: float, | |
| ) -> tuple[Image.Image, list[tuple[float, float]]]: | |
| """ | |
| Returns: | |
| - The preview image with drawn boxes. | |
| - A list of (cx_norm, cy_norm) for each box (normalized to 512×512). | |
| """ | |
| try: | |
| orig = Image.open(image_path).convert("RGB") | |
| except Exception as e: | |
| fallback = Image.new("RGB", (512, 512), (200, 200, 200)) | |
| ImageDraw.Draw(fallback).text((20, 20), f"Error:\n{e}", fill="red") | |
| return fallback, [] | |
| base = resize_and_center_crop(orig, 512) | |
| scale_int = int(scale_option.replace("x", "")) | |
| if scale_int <= 1: | |
| sizes = [512.0, 512.0, 512.0, 512.0] | |
| else: | |
| sizes = [512.0 / (scale_int ** (i + 1)) for i in range(4)] | |
| draw = ImageDraw.Draw(base) | |
| colors = ["red", "lime", "cyan", "yellow"] | |
| width = 3 | |
| abs_cx = cx_norm * 512.0 | |
| abs_cy = cy_norm * 512.0 | |
| prev_x0, prev_y0, prev_size = 0.0, 0.0, 512.0 | |
| centers: list[tuple[float, float]] = [] | |
| for i, crop_size in enumerate(sizes): | |
| x0 = abs_cx - (crop_size / 2.0) | |
| y0 = abs_cy - (crop_size / 2.0) | |
| min_x0 = prev_x0 | |
| max_x0 = prev_x0 + prev_size - crop_size | |
| min_y0 = prev_y0 | |
| max_y0 = prev_y0 + prev_size - crop_size | |
| x0 = max(min_x0, min(x0, max_x0)) | |
| y0 = max(min_y0, min(y0, max_y0)) | |
| x1 = x0 + crop_size | |
| y1 = y0 + crop_size | |
| draw.rectangle([(int(round(x0)), int(round(y0))), | |
| (int(round(x1)), int(round(y1)))], | |
| outline=colors[i % len(colors)], width=width) | |
| # --- compute normalized center of this box --- | |
| cx_box = ((x0 - prev_x0) + crop_size / 2.0) / float(prev_size) | |
| cy_box = ((y0 - prev_y0) + crop_size / 2.0) / float(prev_size) | |
| centers.append((cx_box, cy_box)) | |
| prev_x0, prev_y0, prev_size = x0, y0, crop_size | |
| return base, centers | |
| # ------------------------------------------------------------------ | |
| # HELPER FUNCTION FOR INFERENCE (build a list of identical centers) | |
| # ------------------------------------------------------------------ | |
| def run_with_upload( | |
| uploaded_image_path: str, | |
| upscale_option: str, | |
| cx_norm: float, | |
| cy_norm: float, | |
| ): | |
| """ | |
| Perform chain-of-zoom super-resolution on a given image, using recursive multi-scale upscaling centered on a specific point. | |
| This function enhances a given image by progressively zooming into a specific point, using a recursive deep super-resolution model. | |
| Args: | |
| uploaded_image_path (str): Path to the input image file on disk. | |
| upscale_option (str): The desired upscale factor as a string. Valid options are "1x", "2x", and "4x". | |
| - "1x" means no upscaling. | |
| - "2x" means 2× enlargement per zoom step. | |
| - "4x" means 4× enlargement per zoom step. | |
| cx_norm (float): Normalized X-coordinate (0 to 1) of the zoom center. | |
| cy_norm (float): Normalized Y-coordinate (0 to 1) of the zoom center. | |
| Returns: | |
| list[PIL.Image.Image]: A list of progressively zoomed-in and super-resolved images at each recursion step (typically 4), | |
| centered around the user-specified point. | |
| Note: | |
| The center point is repeated for each recursion level to maintain consistency during zooming. | |
| This function uses a modified version of the `recursive_multiscale_sr` pipeline for inference. | |
| """ | |
| if uploaded_image_path is None: | |
| return [] | |
| upscale_value = int(upscale_option.replace("x", "")) | |
| rec_num = 4 # match the SR pipeline’s default recursion depth | |
| centers = [(cx_norm, cy_norm)] * rec_num | |
| # Call the modified SR function | |
| sr_list, _ = recursive_multiscale_sr( | |
| uploaded_image_path, | |
| upscale=upscale_value, | |
| rec_num=rec_num, | |
| centers=centers, | |
| ) | |
| # Return the list of PIL images (Gradio Gallery expects a list) | |
| return sr_list | |
| def magnify( | |
| uploaded_image_path: str, | |
| upscale_option: str, | |
| centres: list | |
| ): | |
| """ | |
| Perform chain-of-zoom super-resolution on a given image, using recursive multi-scale upscaling centered on a specific point. | |
| This function enhances a given image by progressively zooming into a specific point, using a recursive deep super-resolution model. | |
| Args: | |
| uploaded_image_path (str): Path to the input image file on disk. | |
| upscale_option (str): The desired upscale factor as a string. Valid options are "1x", "2x", and "4x". | |
| - "1x" means no upscaling. | |
| - "2x" means 2× enlargement per zoom step. | |
| - "4x" means 4× enlargement per zoom step. | |
| centres (list): Normalized list of X-coordinate, Y-coordinate (0 to 1) of the zoom center. | |
| Returns: | |
| list[PIL.Image.Image]: A list of progressively zoomed-in and super-resolved images at each recursion step (typically 4), | |
| centered around the user-specified point. | |
| Note: | |
| The center point is repeated for each recursion level to maintain consistency during zooming. | |
| This function uses a modified version of the `recursive_multiscale_sr` pipeline for inference. | |
| """ | |
| if uploaded_image_path is None: | |
| return [] | |
| upscale_value = int(upscale_option.replace("x", "")) | |
| rec_num = len(centres) | |
| # Call the modified SR function | |
| sr_list, _ = recursive_multiscale_sr( | |
| uploaded_image_path, | |
| upscale=upscale_value, | |
| rec_num=rec_num, | |
| centers=centres, | |
| ) | |
| # Return the list of PIL images (Gradio Gallery expects a list) | |
| return sr_list | |
| # ------------------------------------------------------------------ | |
| # BUILD THE GRADIO INTERFACE (two sliders + correct preview) | |
| # ------------------------------------------------------------------ | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1024px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| session_centres = gr.State() | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML( | |
| """ | |
| <div style="text-align: left;"> | |
| <p style="font-size:16px; display: inline; margin: 0;"> | |
| <strong>Chain-of-Zoom</strong> – Extreme Super-Resolution via Scale Autoregression and Preference Alignment | |
| </p> | |
| <a href="https://github.com/bryanswkim/Chain-of-Zoom" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> | |
| [Github] | |
| </a> | |
| </div> | |
| <div style="text-align: left;"> | |
| <strong>HF Space by:</strong> | |
| <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> | |
| <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo"> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # 1) Image upload component | |
| upload_image = gr.Image( | |
| label="Input image", | |
| type="filepath" | |
| ) | |
| # 2) Radio for choosing 1× / 2× / 4× upscaling | |
| upscale_radio = gr.Radio( | |
| choices=["1x", "2x", "4x"], | |
| value="2x", | |
| show_label=False | |
| ) | |
| # 3) Two sliders for normalized center (0..1) | |
| center_x = gr.Slider( | |
| label="Center X (normalized)", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.5 | |
| ) | |
| center_y = gr.Slider( | |
| label="Center Y (normalized)", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.5 | |
| ) | |
| # 4) Button to launch inference | |
| run_button = gr.Button("🔎 Chain-of-Zoom it", variant="primary") | |
| gr.Markdown("*Click anywhere on the preview image to select coordinates to zoom*") | |
| # 5) Preview (512×512 + four truly nested boxes) | |
| preview_with_box = gr.Image( | |
| label="Preview", | |
| type="pil", | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| # 6) Gallery to display multiple output images | |
| output_gallery = gr.Gallery( | |
| label="Inference Results", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=[2], rows=[2] | |
| ) | |
| examples = gr.Examples( | |
| # List of example-rows. Each row is [input_image, scale, cx, cy] | |
| examples=[["samples/0479.png", "4x", 0.5, 0.5], ["samples/0064.png", "4x", 0.5, 0.5], ["samples/0245.png", "4x", 0.5, 0.5], ["samples/0393.png", "4x", 0.5, 0.5]], | |
| inputs=[upload_image, upscale_radio, center_x, center_y], | |
| outputs=[output_gallery], | |
| fn=run_with_upload, | |
| cache_examples=True | |
| ) | |
| # ------------------------------------------------------------------ | |
| # CALLBACK #1: update the preview whenever inputs change | |
| # ------------------------------------------------------------------ | |
| def update_preview( | |
| img_path: str, | |
| scale_opt: str, | |
| cx: float, | |
| cy: float | |
| ) -> Image.Image | None: | |
| """ | |
| If no image uploaded, show blank; otherwise, draw four nested boxes | |
| exactly as the SR pipeline would crop at each recursion. | |
| """ | |
| if img_path is None: | |
| return None, [] | |
| return make_preview_with_boxes(img_path, scale_opt, cx, cy) | |
| def get_select_coords(input_img, evt: gr.SelectData): | |
| print("coordinates selected") | |
| i = evt.index[1] | |
| j = evt.index[0] | |
| w, h = input_img.size | |
| return gr.update(value=j/w), gr.update(value=i/h) | |
| preview_with_box.select(get_select_coords, [preview_with_box], [center_x, center_y]) | |
| upload_image.change( | |
| fn=update_preview, | |
| inputs=[upload_image, upscale_radio, center_x, center_y], | |
| outputs=[preview_with_box, session_centres], | |
| show_api=False | |
| ) | |
| upscale_radio.change( | |
| fn=update_preview, | |
| inputs=[upload_image, upscale_radio, center_x, center_y], | |
| outputs=[preview_with_box, session_centres], | |
| show_api=False | |
| ) | |
| center_x.change( | |
| fn=update_preview, | |
| inputs=[upload_image, upscale_radio, center_x, center_y], | |
| outputs=[preview_with_box, session_centres], | |
| show_api=False | |
| ) | |
| center_y.change( | |
| fn=update_preview, | |
| inputs=[upload_image, upscale_radio, center_x, center_y], | |
| outputs=[preview_with_box, session_centres], | |
| show_api=False | |
| ) | |
| # ------------------------------------------------------------------ | |
| # CALLBACK #2: on button‐click, run the SR pipeline | |
| # ------------------------------------------------------------------ | |
| run_button.click( | |
| fn=magnify, | |
| inputs=[upload_image, upscale_radio, session_centres], | |
| outputs=[output_gallery] | |
| ) | |
| # ------------------------------------------------------------------ | |
| # START THE GRADIO SERVER | |
| # ------------------------------------------------------------------ | |
| demo.queue() | |
| demo.launch(share=True, mcp_server=True) | |