CJ Hauser commited on
Commit
3916b4c
·
verified ·
1 Parent(s): 8cfc694

Update app.py

Browse files

change app.py req

Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -1,49 +1,57 @@
1
  import torch
2
  import gradio as gr
3
- from shap_e.diffusion.sample import sample_latents
4
  from shap_e.models.download import load_model
 
5
  from shap_e.util.notebooks import decode_latent_mesh
6
  from PIL import Image
7
 
8
  # pick device
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
- # load models once
12
  print("Loading Shap-E models (this may take a bit)...")
13
- xm = load_model('transmitter', device=device)
14
- im_model = load_model('image300M', device=device)
15
 
16
- def generate_3d(image):
17
- """Takes an uploaded image and returns path to generated 3D model"""
18
  img = image.convert("RGB")
19
 
20
- # generate latents
21
  latents = sample_latents(
22
  batch_size=1,
23
- model=im_model,
24
- guidance_scale=3.0,
25
  model_kwargs=dict(images=[img]),
 
 
 
 
 
 
 
 
 
26
  device=device
27
  )
28
 
29
- # decode into mesh
30
- mesh = decode_latent_mesh(xm, latents[0])
31
 
32
- # save model
33
  output_path = "output.obj"
34
  with open(output_path, "w") as f:
35
  mesh.write_obj(f)
36
 
37
  return output_path
38
 
39
- # Gradio UI
40
  demo = gr.Interface(
41
  fn=generate_3d,
42
  inputs=gr.Image(type="pil"),
43
- outputs=gr.File(),
44
  title="Shap-E: 2D → 3D Model",
45
  description="Upload a 2D image and download a generated 3D model (.obj)"
46
  )
47
 
48
  if __name__ == "__main__":
49
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import torch
2
  import gradio as gr
 
3
  from shap_e.models.download import load_model
4
+ from shap_e.diffusion.sample import sample_latents
5
  from shap_e.util.notebooks import decode_latent_mesh
6
  from PIL import Image
7
 
8
  # pick device
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
+ # Load Shap-E models
12
  print("Loading Shap-E models (this may take a bit)...")
13
+ transmitter_model = load_model("transmitter", device=device)
14
+ image_model = load_model("image300M", device=device)
15
 
16
+ def generate_3d(image: Image.Image):
17
+ """Takes an uploaded image and returns path to generated 3D model (.obj)"""
18
  img = image.convert("RGB")
19
 
20
+ # Sample latents (updated for latest Shap-E API)
21
  latents = sample_latents(
22
  batch_size=1,
23
+ model=image_model,
 
24
  model_kwargs=dict(images=[img]),
25
+ diffusion=None,
26
+ clip_denoised=True,
27
+ use_fp16=False,
28
+ use_karras=True,
29
+ karras_steps=64,
30
+ sigma_min=0.002,
31
+ sigma_max=80,
32
+ s_churn=0.0,
33
+ guidance_scale=3.0,
34
  device=device
35
  )
36
 
37
+ # Decode into mesh
38
+ mesh = decode_latent_mesh(transmitter_model, latents[0])
39
 
40
+ # Save output
41
  output_path = "output.obj"
42
  with open(output_path, "w") as f:
43
  mesh.write_obj(f)
44
 
45
  return output_path
46
 
47
+ # Gradio interface
48
  demo = gr.Interface(
49
  fn=generate_3d,
50
  inputs=gr.Image(type="pil"),
51
+ outputs=gr.File(file_types=[".obj"]),
52
  title="Shap-E: 2D → 3D Model",
53
  description="Upload a 2D image and download a generated 3D model (.obj)"
54
  )
55
 
56
  if __name__ == "__main__":
57
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)