tristan-deep commited on
Commit
4e14677
·
1 Parent(s): 3646605
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -1,9 +1,13 @@
1
  import os
2
 
 
 
3
  import gradio as gr
4
  import jax
 
5
  import numpy as np
6
  import spaces
 
7
  from PIL import Image
8
 
9
  from main import Config, init, run
@@ -28,9 +32,12 @@ def initialize_model():
28
 
29
 
30
  @spaces.GPU(duration=30)
31
-
32
- # Generator function for status updates
33
  def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
 
 
 
 
 
34
  if input_img is None:
35
  yield (
36
  gr.update(
 
1
  import os
2
 
3
+ os.environ["KERAS_BACKEND"] = "jax"
4
+
5
  import gradio as gr
6
  import jax
7
+ import keras
8
  import numpy as np
9
  import spaces
10
+ import tensorflow as tf
11
  from PIL import Image
12
 
13
  from main import Config, init, run
 
32
 
33
 
34
  @spaces.GPU(duration=30)
 
 
35
  def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
36
+ print(f"Keras backend: {os.environ['KERAS_BACKEND']} and {keras.backend.backend()}")
37
+ print(f"Keras version: {keras.__version__}, JAX version: {jax.__version__}")
38
+ print(f"JAX cuda: {jax.devices()}")
39
+ print(f"Tensorflow devices: {tf.config.list_physical_devices()}")
40
+
41
  if input_img is None:
42
  yield (
43
  gr.update(