Spaces:
Running
Running
| import spaces | |
| import argparse | |
| import gradio as gr | |
| import os | |
| import torch | |
| import trimesh | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import json | |
| from datetime import datetime | |
| pathdir = Path(__file__).parent / 'cube' | |
| sys.path.append(pathdir.as_posix()) | |
| # print(__file__) | |
| # print(os.listdir()) | |
| # print(os.listdir('cube')) | |
| # print(pathdir.as_posix()) | |
| from cube3d.inference.engine import EngineFast, Engine | |
| from cube3d.inference.utils import normalize_bbox | |
| from pathlib import Path | |
| import uuid | |
| import shutil | |
| from huggingface_hub import snapshot_download | |
| from cube3d.mesh_utils.postprocessing import ( | |
| PYMESHLAB_AVAILABLE, | |
| create_pymeshset, | |
| postprocess_mesh, | |
| save_mesh, | |
| ) | |
| GLOBAL_STATE = {} | |
| def gen_save_folder(max_size=200): | |
| os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True) | |
| dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()] | |
| if len(dirs) >= max_size: | |
| oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime) | |
| shutil.rmtree(oldest_dir) | |
| print(f"Removed the oldest folder: {oldest_dir}") | |
| new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4())) | |
| os.makedirs(new_folder, exist_ok=True) | |
| print(f"Created new folder: {new_folder}") | |
| return new_folder | |
| def handle_text_prompt(input_prompt, use_bbox = True, bbox_x=1.0, bbox_y=1.0, bbox_z=1.0, hi_res=False): | |
| # Create debug info | |
| debug_info = { | |
| "timestamp": datetime.now().isoformat(), | |
| "prompt": input_prompt, | |
| "use_bbox": use_bbox, | |
| "bbox_x": bbox_x, | |
| "bbox_y": bbox_y, | |
| "bbox_z": bbox_z, | |
| "hi_res": hi_res | |
| } | |
| # Save to persistent storage | |
| data_dir = "/data" | |
| os.makedirs(data_dir, exist_ok=True) | |
| prompt_file = os.path.join(data_dir, "prompt_log.jsonl") | |
| with open(prompt_file, "a") as f: | |
| f.write(json.dumps(debug_info) + "\n") | |
| print(f"prompt: {input_prompt}, use_bbox: {use_bbox}, bbox_x: {bbox_x}, bbox_y: {bbox_y}, bbox_z: {bbox_z}, hi_res: {hi_res}") | |
| if "engine_fast" not in GLOBAL_STATE: | |
| config_path = GLOBAL_STATE["config_path"] | |
| gpt_ckpt_path = "./model_weights/shape_gpt.safetensors" | |
| shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors" | |
| engine_fast = EngineFast( | |
| config_path, | |
| gpt_ckpt_path, | |
| shape_ckpt_path, | |
| device=torch.device("cuda"), | |
| ) | |
| GLOBAL_STATE["engine_fast"] = engine_fast | |
| # Determine bounding box size based on option | |
| bbox_size = None | |
| if use_bbox: | |
| bbox_size = [bbox_x, bbox_y, bbox_z] | |
| # For "No Bounding Box", bbox_size remains None | |
| normalized_bbox = normalize_bbox(bbox_size) if bbox_size is not None else None | |
| resolution_base = 9.0 if hi_res else 8.0 | |
| mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=resolution_base, bounding_box_xyz=normalized_bbox) | |
| # save output | |
| vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] | |
| ms = create_pymeshset(vertices, faces) | |
| target_face_num = max(10000, int(faces.shape[0] * 0.1)) | |
| print(f"Postprocessing mesh to {target_face_num} faces") | |
| postprocess_mesh(ms, target_face_num) | |
| mesh = ms.current_mesh() | |
| vertices = mesh.vertex_matrix() | |
| faces = mesh.face_matrix() | |
| min_extents = np.min(mesh.vertex_matrix(), axis = 0) | |
| max_extents = np.max(mesh.vertex_matrix(), axis = 0) | |
| mesh = trimesh.Trimesh(vertices=vertices, faces=faces) | |
| scene = trimesh.scene.Scene() | |
| scene.add_geometry(mesh) | |
| save_folder = gen_save_folder() | |
| output_path = os.path.join(save_folder, "output.glb") | |
| # trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path) | |
| scene.export(output_path) | |
| return output_path | |
| def build_interface(): | |
| """Build UI for gradio app | |
| """ | |
| title = "Cube 3D" | |
| with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface: | |
| gr.Markdown( | |
| f""" | |
| # {title} | |
| **Disclaimer:** Content generated through the Hugging Face integration is not moderated by Roblox safety systems. Use of this integration is subject to the applicable license terms, and users are solely responsible for outputs generated. | |
| # Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Group(): | |
| input_text_box = gr.Textbox( | |
| value=None, | |
| label="Prompt", | |
| lines=2, | |
| ) | |
| use_bbox = gr.Checkbox(label="Use Bounding Box", value=False) | |
| with gr.Group() as bbox_group: | |
| bbox_x = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Length", interactive=False) | |
| bbox_y = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Height", interactive=False) | |
| bbox_z = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Depth", interactive=False) | |
| # Enable/disable bbox sliders based on use_bbox checkbox | |
| def toggle_bbox_interactivity(use_bbox): | |
| return ( | |
| gr.Slider(interactive=use_bbox), | |
| gr.Slider(interactive=use_bbox), | |
| gr.Slider(interactive=use_bbox) | |
| ) | |
| use_bbox.change( | |
| toggle_bbox_interactivity, | |
| inputs=[use_bbox], | |
| outputs=[bbox_x, bbox_y, bbox_z] | |
| ) | |
| hi_res = gr.Checkbox(label="Hi-Res", value=False) | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit", variant="primary") | |
| with gr.Column(scale=3): | |
| model3d = gr.Model3D( | |
| label="Output", height="45em", interactive=False | |
| ) | |
| submit_button.click( | |
| handle_text_prompt, | |
| inputs=[ | |
| input_text_box, | |
| use_bbox, | |
| bbox_x, | |
| bbox_y, | |
| bbox_z, | |
| hi_res | |
| ], | |
| outputs=[ | |
| model3d | |
| ] | |
| ) | |
| return interface | |
| def generate(args): | |
| GLOBAL_STATE["config_path"] = args.config_path | |
| GLOBAL_STATE["SAVE_DIR"] = args.save_dir | |
| os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True) | |
| demo = build_interface() | |
| demo.queue(default_concurrency_limit=1) | |
| demo.launch() | |
| if __name__=="__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config_path", | |
| type=str, | |
| help="Path to the config file", | |
| default="cube/cube3d/configs/open_model_v0.5.yaml", | |
| ) | |
| parser.add_argument( | |
| "--gpt_ckpt_path", | |
| type=str, | |
| help="Path to the gpt ckpt path", | |
| default="model_weights/shape_gpt.safetensors", | |
| ) | |
| parser.add_argument( | |
| "--shape_ckpt_path", | |
| type=str, | |
| help="Path to the shape ckpt path", | |
| default="model_weights/shape_tokenizer.safetensors", | |
| ) | |
| parser.add_argument( | |
| "--save_dir", | |
| type=str, | |
| default="gradio_save_dir", | |
| ) | |
| args = parser.parse_args() | |
| snapshot_download( | |
| repo_id="Roblox/cube3d-v0.5", | |
| local_dir="./model_weights" | |
| ) | |
| generate(args) | |