Spaces:
Sleeping
Sleeping
File size: 4,284 Bytes
3bfd811 7318bea 3bfd811 7318bea 3bfd811 7318bea 3bfd811 7318bea 3bfd811 7318bea 3bfd811 7318bea 3bfd811 7318bea 3bfd811 7318bea 3bfd811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import io
import cv2
import gradio as gr
import numpy as np
import torch
import spaces
from PIL import Image
from functools import lru_cache
from huggingface_hub import hf_hub_download, snapshot_download
from torchvision.transforms.functional import normalize
import glob
import traceback
from restormerRFR_arch import RestormerRFR
from dino_feature_extractor import DinoFeatureModule
WEIGHT_REPO_ID = "233zzl/RAM_plus_plus"
WEIGHT_FILENAME = "7task/RestormerRFR.pth"
MODEL_NAME = "RestormerRFR"
def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def warmup():
hf_hub_download(
repo_id=WEIGHT_REPO_ID,
filename=WEIGHT_FILENAME,
repo_type="model",
revision="main"
)
snapshot_download(
repo_id="facebook/dinov2-giant",
repo_type="model",
revision="main"
)
def build_model():
model = RestormerRFR(
inp_channels=3,
out_channels=3,
dim=48,
num_blocks=[4, 6, 6, 8],
num_refinement_blocks=4,
heads=[1, 2, 4, 8],
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type="WithBias",
finetune_type=None,
img_size=128,
)
return model
@lru_cache(maxsize=1)
def get_dino_extractor(device):
extractor = DinoFeatureModule().to(device).eval()
return extractor
@lru_cache(maxsize=1)
def get_model_and_device():
device = get_device()
model = build_model()
weight_path = hf_hub_download(
repo_id=WEIGHT_REPO_ID,
filename=WEIGHT_FILENAME,
)
ckpt = torch.load(weight_path, map_location="cpu")
keyname = "params" if "params" in ckpt else None
if keyname is not None:
model.load_state_dict(ckpt[keyname], strict=False)
else:
model.load_state_dict(ckpt, strict=False)
model.eval().to(device)
return model, device
@spaces.GPU(duration=240)
def restore_image(pil_img: Image.Image) -> Image.Image:
"""
输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
"""
try:
model, device = get_model_and_device()
dino_extractor = get_dino_extractor(device)
img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB
img = img.unsqueeze(0).to(device) # (1,3,H,W)
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
normalize(img, mean, std, inplace=True)
with torch.no_grad():
dino_features = dino_extractor(img)
output = model(img, dino_features)
output = normalize(output, -1 * mean / std, 1 / std)
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W)
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # (H,W,RGB)
output = (output * 255.0).round().astype(np.uint8)
out_pil = Image.fromarray(output, mode="RGB")
return out_pil
except Exception as e:
raise gr.Error(f"{e}\n{traceback.format_exc()}")
DESCRIPTION = """
# RAM++: Robust Representation Learning via Adaptive Mask for All-in-One Image Restoration
"""
with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="load picture(JPEG/PNG)")
btn = gr.Button("Run (ZeroGPU)")
with gr.Column():
out = gr.Image(type="pil", label="output")
ex_files = []
for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp"):
ex_files.extend(glob.glob(os.path.join("examples", ext)))
ex_files = sorted(ex_files)
if ex_files:
gr.Examples(examples=ex_files, inputs=inp, label="exampls)")
btn.click(restore_image, inputs=inp, outputs=out, api_name="run")
gr.Markdown("""
**Tips**
- If the queue is long or you hit the quota, please try again later, or upgrade to Pro for a higher ZeroGPU quota and priority.
""")
demo.load(fn=warmup, inputs=None, outputs=None)
if __name__ == "__main__":
demo.launch()
|