Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -88,18 +88,22 @@ def get_normals(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
|
|
| 88 |
|
| 89 |
return normals
|
| 90 |
|
|
|
|
| 91 |
@spaces.GPU
|
| 92 |
def generate(
|
| 93 |
prompt: str,
|
| 94 |
seed: int,
|
| 95 |
guidance_weight: float,
|
| 96 |
sample_label: str,
|
| 97 |
-
# -----------------------
|
| 98 |
dataset: MultimodalDataset,
|
| 99 |
device: torch.device,
|
| 100 |
diffuser: Diffuser,
|
| 101 |
clip_model: clip.model.CLIP,
|
| 102 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
| 103 |
# Set arguments
|
| 104 |
set_random_seed(seed)
|
| 105 |
diffuser.gen_seeds = np.array([seed])
|
|
@@ -206,9 +210,6 @@ def launch_app(gen_fn: Callable):
|
|
| 206 |
# ------------------------------------------------------------------------------------- #
|
| 207 |
|
| 208 |
diffuser, clip_model, dataset, device = init("config")
|
| 209 |
-
diffuser.to("cuda")
|
| 210 |
-
clip_model.to("cuda")
|
| 211 |
-
|
| 212 |
generate_sample = partial(
|
| 213 |
generate,
|
| 214 |
dataset=dataset,
|
|
@@ -216,5 +217,4 @@ generate_sample = partial(
|
|
| 216 |
diffuser=diffuser,
|
| 217 |
clip_model=clip_model,
|
| 218 |
)
|
| 219 |
-
|
| 220 |
launch_app(generate_sample)
|
|
|
|
| 88 |
|
| 89 |
return normals
|
| 90 |
|
| 91 |
+
|
| 92 |
@spaces.GPU
|
| 93 |
def generate(
|
| 94 |
prompt: str,
|
| 95 |
seed: int,
|
| 96 |
guidance_weight: float,
|
| 97 |
sample_label: str,
|
| 98 |
+
# ----------------------- #
|
| 99 |
dataset: MultimodalDataset,
|
| 100 |
device: torch.device,
|
| 101 |
diffuser: Diffuser,
|
| 102 |
clip_model: clip.model.CLIP,
|
| 103 |
) -> Dict[str, Any]:
|
| 104 |
+
diffuser.to(device)
|
| 105 |
+
clip_model.to(device)
|
| 106 |
+
|
| 107 |
# Set arguments
|
| 108 |
set_random_seed(seed)
|
| 109 |
diffuser.gen_seeds = np.array([seed])
|
|
|
|
| 210 |
# ------------------------------------------------------------------------------------- #
|
| 211 |
|
| 212 |
diffuser, clip_model, dataset, device = init("config")
|
|
|
|
|
|
|
|
|
|
| 213 |
generate_sample = partial(
|
| 214 |
generate,
|
| 215 |
dataset=dataset,
|
|
|
|
| 217 |
diffuser=diffuser,
|
| 218 |
clip_model=clip_model,
|
| 219 |
)
|
|
|
|
| 220 |
launch_app(generate_sample)
|