Spaces:
Build error
Build error
| import argparse | |
| import os | |
| import sys | |
| from glob import glob | |
| import time | |
| from typing import Any, Union | |
| import numpy as np | |
| import torch | |
| import trimesh | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from accelerate.utils import set_seed | |
| from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces | |
| from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings | |
| from src.pipelines.pipeline_partcrafter import PartCrafterPipeline | |
| from src.utils.image_utils import prepare_image | |
| from src.models.briarmbg import BriaRMBG | |
| def run_triposg( | |
| pipe: Any, | |
| image_input: Union[str, Image.Image], | |
| num_parts: int, | |
| rmbg_net: Any, | |
| seed: int, | |
| num_tokens: int = 1024, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.0, | |
| max_num_expanded_coords: int = 1e9, | |
| use_flash_decoder: bool = False, | |
| rmbg: bool = False, | |
| dtype: torch.dtype = torch.float16, | |
| device: str = "cuda", | |
| ) -> trimesh.Scene: | |
| if rmbg: | |
| img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net) | |
| else: | |
| img_pil = Image.open(image_input) | |
| start_time = time.time() | |
| outputs = pipe( | |
| image=[img_pil] * num_parts, | |
| attention_kwargs={"num_parts": num_parts}, | |
| num_tokens=num_tokens, | |
| generator=torch.Generator(device=pipe.device).manual_seed(seed), | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| max_num_expanded_coords=max_num_expanded_coords, | |
| use_flash_decoder=use_flash_decoder, | |
| ).meshes | |
| end_time = time.time() | |
| print(f"Time elapsed: {end_time - start_time:.2f} seconds") | |
| for i in range(len(outputs)): | |
| if outputs[i] is None: | |
| # If the generated mesh is None (decoing error), use a dummy mesh | |
| outputs[i] = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]]) | |
| return outputs, img_pil | |
| MAX_NUM_PARTS = 16 | |
| if __name__ == "__main__": | |
| device = "cuda" | |
| dtype = torch.float16 | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--image_path", type=str, required=True) | |
| parser.add_argument("--num_parts", type=int, required=True, help="number of parts to generate") | |
| parser.add_argument("--output_dir", type=str, default="./results") | |
| parser.add_argument("--tag", type=str, default=None) | |
| parser.add_argument("--seed", type=int, default=0) | |
| parser.add_argument("--num_tokens", type=int, default=1024) | |
| parser.add_argument("--num_inference_steps", type=int, default=50) | |
| parser.add_argument("--guidance_scale", type=float, default=7.0) | |
| parser.add_argument("--max_num_expanded_coords", type=int, default=1e9) | |
| parser.add_argument("--use_flash_decoder", action="store_true") | |
| parser.add_argument("--rmbg", action="store_true") | |
| parser.add_argument("--render", action="store_true") | |
| args = parser.parse_args() | |
| assert 1 <= args.num_parts <= MAX_NUM_PARTS, f"num_parts must be in [1, {MAX_NUM_PARTS}]" | |
| # download pretrained weights | |
| partcrafter_weights_dir = "pretrained_weights/PartCrafter" | |
| rmbg_weights_dir = "pretrained_weights/RMBG-1.4" | |
| snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir) | |
| snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir) | |
| # init rmbg model for background removal | |
| rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device) | |
| rmbg_net.eval() | |
| # init tripoSG pipeline | |
| pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(device, dtype) | |
| set_seed(args.seed) | |
| # run inference | |
| outputs, processed_image = run_triposg( | |
| pipe, | |
| image_input=args.image_path, | |
| num_parts=args.num_parts, | |
| rmbg_net=rmbg_net, | |
| seed=args.seed, | |
| num_tokens=args.num_tokens, | |
| num_inference_steps=args.num_inference_steps, | |
| guidance_scale=args.guidance_scale, | |
| max_num_expanded_coords=args.max_num_expanded_coords, | |
| use_flash_decoder=args.use_flash_decoder, | |
| rmbg=args.rmbg, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| if not os.path.exists(args.output_dir): | |
| os.makedirs(args.output_dir) | |
| if args.tag is None: | |
| args.tag = time.strftime("%Y%m%d_%H_%M_%S") | |
| export_dir = os.path.join(args.output_dir, args.tag) | |
| os.makedirs(export_dir, exist_ok=True) | |
| for i, mesh in enumerate(outputs): | |
| mesh.export(os.path.join(export_dir, f"part_{i:02}.glb")) | |
| merged_mesh = get_colored_mesh_composition(outputs) | |
| merged_mesh.export(os.path.join(export_dir, "object.glb")) | |
| print(f"Generated {len(outputs)} parts and saved to {export_dir}") | |
| if args.render: | |
| print("Start rendering...") | |
| num_views = 36 | |
| radius = 4 | |
| fps = 18 | |
| rendered_images = render_views_around_mesh( | |
| merged_mesh, | |
| num_views=num_views, | |
| radius=radius, | |
| ) | |
| rendered_normals = render_normal_views_around_mesh( | |
| merged_mesh, | |
| num_views=num_views, | |
| radius=radius, | |
| ) | |
| rendered_grids = make_grid_for_images_or_videos( | |
| [ | |
| [processed_image] * num_views, | |
| rendered_images, | |
| rendered_normals, | |
| ], | |
| nrow=3 | |
| ) | |
| export_renderings( | |
| rendered_images, | |
| os.path.join(export_dir, "rendering.gif"), | |
| fps=fps, | |
| ) | |
| export_renderings( | |
| rendered_normals, | |
| os.path.join(export_dir, "rendering_normal.gif"), | |
| fps=fps, | |
| ) | |
| export_renderings( | |
| rendered_grids, | |
| os.path.join(export_dir, "rendering_grid.gif"), | |
| fps=fps, | |
| ) | |
| rendered_image, rendered_normal, rendered_grid = rendered_images[0], rendered_normals[0], rendered_grids[0] | |
| rendered_image.save(os.path.join(export_dir, "rendering.png")) | |
| rendered_normal.save(os.path.join(export_dir, "rendering_normal.png")) | |
| rendered_grid.save(os.path.join(export_dir, "rendering_grid.png")) | |
| print("Rendering done.") | |