import torch import gradio as gr from shap_e.models.download import load_model from shap_e.diffusion.sample import sample_latents from shap_e.util.notebooks import decode_latent_mesh from PIL import Image # pick device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load Shap-E models print("Loading Shap-E models (this may take a bit)...") transmitter_model = load_model("transmitter", device=device) image_model = load_model("image300M", device=device) def generate_3d(image: Image.Image): """Takes an uploaded image and returns path to generated 3D model (.obj)""" img = image.convert("RGB") # Sample latents (updated for latest Shap-E API) latents = sample_latents( batch_size=1, model=image_model, model_kwargs=dict(images=[img]), diffusion=None, clip_denoised=True, use_fp16=False, use_karras=False, karras_steps=64, sigma_min=0.002, sigma_max=80, s_churn=0.0, guidance_scale=3.0, device=device ) # Decode into mesh mesh = decode_latent_mesh(transmitter_model, latents[0]) # Save output output_path = "output.obj" with open(output_path, "w") as f: mesh.write_obj(f) return output_path # Gradio interface demo = gr.Interface( fn=generate_3d, inputs=gr.Image(type="pil"), outputs=gr.File(file_types=[".obj"]), title="Shap-E: 2D → 3D Model", description="Upload a 2D image and download a generated 3D model (.obj)" ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=True)