Spaces:
Configuration error
Configuration error
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| import subprocess | |
| from gradio_model4dgs import Model4DGS | |
| import numpy | |
| import hashlib | |
| import shlex | |
| subprocess.run(shlex.split("pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
| # subprocess.run(shlex.split("pip install xformers==0.0.23 --no-deps --index-url https://download.pytorch.org/whl/cu118")) | |
| import rembg | |
| import glob | |
| import cv2 | |
| import numpy as np | |
| from diffusers import StableVideoDiffusionPipeline | |
| from scripts.gen_vid import * | |
| import sys | |
| sys.path.append('lgm') | |
| from safetensors.torch import load_file | |
| from kiui.cam import orbit_camera | |
| from core.options import config_defaults, Options | |
| from core.models import LGM | |
| from mvdream.pipeline_mvdream import MVDreamPipeline | |
| from infer_demo import process as process_lgm | |
| from main_4d_demo import process as process_dg4d | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors") | |
| js_func = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'light') { | |
| url.searchParams.set('__theme', 'light'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| device = torch.device('cuda') | |
| # device = torch.device('cpu') | |
| session = rembg.new_session(model_name='u2net') | |
| pipe = StableVideoDiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16" | |
| ) | |
| pipe.to(device) | |
| opt = config_defaults['big'] | |
| opt.resume = ckpt_path | |
| # model | |
| model = LGM(opt) | |
| # resume pretrained checkpoint | |
| if opt.resume is not None: | |
| if opt.resume.endswith('safetensors'): | |
| ckpt = load_file(opt.resume, device='cpu') | |
| else: | |
| ckpt = torch.load(opt.resume, map_location='cpu') | |
| model.load_state_dict(ckpt, strict=False) | |
| print(f'[INFO] Loaded checkpoint from {opt.resume}') | |
| else: | |
| print(f'[WARN] model randomly initialized, are you sure?') | |
| # device | |
| model = model.half().to(device) | |
| model.eval() | |
| rays_embeddings = model.prepare_default_rays(device) | |
| # load image dream | |
| pipe_mvdream = MVDreamPipeline.from_pretrained( | |
| "ashawkey/imagedream-ipmv-diffusers", # remote weights | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| # local_files_only=True, | |
| ) | |
| pipe_mvdream = pipe_mvdream.to(device) | |
| from guidance.zero123_utils import Zero123 | |
| guidance_zero123 = Zero123(device, model_key='ashawkey/stable-zero123-diffusers') | |
| def preprocess(path, recenter=True, size=256, border_ratio=0.2): | |
| files = [path] | |
| out_dir = os.path.dirname(path) | |
| for file in files: | |
| out_base = os.path.basename(file).split('.')[0] | |
| out_rgba = os.path.join(out_dir, out_base + '_rgba.png') | |
| # load image | |
| print(f'[INFO] loading image {file}...') | |
| image = cv2.imread(file, cv2.IMREAD_UNCHANGED) | |
| # carve background | |
| print(f'[INFO] background removal...') | |
| carved_image = rembg.remove(image, session=session) # [H, W, 4] | |
| mask = carved_image[..., -1] > 0 | |
| # recenter | |
| if recenter: | |
| print(f'[INFO] recenter...') | |
| final_rgba = np.zeros((size, size, 4), dtype=np.uint8) | |
| coords = np.nonzero(mask) | |
| x_min, x_max = coords[0].min(), coords[0].max() | |
| y_min, y_max = coords[1].min(), coords[1].max() | |
| h = x_max - x_min | |
| w = y_max - y_min | |
| desired_size = int(size * (1 - border_ratio)) | |
| scale = desired_size / max(h, w) | |
| h2 = int(h * scale) | |
| w2 = int(w * scale) | |
| x2_min = (size - h2) // 2 | |
| x2_max = x2_min + h2 | |
| y2_min = (size - w2) // 2 | |
| y2_max = y2_min + w2 | |
| final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) | |
| else: | |
| final_rgba = carved_image | |
| # write image | |
| cv2.imwrite(out_rgba, final_rgba) | |
| def gen_vid(input_path, seed, bg='white'): | |
| name = input_path.split('/')[-1].split('.')[0] | |
| input_dir = os.path.dirname(input_path) | |
| height, width = 512, 512 | |
| image = load_image(input_path, width, height, bg) | |
| generator = torch.manual_seed(seed) | |
| # frames = pipe(image, height, width, decode_chunk_size=2, generator=generator).frames[0] | |
| frames = pipe(image, height, width, generator=generator).frames[0] | |
| imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=7) | |
| os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True) | |
| for idx, img in enumerate(frames): | |
| img.save(f"{input_dir}/{name}_frames/{idx:03}.png") | |
| # check if there is a picture uploaded or selected | |
| def check_img_input(control_image): | |
| if control_image is None: | |
| raise gr.Error("Please select or upload an input image") | |
| # check if there is a picture uploaded or selected | |
| def check_video_3d_input(image_block: Image.Image): | |
| img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() | |
| if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')): | |
| raise gr.Error("Please generate a video first") | |
| if not os.path.exists(os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')): | |
| raise gr.Error("Please generate a 3D first") | |
| def optimize_stage_0(image_block: Image.Image, preprocess_chk: bool, seed_slider: int): | |
| if not os.path.exists('tmp_data'): | |
| os.makedirs('tmp_data') | |
| img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() | |
| if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')): | |
| if preprocess_chk: | |
| # save image to a designated path | |
| image_block.save(os.path.join('tmp_data', f'{img_hash}.png')) | |
| # preprocess image | |
| # print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}') | |
| # subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True) | |
| preprocess(os.path.join("tmp_data", f"{img_hash}.png")) | |
| else: | |
| image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png')) | |
| # stage 1 | |
| # subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True) | |
| gen_vid(f'tmp_data/{img_hash}_rgba.png', seed_slider) | |
| # return [os.path.join('logs', 'tmp_rgba_model.ply')] | |
| return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4') | |
| def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int): | |
| if not os.path.exists('tmp_data'): | |
| os.makedirs('tmp_data') | |
| img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() | |
| if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')): | |
| if preprocess_chk: | |
| # save image to a designated path | |
| image_block.save(os.path.join('tmp_data', f'{img_hash}.png')) | |
| # preprocess image | |
| # print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}') | |
| # subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True) | |
| preprocess(os.path.join("tmp_data", f"{img_hash}.png")) | |
| else: | |
| image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png')) | |
| # stage 1 | |
| # subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True) | |
| process_lgm(opt, f'tmp_data/{img_hash}_rgba.png', pipe_mvdream, model, rays_embeddings, seed_slider) | |
| # return [os.path.join('logs', 'tmp_rgba_model.ply')] | |
| return os.path.join('vis_data', f'{img_hash}_rgba_static.mp4') | |
| def optimize_stage_2(image_block: Image.Image, seed_slider: int): | |
| img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() | |
| # stage 2 | |
| # subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True) | |
| process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123) | |
| # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames')) | |
| image_dir = os.path.join('logs', f'{img_hash}_rgba_frames') | |
| # return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')] | |
| return [image_dir+f'/{t:03d}.ply' for t in range(28)] | |
| if __name__ == "__main__": | |
| _TITLE = '''DreamGaussian4D: Generative 4D Gaussian Splatting''' | |
| _DESCRIPTION = ''' | |
| <div> | |
| <a style="display:inline-block" href="https://jiawei-ren.github.io/projects/dreamgaussian4d/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a> | |
| <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2312.17142"><img src="https://img.shields.io/badge/2312.17142-f9f7f7?logo="></a> | |
| <a style="display:inline-block; margin-left: .5em" href='https://github.com/jiawei-ren/dreamgaussian4d'><img src='https://img.shields.io/github/stars/jiawei-ren/dreamgaussian4d?style=social'/></a> | |
| </div> | |
| We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting. | |
| ''' | |
| _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), click **Generate Video** and **Generate 3D**. Finally, click **Generate 4D**." | |
| # load images in 'data' folder as examples | |
| example_folder = os.path.join(os.path.dirname(__file__), 'data') | |
| example_fns = os.listdir(example_folder) | |
| example_fns.sort() | |
| examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')] | |
| # Compose demo layout & data flow | |
| with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown('# ' + _TITLE) | |
| gr.Markdown(_DESCRIPTION) | |
| # Image-to-3D | |
| with gr.Row(variant='panel'): | |
| with gr.Column(scale=4): | |
| image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image') | |
| # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle') | |
| seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (Video)') | |
| seed_slider2 = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (3D)') | |
| gr.Markdown( | |
| "random seed for video generation.") | |
| preprocess_chk = gr.Checkbox(True, | |
| label='Preprocess image automatically (remove background and recenter object)') | |
| gr.Examples( | |
| examples=examples_full, # NOTE: elements must match inputs list! | |
| inputs=[image_block], | |
| outputs=[image_block], | |
| cache_examples=False, | |
| label='Examples (click one of the images below to start)', | |
| examples_per_page=40 | |
| ) | |
| img_run_btn = gr.Button("Generate Video") | |
| threed_run_btn = gr.Button("Generate 3D") | |
| fourd_run_btn = gr.Button("Generate 4D") | |
| img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True) | |
| with gr.Column(scale=5): | |
| dirving_video = gr.Video(label="video",height=290) | |
| obj3d = gr.Video(label="3D Model",height=290) | |
| obj4d = Model4DGS(label="4D Model", height=500, fps=14) | |
| img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_0, | |
| inputs=[image_block, | |
| preprocess_chk, | |
| seed_slider], | |
| outputs=[ | |
| dirving_video]) | |
| threed_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1, | |
| inputs=[image_block, | |
| preprocess_chk, | |
| seed_slider2], | |
| outputs=[ | |
| obj3d]) | |
| fourd_run_btn.click(check_video_3d_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d]) | |
| # demo.queue().launch(share=True) | |
| demo.queue(max_size=10) # <-- Sets up a queue with default parameters | |
| demo.launch() |