import gradio as gr import subprocess import os import shutil from pathlib import Path import spaces # import the updated recursive_multiscale_sr that expects a list of centers from inference_coz_single import recursive_multiscale_sr from PIL import Image, ImageDraw # ------------------------------------------------------------------ # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE # ------------------------------------------------------------------ INPUT_DIR = "samples" OUTPUT_DIR = "inference_results/coz_vlmprompt" # ------------------------------------------------------------------ # HELPER: Resize & center-crop to 512, preserving aspect ratio # ------------------------------------------------------------------ def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: """ Resize the input PIL image so that its shorter side == `size`, then center-crop to exactly (size x size). """ w, h = img.size scale = size / min(w, h) new_w, new_h = int(w * scale), int(h * scale) img = img.resize((new_w, new_h), Image.LANCZOS) left = (new_w - size) // 2 top = (new_h - size) // 2 return img.crop((left, top, left + size, top + size)) # ------------------------------------------------------------------ # HELPER: Draw four true “nested” rectangles, matching the SR logic # ------------------------------------------------------------------ def make_preview_with_boxes( image_path: str, scale_option: str, cx_norm: float, cy_norm: float, ) -> tuple[Image.Image, list[tuple[float, float]]]: """ Returns: - The preview image with drawn boxes. - A list of (cx_norm, cy_norm) for each box (normalized to 512×512). """ try: orig = Image.open(image_path).convert("RGB") except Exception as e: fallback = Image.new("RGB", (512, 512), (200, 200, 200)) ImageDraw.Draw(fallback).text((20, 20), f"Error:\n{e}", fill="red") return fallback, [] base = resize_and_center_crop(orig, 512) scale_int = int(scale_option.replace("x", "")) if scale_int <= 1: sizes = [512.0, 512.0, 512.0, 512.0] else: sizes = [512.0 / (scale_int ** (i + 1)) for i in range(4)] draw = ImageDraw.Draw(base) colors = ["red", "lime", "cyan", "yellow"] width = 3 abs_cx = cx_norm * 512.0 abs_cy = cy_norm * 512.0 prev_x0, prev_y0, prev_size = 0.0, 0.0, 512.0 centers: list[tuple[float, float]] = [] for i, crop_size in enumerate(sizes): x0 = abs_cx - (crop_size / 2.0) y0 = abs_cy - (crop_size / 2.0) min_x0 = prev_x0 max_x0 = prev_x0 + prev_size - crop_size min_y0 = prev_y0 max_y0 = prev_y0 + prev_size - crop_size x0 = max(min_x0, min(x0, max_x0)) y0 = max(min_y0, min(y0, max_y0)) x1 = x0 + crop_size y1 = y0 + crop_size draw.rectangle([(int(round(x0)), int(round(y0))), (int(round(x1)), int(round(y1)))], outline=colors[i % len(colors)], width=width) # --- compute normalized center of this box --- cx_box = ((x0 - prev_x0) + crop_size / 2.0) / float(prev_size) cy_box = ((y0 - prev_y0) + crop_size / 2.0) / float(prev_size) centers.append((cx_box, cy_box)) prev_x0, prev_y0, prev_size = x0, y0, crop_size return base, centers # ------------------------------------------------------------------ # HELPER FUNCTION FOR INFERENCE (build a list of identical centers) # ------------------------------------------------------------------ @spaces.GPU() def run_with_upload( uploaded_image_path: str, upscale_option: str, cx_norm: float, cy_norm: float, ): """ Perform chain-of-zoom super-resolution on a given image, using recursive multi-scale upscaling centered on a specific point. This function enhances a given image by progressively zooming into a specific point, using a recursive deep super-resolution model. Args: uploaded_image_path (str): Path to the input image file on disk. upscale_option (str): The desired upscale factor as a string. Valid options are "1x", "2x", and "4x". - "1x" means no upscaling. - "2x" means 2× enlargement per zoom step. - "4x" means 4× enlargement per zoom step. cx_norm (float): Normalized X-coordinate (0 to 1) of the zoom center. cy_norm (float): Normalized Y-coordinate (0 to 1) of the zoom center. Returns: list[PIL.Image.Image]: A list of progressively zoomed-in and super-resolved images at each recursion step (typically 4), centered around the user-specified point. Note: The center point is repeated for each recursion level to maintain consistency during zooming. This function uses a modified version of the `recursive_multiscale_sr` pipeline for inference. """ if uploaded_image_path is None: return [] upscale_value = int(upscale_option.replace("x", "")) rec_num = 4 # match the SR pipeline’s default recursion depth centers = [(cx_norm, cy_norm)] * rec_num # Call the modified SR function sr_list, _ = recursive_multiscale_sr( uploaded_image_path, upscale=upscale_value, rec_num=rec_num, centers=centers, ) # Return the list of PIL images (Gradio Gallery expects a list) return sr_list @spaces.GPU() def magnify( uploaded_image_path: str, upscale_option: str, centres: list ): """ Perform chain-of-zoom super-resolution on a given image, using recursive multi-scale upscaling centered on a specific point. This function enhances a given image by progressively zooming into a specific point, using a recursive deep super-resolution model. Args: uploaded_image_path (str): Path to the input image file on disk. upscale_option (str): The desired upscale factor as a string. Valid options are "1x", "2x", and "4x". - "1x" means no upscaling. - "2x" means 2× enlargement per zoom step. - "4x" means 4× enlargement per zoom step. centres (list): Normalized list of X-coordinate, Y-coordinate (0 to 1) of the zoom center. Returns: list[PIL.Image.Image]: A list of progressively zoomed-in and super-resolved images at each recursion step (typically 4), centered around the user-specified point. Note: The center point is repeated for each recursion level to maintain consistency during zooming. This function uses a modified version of the `recursive_multiscale_sr` pipeline for inference. """ if uploaded_image_path is None: return [] upscale_value = int(upscale_option.replace("x", "")) rec_num = len(centres) # Call the modified SR function sr_list, _ = recursive_multiscale_sr( uploaded_image_path, upscale=upscale_value, rec_num=rec_num, centers=centres, ) # Return the list of PIL images (Gradio Gallery expects a list) return sr_list # ------------------------------------------------------------------ # BUILD THE GRADIO INTERFACE (two sliders + correct preview) # ------------------------------------------------------------------ css = """ #col-container { margin: 0 auto; max-width: 1024px; } """ with gr.Blocks(css=css) as demo: session_centres = gr.State() with gr.Column(elem_id="col-container"): gr.HTML( """

