import os import io import cv2 import gradio as gr import numpy as np import torch import spaces from PIL import Image from functools import lru_cache from huggingface_hub import hf_hub_download, snapshot_download from torchvision.transforms.functional import normalize import glob import traceback from restormerRFR_arch import RestormerRFR from dino_feature_extractor import DinoFeatureModule WEIGHT_REPO_ID = "233zzl/RAM_plus_plus" WEIGHT_FILENAME = "ram_plus/7task/RestormerRFR.pth" MODEL_NAME = "RestormerRFR" def get_device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def warmup(): hf_hub_download( repo_id=WEIGHT_REPO_ID, filename=WEIGHT_FILENAME, repo_type="model", revision="main" ) snapshot_download( repo_id="facebook/dinov2-giant", repo_type="model", revision="main" ) def build_model(): model = RestormerRFR( inp_channels=3, out_channels=3, dim=48, num_blocks=[4, 6, 6, 8], num_refinement_blocks=4, heads=[1, 2, 4, 8], ffn_expansion_factor=2.66, bias=False, LayerNorm_type="WithBias", finetune_type=None, img_size=128, ) return model @lru_cache(maxsize=1) def get_dino_extractor(device): extractor = DinoFeatureModule().to(device).eval() return extractor @lru_cache(maxsize=1) def get_model_and_device(): device = get_device() model = build_model() weight_path = hf_hub_download( repo_id=WEIGHT_REPO_ID, filename=WEIGHT_FILENAME, ) ckpt = torch.load(weight_path, map_location="cpu") keyname = "params" if "params" in ckpt else None if keyname is not None: model.load_state_dict(ckpt[keyname], strict=False) else: model.load_state_dict(ckpt, strict=False) model.eval().to(device) return model, device @spaces.GPU(duration=60) def restore_image(pil_img: Image.Image) -> Image.Image: try: model, device = get_model_and_device() dino_extractor = get_dino_extractor(device) img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0 img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB img = img.unsqueeze(0).to(device) # (1,3,H,W) mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) normalize(img, mean, std, inplace=True) with torch.no_grad(): dino_features = dino_extractor(img) output = model(img, dino_features) output = normalize(output, -1 * mean / std, 1 / std) output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W) output = (output * 255.0).round().astype(np.uint8) output = np.transpose(output, (1, 2, 0)) out_pil = Image.fromarray(output, mode="RGB") return out_pil except Exception as e: raise gr.Error(f"{e}\n{traceback.format_exc()}") DESCRIPTION = """ # RAM++ """ with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): inp = gr.Image(type="pil", label="load picture(JPEG/PNG)") btn = gr.Button("Run (ZeroGPU)") with gr.Column(): out = gr.Image(type="pil", label="output") ex_files = [] for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp"): ex_files.extend(glob.glob(os.path.join("examples", ext))) ex_files = sorted(ex_files) if ex_files: gr.Examples(examples=ex_files, inputs=inp, label="exampls)") btn.click(restore_image, inputs=inp, outputs=out, api_name="run") gr.Markdown(""" **Tips** - If the queue is long or you hit the quota, please try again later, or upgrade to Pro for a higher ZeroGPU quota and priority. """) demo.load(fn=warmup, inputs=None, outputs=None) if __name__ == "__main__": demo.launch()