Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import io | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import spaces | |
| from PIL import Image | |
| from functools import lru_cache | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from torchvision.transforms.functional import normalize | |
| import glob | |
| import traceback | |
| from restormerRFR_arch import RestormerRFR | |
| from dino_feature_extractor import DinoFeatureModule | |
| WEIGHT_REPO_ID = "233zzl/RAM_plus_plus" | |
| WEIGHT_FILENAME = "7task/RestormerRFR.pth" | |
| MODEL_NAME = "RestormerRFR" | |
| def get_device(): | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def warmup(): | |
| hf_hub_download( | |
| repo_id=WEIGHT_REPO_ID, | |
| filename=WEIGHT_FILENAME, | |
| repo_type="model", | |
| revision="main" | |
| ) | |
| snapshot_download( | |
| repo_id="facebook/dinov2-giant", | |
| repo_type="model", | |
| revision="main" | |
| ) | |
| def build_model(): | |
| model = RestormerRFR( | |
| inp_channels=3, | |
| out_channels=3, | |
| dim=48, | |
| num_blocks=[4, 6, 6, 8], | |
| num_refinement_blocks=4, | |
| heads=[1, 2, 4, 8], | |
| ffn_expansion_factor=2.66, | |
| bias=False, | |
| LayerNorm_type="WithBias", | |
| finetune_type=None, | |
| img_size=128, | |
| ) | |
| return model | |
| def get_dino_extractor(device): | |
| extractor = DinoFeatureModule().to(device).eval() | |
| return extractor | |
| def get_model_and_device(): | |
| device = get_device() | |
| model = build_model() | |
| weight_path = hf_hub_download( | |
| repo_id=WEIGHT_REPO_ID, | |
| filename=WEIGHT_FILENAME, | |
| ) | |
| ckpt = torch.load(weight_path, map_location="cpu") | |
| keyname = "params" if "params" in ckpt else None | |
| if keyname is not None: | |
| model.load_state_dict(ckpt[keyname], strict=False) | |
| else: | |
| model.load_state_dict(ckpt, strict=False) | |
| model.eval().to(device) | |
| return model, device | |
| def restore_image(pil_img: Image.Image) -> Image.Image: | |
| """ | |
| 输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致) | |
| """ | |
| try: | |
| model, device = get_model_and_device() | |
| dino_extractor = get_dino_extractor(device) | |
| img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0 | |
| img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB | |
| img = img.unsqueeze(0).to(device) # (1,3,H,W) | |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| normalize(img, mean, std, inplace=True) | |
| with torch.no_grad(): | |
| dino_features = dino_extractor(img) | |
| output = model(img, dino_features) | |
| output = normalize(output, -1 * mean / std, 1 / std) | |
| output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W) | |
| output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # (H,W,RGB) | |
| output = (output * 255.0).round().astype(np.uint8) | |
| out_pil = Image.fromarray(output, mode="RGB") | |
| return out_pil | |
| except Exception as e: | |
| raise gr.Error(f"{e}\n{traceback.format_exc()}") | |
| DESCRIPTION = """ | |
| # RAM++: Robust Representation Learning via Adaptive Mask for All-in-One Image Restoration | |
| """ | |
| with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(type="pil", label="load picture(JPEG/PNG)") | |
| btn = gr.Button("Run (ZeroGPU)") | |
| with gr.Column(): | |
| out = gr.Image(type="pil", label="output") | |
| ex_files = [] | |
| for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp"): | |
| ex_files.extend(glob.glob(os.path.join("examples", ext))) | |
| ex_files = sorted(ex_files) | |
| if ex_files: | |
| gr.Examples(examples=ex_files, inputs=inp, label="exampls)") | |
| btn.click(restore_image, inputs=inp, outputs=out, api_name="run") | |
| gr.Markdown(""" | |
| **Tips** | |
| - If the queue is long or you hit the quota, please try again later, or upgrade to Pro for a higher ZeroGPU quota and priority. | |
| """) | |
| demo.load(fn=warmup, inputs=None, outputs=None) | |
| if __name__ == "__main__": | |
| demo.launch() | |