Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import subprocess | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| from PIL import Image, ImageDraw | |
| import spaces | |
| # ------------------------------------------------------------------ | |
| # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE | |
| # ------------------------------------------------------------------ | |
| INPUT_DIR = "samples" | |
| OUTPUT_DIR = "inference_results/coz_vlmprompt" | |
| # ------------------------------------------------------------------ | |
| # HELPER: Resize & center-crop to 512, preserving aspect ratio | |
| # ------------------------------------------------------------------ | |
| def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: | |
| """ | |
| Resize the input PIL image so that its shorter side == `size`, | |
| then center-crop to exactly (size x size). | |
| """ | |
| w, h = img.size | |
| scale = size / min(w, h) | |
| new_w, new_h = int(w * scale), int(h * scale) | |
| img = img.resize((new_w, new_h), Image.LANCZOS) | |
| left = (new_w - size) // 2 | |
| top = (new_h - size) // 2 | |
| return img.crop((left, top, left + size, top + size)) | |
| # ------------------------------------------------------------------ | |
| # HELPER: Draw four concentric, centered rectangles on a 512Γ512 image | |
| # ------------------------------------------------------------------ | |
| def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image: | |
| """ | |
| 1) Open the uploaded image from disk. | |
| 2) Resize & center-crop it to exactly 512Γ512. | |
| 3) Depending on scale_option ("1x","2x","4x"), compute four rectangle sizes: | |
| - "1x": [512, 512, 512, 512] | |
| - "2x": [256, 128, 64, 32] | |
| - "4x": [128, 64, 32, 16] | |
| 4) Draw each of those four rectangles (outline only), all centered. | |
| 5) Return the modified PIL image. | |
| """ | |
| try: | |
| orig = Image.open(image_path).convert("RGB") | |
| except Exception as e: | |
| # If something fails, return a plain 512Γ512 gray image as fallback | |
| fallback = Image.new("RGB", (512, 512), (200, 200, 200)) | |
| draw = ImageDraw.Draw(fallback) | |
| draw.text((20, 20), f"Error:\n{e}", fill="red") | |
| return fallback | |
| # 1. Resize & center-crop to 512Γ512 | |
| base = resize_and_center_crop(orig, 512) # now `base.size == (512,512)` | |
| # 2. Determine the four box sizes | |
| scale_int = int(scale_option.replace("x", "")) # e.g. "2x" -> 2 | |
| if scale_int == 1: | |
| sizes = [512, 512, 512, 512] | |
| else: | |
| # For scale=2: sizes = [512//2, 512//(2*2), 512//(2*4), 512//(2*8)] -> [256,128,64,32] | |
| # For scale=4: sizes = [512//4, 512//(4*2), 512//(4*4), 512//(4*8)] -> [128,64,32,16] | |
| sizes = [512 // (scale_int * (2 ** i)) for i in range(4)] | |
| draw = ImageDraw.Draw(base) | |
| # 3. Outline color cycle (you can change these or use just one color) | |
| colors = ["red", "lime", "cyan", "yellow"] | |
| width = 3 # thickness of each rectangleβs outline | |
| for idx, s in enumerate(sizes): | |
| # Compute top-left corner so that box is centered in 512Γ512 | |
| x0 = (512 - s) // 2 | |
| y0 = (512 - s) // 2 | |
| x1 = x0 + s | |
| y1 = y0 + s | |
| draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx % len(colors)], width=width) | |
| return base | |
| # ------------------------------------------------------------------ | |
| # HELPER FUNCTIONS FOR INFERENCE & CAPTION (unchanged from your original) | |
| # ------------------------------------------------------------------ | |
| def run_with_upload(uploaded_image_path, upscale_option): | |
| """ | |
| 1) Clear INPUT_DIR | |
| 2) Save the uploaded file as input.png in INPUT_DIR | |
| 3) Read `upscale_option` (e.g. "1x", "2x", "4x") β turn it into "1","2","4" | |
| 4) Call inference_coz.py with `--upscale <that_value>` | |
| 5) Return the FOUR outputβPNG fileβpaths as a Python list, so that Gradio's Gallery | |
| can display them. | |
| """ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # (Copyβpaste exactly your existing code here; no changes needed) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| os.makedirs(INPUT_DIR, exist_ok=True) | |
| for fn in os.listdir(INPUT_DIR): | |
| full_path = os.path.join(INPUT_DIR, fn) | |
| try: | |
| if os.path.isfile(full_path) or os.path.islink(full_path): | |
| os.remove(full_path) | |
| elif os.path.isdir(full_path): | |
| shutil.rmtree(full_path) | |
| except Exception as e: | |
| print(f"Warning: could not delete {full_path}: {e}") | |
| if uploaded_image_path is None: | |
| return [] | |
| try: | |
| pil_img = Image.open(uploaded_image_path).convert("RGB") | |
| except Exception as e: | |
| print(f"Error: could not open uploaded image: {e}") | |
| return [] | |
| save_path = Path(INPUT_DIR) / "input.png" | |
| try: | |
| pil_img.save(save_path, format="PNG") | |
| except Exception as e: | |
| print(f"Error: could not save as PNG: {e}") | |
| return [] | |
| upscale_value = upscale_option.replace("x", "") # e.g. "2x" β "2" | |
| cmd = [ | |
| "python", "inference_coz.py", | |
| "-i", INPUT_DIR, | |
| "-o", OUTPUT_DIR, | |
| "--rec_type", "recursive_multiscale", | |
| "--prompt_type", "vlm", | |
| "--upscale", upscale_value, | |
| "--lora_path", "ckpt/SR_LoRA/model_20001.pkl", | |
| "--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt", | |
| "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", | |
| "--ram_ft_path", "ckpt/DAPE/DAPE.pth", | |
| "--ram_path", "ckpt/RAM/ram_swin_large_14m.pth" | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True) | |
| except subprocess.CalledProcessError as err: | |
| print("Inference failed:", err) | |
| return [] | |
| per_sample_dir = os.path.join(OUTPUT_DIR, "per-sample", "input") | |
| expected_files = [ | |
| os.path.join(per_sample_dir, f"{i}.png") | |
| for i in range(1, 5) | |
| ] | |
| for fp in expected_files: | |
| if not os.path.isfile(fp): | |
| print(f"Warning: expected file not found: {fp}") | |
| return [] | |
| return expected_files | |
| def get_caption(src_gallery, evt: gr.SelectData): | |
| """ | |
| Given a clickedβon image in the gallery, read the corresponding .txt in | |
| .../per-sample/input/txt and return its contents. | |
| """ | |
| if not src_gallery or not os.path.isfile(src_gallery[evt.index][0]): | |
| return "No caption available." | |
| selected_image_path = src_gallery[evt.index][0] | |
| base = os.path.basename(selected_image_path) # e.g. "2.png" | |
| stem = os.path.splitext(base)[0] # e.g. "2" | |
| txt_folder = os.path.join(OUTPUT_DIR, "per-sample", "input", "txt") | |
| txt_path = os.path.join(txt_folder, f"{int(stem) - 1}.txt") | |
| if not os.path.isfile(txt_path): | |
| return f"Caption file not found: {int(stem) - 1}.txt" | |
| try: | |
| with open(txt_path, "r", encoding="utf-8") as f: | |
| caption = f.read().strip() | |
| return caption if caption else "(Caption file is empty.)" | |
| except Exception as e: | |
| return f"Error reading caption: {e}" | |
| # ------------------------------------------------------------------ | |
| # BUILD THE GRADIO INTERFACE (with updated callbacks) | |
| # ------------------------------------------------------------------ | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1024px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center;"> | |
| <h1>Chain-of-Zoom</h1> | |
| <p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment</p> | |
| </div> | |
| <br> | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <a href="https://github.com/bryanswkim/Chain-of-Zoom"> | |
| <img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Column(elem_id="col-container"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| # 1) Image upload component | |
| upload_image = gr.Image( | |
| label="Upload your input image", | |
| type="filepath" | |
| ) | |
| # 2) Radio for choosing 1Γ / 2Γ / 4Γ upscaling | |
| upscale_radio = gr.Radio( | |
| choices=["1x", "2x", "4x"], | |
| value="2x", | |
| show_label=False | |
| ) | |
| # 3) Button to launch inference | |
| run_button = gr.Button("Chain-of-Zoom it") | |
| # 4) Show the 512Γ512 preview with four centered rectangles | |
| preview_with_box = gr.Image( | |
| label="Preview (512Γ512 with centered boxes)", | |
| type="pil", # weβll return a PIL.Image from our function | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| # 5) Gallery to display multiple output images | |
| output_gallery = gr.Gallery( | |
| label="Inference Results", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=[2], rows=[2] | |
| ) | |
| # 6) Textbox under the gallery for showing captions | |
| caption_text = gr.Textbox( | |
| label="Caption", | |
| lines=4, | |
| placeholder="Click on any image above to see its caption here." | |
| ) | |
| # ------------------------------------------------------------------ | |
| # CALLBACK #1: Whenever the user uploads or changes the radio, update preview | |
| # ------------------------------------------------------------------ | |
| def update_preview(img_path, scale_opt): | |
| """ | |
| If there's no image uploaded yet, return None (Gradio will show blank). | |
| Otherwise, draw the resized 512Γ512 + four boxes and return it. | |
| """ | |
| if img_path is None: | |
| return None | |
| return make_preview_with_boxes(img_path, scale_opt) | |
| # When the user uploads a new file: | |
| upload_image.change( | |
| fn=update_preview, | |
| inputs=[upload_image, upscale_radio], | |
| outputs=[preview_with_box] | |
| ) | |
| # Also trigger preview redraw if they switch 1Γ/2Γ/4Γ after uploading: | |
| upscale_radio.change( | |
| fn=update_preview, | |
| inputs=[upload_image, upscale_radio], | |
| outputs=[preview_with_box] | |
| ) | |
| # ------------------------------------------------------------------ | |
| # CALLBACK #2: When βChain-of-Zoom itβ is clicked, run inference | |
| # ------------------------------------------------------------------ | |
| run_button.click( | |
| fn=run_with_upload, | |
| inputs=[upload_image, upscale_radio], | |
| outputs=[output_gallery] | |
| ) | |
| # ------------------------------------------------------------------ | |
| # CALLBACK #3: When an image in the gallery is clicked, show its caption | |
| # ------------------------------------------------------------------ | |
| output_gallery.select( | |
| fn=get_caption, | |
| inputs=[output_gallery], | |
| outputs=[caption_text] | |
| ) | |
| # ------------------------------------------------------------------ | |
| # START THE GRADIO SERVER | |
| # ------------------------------------------------------------------ | |
| # π§ 1) turn the global queue ON and set its default_concurrency_limit to 1 | |
| demo.queue(default_concurrency_limit=1, # β€ 1 worker per event | |
| max_size=20) # optional: allow 20 waiting jobs | |
| # π§ 2) launch as usual | |
| demo.launch(share=True) |