Spaces:
Sleeping
Sleeping
| import spaces | |
| import os | |
| import gradio as gr | |
| from gradio_imageslider import ImageSlider | |
| import torch | |
| from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype | |
| import numpy as np | |
| from SUPIR.util import create_SUPIR_model, load_QF_ckpt | |
| from PIL import Image | |
| from llava.llava_agent import LLavaAgent | |
| from CKPT_PTH import LLAVA_MODEL_PATH | |
| import einops | |
| import copy | |
| import time | |
| # Initialize devices | |
| if torch.cuda.device_count() >= 2: | |
| SUPIR_device = 'cuda:0' | |
| LLaVA_device = 'cuda:1' | |
| elif torch.cuda.device_count() == 1: | |
| SUPIR_device = 'cuda:0' | |
| LLaVA_device = 'cuda:0' | |
| else: | |
| raise ValueError('Currently support CUDA only.') | |
| # Load SUPIR model | |
| model, default_setting = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q', load_default_setting=True) | |
| model = model.to(SUPIR_device) | |
| model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder) | |
| model.current_model = 'v0-Q' | |
| ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml') | |
| # Load LLaVA | |
| llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=True, load_4bit=False) | |
| def stage1_process(input_image, gamma_correction): | |
| torch.cuda.set_device(SUPIR_device) | |
| LQ = HWC3(input_image) | |
| LQ = fix_resize(LQ, 512) | |
| # stage1 | |
| LQ = np.array(LQ) / 255 * 2 - 1 | |
| LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :] | |
| LQ = model.batchify_denoise(LQ, is_stage1=True) | |
| LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8) | |
| # gamma correction | |
| LQ = LQ / 255.0 | |
| LQ = np.power(LQ, gamma_correction) | |
| LQ *= 255.0 | |
| LQ = LQ.round().clip(0, 255).astype(np.uint8) | |
| return LQ | |
| def llave_process(input_image, temperature, top_p, qs=None): | |
| torch.cuda.set_device(LLaVA_device) | |
| LQ = HWC3(input_image) | |
| LQ = Image.fromarray(LQ.astype('uint8')) | |
| captions = llava_agent.gen_image_caption([LQ], temperature=temperature, top_p=top_p, qs=qs) | |
| return captions[0] | |
| def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, | |
| s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, | |
| linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select): | |
| torch.cuda.set_device(SUPIR_device) | |
| event_id = str(time.time_ns()) | |
| if model_select != model.current_model: | |
| if model_select == 'v0-Q': | |
| model.load_state_dict(ckpt_Q, strict=False) | |
| model.current_model = 'v0-Q' | |
| elif model_select == 'v0-F': | |
| model.load_state_dict(ckpt_F, strict=False) | |
| model.current_model = 'v0-F' | |
| input_image = HWC3(input_image) | |
| input_image = upscale_image(input_image, upscale, unit_resolution=32, min_size=1024) | |
| LQ = np.array(input_image) / 255.0 | |
| LQ = np.power(LQ, gamma_correction) | |
| LQ *= 255.0 | |
| LQ = LQ.round().clip(0, 255).astype(np.uint8) | |
| LQ = LQ / 255 * 2 - 1 | |
| LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :] | |
| captions = [prompt] | |
| model.ae_dtype = convert_dtype(ae_dtype) | |
| model.model.dtype = convert_dtype(diff_dtype) | |
| samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn, | |
| s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed, | |
| num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type, | |
| use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2, | |
| cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2) | |
| x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip( | |
| 0, 255).astype(np.uint8) | |
| results = [x_samples[i] for i in range(num_samples)] | |
| return [input_image] + results, event_id, 3, '' | |
| def load_and_reset(param_setting): | |
| edm_steps = default_setting.edm_steps | |
| s_stage2 = 1.0 | |
| s_stage1 = -1.0 | |
| s_churn = 5 | |
| s_noise = 1.003 | |
| a_prompt = 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.' | |
| n_prompt = 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth' | |
| color_fix_type = 'Wavelet' | |
| spt_linear_s_stage2 = 0.0 | |
| linear_s_stage2 = False | |
| linear_CFG = True | |
| if param_setting == "Quality": | |
| s_cfg = default_setting.s_cfg_Quality | |
| spt_linear_CFG = default_setting.spt_linear_CFG_Quality | |
| elif param_setting == "Fidelity": | |
| s_cfg = default_setting.s_cfg_Fidelity | |
| spt_linear_CFG = default_setting.spt_linear_CFG_Fidelity | |
| else: | |
| raise NotImplementedError | |
| return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2 | |
| # Create Gradio interface | |
| block = gr.Blocks(title='SUPIR').queue() | |
| with block: | |
| with gr.Row(): | |
| gr.Markdown("# **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| gr.Markdown("<center>Input</center>") | |
| input_image = gr.Image(type="numpy", elem_id="image-input", height=400, width=400) | |
| with gr.Column(): | |
| gr.Markdown("<center>Stage1 Output</center>") | |
| denoise_image = gr.Image(type="numpy", elem_id="image-s1", height=400, width=400) | |
| prompt = gr.Textbox(label="Prompt", value="") | |
| with gr.Accordion("Stage1 options", open=False): | |
| gamma_correction = gr.Slider(label="Gamma Correction", minimum=0.1, maximum=2.0, value=1.0, step=0.1) | |
| with gr.Accordion("LLaVA options", open=False): | |
| temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1) | |
| top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1) | |
| qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.") | |
| with gr.Accordion("Stage2 options", open=False): | |
| num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4, value=1, step=1) | |
| upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1) | |
| edm_steps = gr.Slider(label="Steps", minimum=1, maximum=200, value=default_setting.edm_steps, step=1) | |
| s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=default_setting.s_cfg_Quality, step=0.1) | |
| s_stage2 = gr.Slider(label="Stage2 Guidance Strength", minimum=0., maximum=1., value=1., step=0.05) | |
| s_stage1 = gr.Slider(label="Stage1 Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0, step=1.0) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
| s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1) | |
| s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001) | |
| a_prompt = gr.Textbox(label="Default Positive Prompt", value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.') | |
| n_prompt = gr.Textbox(label="Default Negative Prompt", value='painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth') | |
| with gr.Row(): | |
| with gr.Column(): | |
| linear_CFG = gr.Checkbox(label="Linear CFG", value=True) | |
| spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0, maximum=9.0, value=default_setting.spt_linear_CFG_Quality, step=0.5) | |
| with gr.Column(): | |
| linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False) | |
| spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0., maximum=1., value=0., step=0.05) | |
| with gr.Row(): | |
| with gr.Column(): | |
| diff_dtype = gr.Radio(['fp32', 'fp16', 'bf16'], label="Diffusion Data Type", value="fp16", interactive=True) | |
| with gr.Column(): | |
| ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16", interactive=True) | |
| with gr.Column(): | |
| color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet", interactive=True) | |
| with gr.Column(): | |
| model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q", interactive=True) | |
| with gr.Column(): | |
| gr.Markdown("<center>Stage2 Output</center>") | |
| result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1") | |
| with gr.Row(): | |
| with gr.Column(): | |
| denoise_button = gr.Button(value="Stage1 Run") | |
| with gr.Column(): | |
| llave_button = gr.Button(value="LlaVa Run") | |
| with gr.Column(): | |
| diffusion_button = gr.Button(value="Stage2 Run") | |
| with gr.Row(): | |
| with gr.Column(): | |
| param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting", value="Quality") | |
| with gr.Column(): | |
| restart_button = gr.Button(value="Reset Param", scale=2) | |
| # Connect the buttons to their functions | |
| llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt]) | |
| denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction], outputs=[denoise_image]) | |
| stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, | |
| s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, | |
| linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select] | |
| diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, gr.Textbox(visible=False), gr.Slider(visible=False), gr.Textbox(visible=False)]) | |
| restart_button.click(fn=load_and_reset, inputs=[param_setting], | |
| outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, | |
| color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2]) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| block.launch() |