Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -44,19 +44,21 @@ titok_generator = titok_generator.to(device)
|
|
| 44 |
|
| 45 |
|
| 46 |
@spaces.GPU
|
| 47 |
-
def demo_infer(
|
|
|
|
|
|
|
| 48 |
class_label, seed):
|
| 49 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
-
|
| 51 |
-
|
| 52 |
n = 4
|
| 53 |
class_labels = [class_label for _ in range(n)]
|
| 54 |
torch.manual_seed(seed)
|
| 55 |
torch.cuda.manual_seed(seed)
|
| 56 |
t1 = time.time()
|
| 57 |
generated_image = demo_util.sample_fn(
|
| 58 |
-
generator=
|
| 59 |
-
tokenizer=
|
| 60 |
labels=class_labels,
|
| 61 |
guidance_scale=guidance_scale,
|
| 62 |
randomize_temperature=randomize_temperature,
|
|
@@ -90,6 +92,7 @@ with gr.Blocks() as demo:
|
|
| 90 |
with gr.Column():
|
| 91 |
output = gr.Gallery(label='Generated Images', height=700)
|
| 92 |
button.click(demo_infer, inputs=[
|
|
|
|
| 93 |
guidance_scale, randomize_temperature, num_sample_steps,
|
| 94 |
i1k_class, seed],
|
| 95 |
outputs=[output])
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
@spaces.GPU
|
| 47 |
+
def demo_infer(tokenizer,
|
| 48 |
+
generator,
|
| 49 |
+
guidance_scale, randomize_temperature, num_sample_steps,
|
| 50 |
class_label, seed):
|
| 51 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 52 |
+
tokenizer = tokenizer.to(device)
|
| 53 |
+
generator = generator.to(device)
|
| 54 |
n = 4
|
| 55 |
class_labels = [class_label for _ in range(n)]
|
| 56 |
torch.manual_seed(seed)
|
| 57 |
torch.cuda.manual_seed(seed)
|
| 58 |
t1 = time.time()
|
| 59 |
generated_image = demo_util.sample_fn(
|
| 60 |
+
generator=generator,
|
| 61 |
+
tokenizer=tokenizer,
|
| 62 |
labels=class_labels,
|
| 63 |
guidance_scale=guidance_scale,
|
| 64 |
randomize_temperature=randomize_temperature,
|
|
|
|
| 92 |
with gr.Column():
|
| 93 |
output = gr.Gallery(label='Generated Images', height=700)
|
| 94 |
button.click(demo_infer, inputs=[
|
| 95 |
+
titok_tokenizer, titok_generator,
|
| 96 |
guidance_scale, randomize_temperature, num_sample_steps,
|
| 97 |
i1k_class, seed],
|
| 98 |
outputs=[output])
|