Chain-of-Zoom – Extreme Super-Resolution via Scale Autoregression and Preference Alignment

[Github]
HF Space by: GitHub Repo
""" ) with gr.Row(): with gr.Column(): # 1) Image upload component upload_image = gr.Image( label="Input image", type="filepath" ) # 2) Radio for choosing 1× / 2× / 4× upscaling upscale_radio = gr.Radio( choices=["1x", "2x", "4x"], value="2x", show_label=False ) # 3) Two sliders for normalized center (0..1) center_x = gr.Slider( label="Center X (normalized)", minimum=0.0, maximum=1.0, step=0.01, value=0.5 ) center_y = gr.Slider( label="Center Y (normalized)", minimum=0.0, maximum=1.0, step=0.01, value=0.5 ) # 4) Button to launch inference run_button = gr.Button("🔎 Chain-of-Zoom it", variant="primary") gr.Markdown("*Click anywhere on the preview image to select coordinates to zoom*") # 5) Preview (512×512 + four truly nested boxes) preview_with_box = gr.Image( label="Preview", type="pil", interactive=False ) with gr.Column(): # 6) Gallery to display multiple output images output_gallery = gr.Gallery( label="Inference Results", show_label=True, elem_id="gallery", columns=[2], rows=[2] ) examples = gr.Examples( # List of example-rows. Each row is [input_image, scale, cx, cy] examples=[["samples/0479.png", "4x", 0.5, 0.5], ["samples/0064.png", "4x", 0.5, 0.5], ["samples/0245.png", "4x", 0.5, 0.5], ["samples/0393.png", "4x", 0.5, 0.5]], inputs=[upload_image, upscale_radio, center_x, center_y], outputs=[output_gallery], fn=run_with_upload, cache_examples=True ) # ------------------------------------------------------------------ # CALLBACK #1: update the preview whenever inputs change # ------------------------------------------------------------------ def update_preview( img_path: str, scale_opt: str, cx: float, cy: float ) -> Image.Image | None: """ If no image uploaded, show blank; otherwise, draw four nested boxes exactly as the SR pipeline would crop at each recursion. """ if img_path is None: return None, [] return make_preview_with_boxes(img_path, scale_opt, cx, cy) def get_select_coords(input_img, evt: gr.SelectData): print("coordinates selected") i = evt.index[1] j = evt.index[0] w, h = input_img.size return gr.update(value=j/w), gr.update(value=i/h) preview_with_box.select(get_select_coords, [preview_with_box], [center_x, center_y]) upload_image.change( fn=update_preview, inputs=[upload_image, upscale_radio, center_x, center_y], outputs=[preview_with_box, session_centres], show_api=False ) upscale_radio.change( fn=update_preview, inputs=[upload_image, upscale_radio, center_x, center_y], outputs=[preview_with_box, session_centres], show_api=False ) center_x.change( fn=update_preview, inputs=[upload_image, upscale_radio, center_x, center_y], outputs=[preview_with_box, session_centres], show_api=False ) center_y.change( fn=update_preview, inputs=[upload_image, upscale_radio, center_x, center_y], outputs=[preview_with_box, session_centres], show_api=False ) # ------------------------------------------------------------------ # CALLBACK #2: on button‐click, run the SR pipeline # ------------------------------------------------------------------ run_button.click( fn=magnify, inputs=[upload_image, upscale_radio, session_centres], outputs=[output_gallery] ) # ------------------------------------------------------------------ # START THE GRADIO SERVER # ------------------------------------------------------------------ demo.queue() demo.launch(share=True, mcp_server=True)