Commit
·
4e14677
1
Parent(s):
3646605
test
Browse files
app.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
import os
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import jax
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import spaces
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
|
| 9 |
from main import Config, init, run
|
|
@@ -28,9 +32,12 @@ def initialize_model():
|
|
| 28 |
|
| 29 |
|
| 30 |
@spaces.GPU(duration=30)
|
| 31 |
-
|
| 32 |
-
# Generator function for status updates
|
| 33 |
def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
if input_img is None:
|
| 35 |
yield (
|
| 36 |
gr.update(
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
+
os.environ["KERAS_BACKEND"] = "jax"
|
| 4 |
+
|
| 5 |
import gradio as gr
|
| 6 |
import jax
|
| 7 |
+
import keras
|
| 8 |
import numpy as np
|
| 9 |
import spaces
|
| 10 |
+
import tensorflow as tf
|
| 11 |
from PIL import Image
|
| 12 |
|
| 13 |
from main import Config, init, run
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
@spaces.GPU(duration=30)
|
|
|
|
|
|
|
| 35 |
def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
|
| 36 |
+
print(f"Keras backend: {os.environ['KERAS_BACKEND']} and {keras.backend.backend()}")
|
| 37 |
+
print(f"Keras version: {keras.__version__}, JAX version: {jax.__version__}")
|
| 38 |
+
print(f"JAX cuda: {jax.devices()}")
|
| 39 |
+
print(f"Tensorflow devices: {tf.config.list_physical_devices()}")
|
| 40 |
+
|
| 41 |
if input_img is None:
|
| 42 |
yield (
|
| 43 |
gr.update(
|