File size: 4,284 Bytes
3bfd811
 
 
 
 
 
 
 
 
 
 
 
7318bea
3bfd811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7318bea
3bfd811
 
 
 
7318bea
 
 
3bfd811
 
7318bea
 
 
3bfd811
 
7318bea
 
 
3bfd811
7318bea
 
 
3bfd811
 
7318bea
 
 
 
 
 
 
 
3bfd811
 
7318bea
3bfd811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import io
import cv2
import gradio as gr
import numpy as np
import torch
import spaces
from PIL import Image
from functools import lru_cache
from huggingface_hub import hf_hub_download, snapshot_download
from torchvision.transforms.functional import normalize
import glob
import traceback


from restormerRFR_arch import RestormerRFR
from dino_feature_extractor import DinoFeatureModule

WEIGHT_REPO_ID = "233zzl/RAM_plus_plus"
WEIGHT_FILENAME = "7task/RestormerRFR.pth" 
MODEL_NAME = "RestormerRFR"

def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def warmup():
  
    hf_hub_download(
        repo_id=WEIGHT_REPO_ID,
        filename=WEIGHT_FILENAME,
        repo_type="model",
        revision="main"
    )
    snapshot_download(
        repo_id="facebook/dinov2-giant",
        repo_type="model",
        revision="main"
    )


def build_model():
    model = RestormerRFR(
        inp_channels=3,
        out_channels=3,
        dim=48,
        num_blocks=[4, 6, 6, 8],
        num_refinement_blocks=4,
        heads=[1, 2, 4, 8],
        ffn_expansion_factor=2.66,
        bias=False,
        LayerNorm_type="WithBias",
        finetune_type=None,
        img_size=128,
    )
    return model

@lru_cache(maxsize=1)
def get_dino_extractor(device):
    extractor = DinoFeatureModule().to(device).eval()
    return extractor

@lru_cache(maxsize=1)
def get_model_and_device():
    device = get_device()
    model = build_model()

   
    weight_path = hf_hub_download(
        repo_id=WEIGHT_REPO_ID,
        filename=WEIGHT_FILENAME,
        
    )
   
    ckpt = torch.load(weight_path, map_location="cpu")
    keyname = "params" if "params" in ckpt else None
    if keyname is not None:
        model.load_state_dict(ckpt[keyname], strict=False)
    else:
        model.load_state_dict(ckpt, strict=False)

    model.eval().to(device)
    return model, device


@spaces.GPU(duration=240)
def restore_image(pil_img: Image.Image) -> Image.Image:
    """
    输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
    """
    try:
        model, device = get_model_and_device()
        dino_extractor = get_dino_extractor(device)


        img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
        img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float()  # (3,H,W), RGB
        img = img.unsqueeze(0).to(device)  # (1,3,H,W)


        mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
        normalize(img, mean, std, inplace=True)

        with torch.no_grad():
            dino_features = dino_extractor(img)
            output = model(img, dino_features)


        output = normalize(output, -1 * mean / std, 1 / std)
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()  # (3,H,W)
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # (H,W,RGB)
        output = (output * 255.0).round().astype(np.uint8)
        out_pil = Image.fromarray(output, mode="RGB")
        return out_pil
    except Exception as e:
        raise gr.Error(f"{e}\n{traceback.format_exc()}")

DESCRIPTION = """
# RAM++: Robust Representation Learning via  Adaptive Mask for All-in-One Image Restoration
"""

with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Column():
            inp = gr.Image(type="pil", label="load picture(JPEG/PNG)")
            btn = gr.Button("Run (ZeroGPU)")
        with gr.Column():
            out = gr.Image(type="pil", label="output")

    ex_files = []
    for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp"):
        ex_files.extend(glob.glob(os.path.join("examples", ext)))
    ex_files = sorted(ex_files)
    if ex_files:
        gr.Examples(examples=ex_files, inputs=inp, label="exampls)")

    btn.click(restore_image, inputs=inp, outputs=out, api_name="run")

    gr.Markdown("""
**Tips**
- If the queue is long or you hit the quota, please try again later, or upgrade to Pro for a higher ZeroGPU quota and priority.
""")

    demo.load(fn=warmup, inputs=None, outputs=None)


if __name__ == "__main__":
   
    demo.launch()