Spaces:
Running
on
Zero
Running
on
Zero
| from functools import partial | |
| from typing import Any, Callable, Dict | |
| import clip | |
| import gradio as gr | |
| from gradio_rerun import Rerun | |
| import numpy as np | |
| import trimesh | |
| import rerun as rr | |
| import torch | |
| from utils.common_viz import init, get_batch | |
| from utils.random_utils import set_random_seed | |
| from utils.rerun import log_sample | |
| from src.diffuser import Diffuser | |
| from src.datasets.multimodal_dataset import MultimodalDataset | |
| # ------------------------------------------------------------------------------------- # | |
| batch_size, num_cams, num_verts = None, None, None | |
| SAMPLE_IDS = [ | |
| "2011_KAeAqaA0Llg_00005_00001", | |
| "2011_F_EuMeT2wBo_00014_00001", | |
| "2011_MCkKihQrNA4_00014_00000", | |
| ] | |
| LABEL_TO_IDS = { | |
| "right": 0, | |
| "static": 1, | |
| "complex": 2, | |
| } | |
| EXAMPLES = [ | |
| "While the character moves right, the camera trucks right.", | |
| "While the character moves right, the camera performs a push in.", | |
| "While the character moves right, the camera performs a pull out.", | |
| "While the character stays static, the camera performs a boom bottom.", | |
| "While the character stays static, the camera performs a boom top.", | |
| "While the character moves to the right, the camera trucks right alongside them. Once the character comes to a stop, the camera remains static.", # noqa | |
| "While the character moves to the right, the camera remains static. Once the character comes to a stop, the camera pushes in.", # noqa | |
| ] | |
| DEFAULT_TEXT = [ | |
| "While the character moves right, the camera [...].", | |
| "While the character remains static, [...].", | |
| "While the character moves to the right, the camera [...]. " | |
| "Once the character comes to a stop, the camera [...].", | |
| ] | |
| HEADER = """ | |
| <div align="center"> | |
| <h1 style='text-align: center'>E.T. the Exceptional Trajectories</h2> | |
| <a href="https://robincourant.github.io/info/"><strong>Robin Courant</strong></a> | |
| 路 | |
| <a href="https://nicolas-dufour.github.io/"><strong>Nicolas Dufour</strong></a> | |
| 路 | |
| <a href="https://triocrossing.github.io/"><strong>Xi Wang</strong></a> | |
| 路 | |
| <a href="http://people.irisa.fr/Marc.Christie/"><strong>Marc Christie</strong></a> | |
| 路 | |
| <a href="https://vicky.kalogeiton.info/"><strong>Vicky Kalogeiton</strong></a> | |
| </div> | |
| <div align="center"> | |
| <a href="https://www.lix.polytechnique.fr/vista/projects/2024_et_courant/" class="button"><b>[Webpage]</b></a> | |
| <a href="https://github.com/robincourant/DIRECTOR" class="button"><b>[DIRECTOR]</b></a> | |
| <a href="https://github.com/robincourant/CLaTr" class="button"><b>[CLaTr]</b></a> | |
| <a href="https://github.com/robincourant/the-exceptional-trajectories" class="button"><b>[Data]</b></a> | |
| </div> | |
| <br/> | |
| """ | |
| # ------------------------------------------------------------------------------------- # | |
| def get_normals(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor: | |
| num_frames, num_faces = vertices.shape[0], faces.shape[-2] | |
| faces = faces.expand(num_frames, num_faces, 3) | |
| normals = [ | |
| trimesh.Trimesh(vertices=v, faces=f, process=False).vertex_normals | |
| for v, f in zip(vertices, faces) | |
| ] | |
| normals = torch.from_numpy(np.stack(normals)) | |
| return normals | |
| def generate( | |
| prompt: str, | |
| seed: int, | |
| guidance_weight: float, | |
| sample_label: str, | |
| # ----------------------- 脽# | |
| dataset: MultimodalDataset, | |
| device: torch.device, | |
| diffuser: Diffuser, | |
| clip_model: clip.model.CLIP, | |
| ) -> Dict[str, Any]: | |
| # Set arguments | |
| set_random_seed(seed) | |
| diffuser.gen_seeds = np.array([seed]) | |
| diffuser.guidance_weight = guidance_weight | |
| # Inference | |
| sample_id = SAMPLE_IDS[LABEL_TO_IDS[sample_label]] | |
| seq_feat = diffuser.net.model.clip_sequential | |
| batch = get_batch(prompt, sample_id, clip_model, dataset, seq_feat, device) | |
| with torch.no_grad(): | |
| out = diffuser.predict_step(batch, 0) | |
| # Run visualization | |
| padding_mask = out["padding_mask"][0].to(bool).cpu() | |
| padded_traj = out["gen_samples"][0].cpu() | |
| traj = padded_traj[padding_mask] | |
| padded_vertices = out["char_raw"]["char_vertices"][0] | |
| vertices = padded_vertices[padding_mask] | |
| faces = out["char_raw"]["char_faces"][0] | |
| normals = get_normals(vertices, faces) | |
| fx, fy, cx, cy = out["intrinsics"][0].cpu().numpy() | |
| K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) | |
| caption = out["caption_raw"][0] | |
| rr.init(f"{sample_id}") | |
| rr.save(".tmp_gr.rrd") | |
| log_sample( | |
| root_name="world", | |
| traj=traj.numpy(), | |
| K=K, | |
| vertices=vertices.numpy(), | |
| faces=faces.numpy(), | |
| normals=normals.numpy(), | |
| caption=caption, | |
| mesh_masks=None, | |
| ) | |
| return "./.tmp_gr.rrd" | |
| # ------------------------------------------------------------------------------------- # | |
| def main(gen_fn: Callable): | |
| theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray") | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Markdown(HEADER) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Column(scale=2): | |
| sample_str = gr.Dropdown( | |
| choices=["static", "right", "complex"], | |
| label="Character trajectory", | |
| value="right", | |
| interactive=True, | |
| ) | |
| text = gr.Textbox( | |
| placeholder="Type the camera motion you want to generate", | |
| show_label=True, | |
| label="Text prompt", | |
| value=DEFAULT_TEXT[LABEL_TO_IDS[sample_str.value]], | |
| ) | |
| seed = gr.Number(value=33, label="Seed") | |
| guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1) | |
| with gr.Column(scale=1): | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| examples = gr.Examples( | |
| examples=[[x, None, None] for x in EXAMPLES], | |
| inputs=[text], | |
| ) | |
| with gr.Row(): | |
| output = Rerun() | |
| def load_example(example_id): | |
| processed_example = examples.non_none_processed_examples[example_id] | |
| return gr.utils.resolve_singleton(processed_example) | |
| def change_fn(change): | |
| sample_index = LABEL_TO_IDS[change] | |
| return gr.update(value=DEFAULT_TEXT[sample_index]) | |
| sample_str.change(fn=change_fn, inputs=[sample_str], outputs=[text]) | |
| inputs = [text, seed, guidance, sample_str] | |
| examples.dataset.click( | |
| load_example, | |
| inputs=[examples.dataset], | |
| outputs=examples.inputs_with_examples, | |
| show_progress=False, | |
| postprocess=False, | |
| queue=False, | |
| ).then(fn=gen_fn, inputs=inputs, outputs=[output]) | |
| btn.click(fn=gen_fn, inputs=inputs, outputs=[output]) | |
| text.submit(fn=gen_fn, inputs=inputs, outputs=[output]) | |
| demo.queue().launch(share=False) | |
| # ------------------------------------------------------------------------------------- # | |
| if __name__ == "__main__": | |
| # Initialize the models and dataset | |
| diffuser, clip_model, dataset, device = init("config") | |
| generate_sample = partial( | |
| generate, | |
| dataset=dataset, | |
| device=device, | |
| diffuser=diffuser, | |
| clip_model=clip_model, | |
| ) | |
| main(generate_sample) | |