toshas commited on
Commit
62b8cbf
·
1 Parent(s): d562bf5

formatting

Browse files
gradio_dualvision/app_template.py CHANGED
@@ -27,6 +27,14 @@ import os
27
  import re
28
 
29
  import gradio as gr
 
 
 
 
 
 
 
 
30
  import spaces
31
  from PIL import Image as PILImage
32
  from gradio import Component, ImageSlider
@@ -35,7 +43,6 @@ from .gradio_patches.examples import Examples
35
  from .gradio_patches.gallery import Gallery
36
  from .gradio_patches.image import Image
37
  from .gradio_patches.radio import Radio
38
- from .version import __version__
39
 
40
 
41
  class DualVisionApp(gr.Blocks):
@@ -82,12 +89,6 @@ class DualVisionApp(gr.Blocks):
82
  gallery_thumb_min_size: Min size of the gallery thumbnail (Default: `96px`).
83
  **kwargs: Any other arguments that Gradio Blocks class can take.
84
  """
85
- if __version__ != gr.__version__:
86
- raise gr.Error(
87
- f"gradio version ({gr.__version__}) must match gradio-dualvision version ({__version__}). "
88
- f"Check the metadata of the README.md in your demo (sdk_version field)."
89
- )
90
-
91
  squeeze_viewport_height_pct = int(squeeze_viewport_height_pct)
92
  if not 50 <= squeeze_viewport_height_pct <= 100:
93
  raise gr.Error(
@@ -113,8 +114,9 @@ class DualVisionApp(gr.Blocks):
113
  self.process_components = spaces.GPU(
114
  self.process_components, duration=spaces_zero_gpu_duration
115
  )
116
- self.head = ""
117
- self.head += """
 
118
  <script>
119
  let observerFooterButtons = new MutationObserver((mutationsList, observer) => {
120
  const origButtonShowAPI = document.querySelector(".show-api");
@@ -158,6 +160,8 @@ class DualVisionApp(gr.Blocks):
158
  observerFooterButtons.observe(document.body, { childList: true, subtree: true });
159
  </script>
160
  """
 
 
161
  if kwargs.get("analytics_enabled") is not False:
162
  self.head += f"""
163
  <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
 
27
  import re
28
 
29
  import gradio as gr
30
+ from .version import __version__
31
+
32
+ if __version__ != gr.__version__:
33
+ raise gr.Error(
34
+ f"gradio version ({gr.__version__}) must match gradio-dualvision version ({__version__}). "
35
+ f"Check the metadata of the README.md in your demo (sdk_version field)."
36
+ )
37
+
38
  import spaces
39
  from PIL import Image as PILImage
40
  from gradio import Component, ImageSlider
 
43
  from .gradio_patches.gallery import Gallery
44
  from .gradio_patches.image import Image
45
  from .gradio_patches.radio import Radio
 
46
 
47
 
48
  class DualVisionApp(gr.Blocks):
 
89
  gallery_thumb_min_size: Min size of the gallery thumbnail (Default: `96px`).
90
  **kwargs: Any other arguments that Gradio Blocks class can take.
91
  """
 
 
 
 
 
 
92
  squeeze_viewport_height_pct = int(squeeze_viewport_height_pct)
93
  if not 50 <= squeeze_viewport_height_pct <= 100:
94
  raise gr.Error(
 
114
  self.process_components = spaces.GPU(
115
  self.process_components, duration=spaces_zero_gpu_duration
116
  )
117
+ # fmt: off
118
+ self.head = (
119
+ """
120
  <script>
121
  let observerFooterButtons = new MutationObserver((mutationsList, observer) => {
122
  const origButtonShowAPI = document.querySelector(".show-api");
 
160
  observerFooterButtons.observe(document.body, { childList: true, subtree: true });
161
  </script>
162
  """
163
+ )
164
+ # fmt: on
165
  if kwargs.get("analytics_enabled") is not False:
166
  self.head += f"""
167
  <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
gradio_dualvision/gradio_patches/gallery.py CHANGED
@@ -29,7 +29,13 @@ from pathlib import Path
29
  from urllib.parse import quote, urlparse
30
 
31
  from gradio import FileData, image_utils, processing_utils, utils, wasm_utils
32
- from gradio.components.gallery import GalleryData, GalleryImage, GalleryMediaType, CaptionedGalleryMediaType, GalleryVideo
 
 
 
 
 
 
33
  from gradio_client import utils as client_utils
34
  from gradio_client.utils import is_http_url_like
35
  from gradio.data_classes import ImageData
@@ -68,7 +74,9 @@ class Gallery(gradio.Gallery):
68
  )
69
  file_path = str(utils.abspath(file))
70
  elif isinstance(img, PIL.Image.Image):
71
- format = "png" if img.mode == "I;16" else self.format # Patch 1: change format based on the inbound dtype
 
 
72
  file = processing_utils.save_pil_to_cache(
73
  img, cache_dir=self.GRADIO_CACHE, format=format
74
  )
 
29
  from urllib.parse import quote, urlparse
30
 
31
  from gradio import FileData, image_utils, processing_utils, utils, wasm_utils
32
+ from gradio.components.gallery import (
33
+ GalleryData,
34
+ GalleryImage,
35
+ GalleryMediaType,
36
+ CaptionedGalleryMediaType,
37
+ GalleryVideo,
38
+ )
39
  from gradio_client import utils as client_utils
40
  from gradio_client.utils import is_http_url_like
41
  from gradio.data_classes import ImageData
 
74
  )
75
  file_path = str(utils.abspath(file))
76
  elif isinstance(img, PIL.Image.Image):
77
+ format = (
78
+ "png" if img.mode == "I;16" else self.format
79
+ ) # Patch 1: change format based on the inbound dtype
80
  file = processing_utils.save_pil_to_cache(
81
  img, cache_dir=self.GRADIO_CACHE, format=format
82
  )
gradio_dualvision/gradio_patches/image.py CHANGED
@@ -39,7 +39,16 @@ class Image(gradio.Image):
39
  self,
40
  value,
41
  ) -> ImageData:
42
- fn_format_selector = lambda x: "png" if (isinstance(x, np.ndarray) and x.dtype == np.uint16 and x.squeeze().ndim == 2) or (isinstance(x, PILImage.Image) and x.mode == "I;16") else self.format
 
 
 
 
 
 
 
 
 
43
  format = fn_format_selector(value)
44
 
45
  out = image_utils.postprocess_image(
@@ -57,7 +66,9 @@ class Image(gradio.Image):
57
 
58
  return out
59
 
60
- def preprocess(self, payload: ImageData) -> str | PIL.Image.Image | np.ndarray | None:
 
 
61
  out = super().preprocess(payload)
62
 
63
  if "settings" in payload.meta:
@@ -78,11 +89,11 @@ class Image(gradio.Image):
78
  right = left + min_side
79
  bottom = top + min_side
80
  img = img.crop((left, top, right, bottom))
81
-
82
  img.thumbnail((max_dim, max_dim))
83
  temp_file = tempfile.NamedTemporaryFile(suffix=".webp", delete=False)
84
  img.save(temp_file.name, "WEBP")
85
-
86
  return temp_file.name
87
 
88
  def process_example(
 
39
  self,
40
  value,
41
  ) -> ImageData:
42
+ fn_format_selector = lambda x: (
43
+ "png"
44
+ if (
45
+ isinstance(x, np.ndarray)
46
+ and x.dtype == np.uint16
47
+ and x.squeeze().ndim == 2
48
+ )
49
+ or (isinstance(x, PILImage.Image) and x.mode == "I;16")
50
+ else self.format
51
+ )
52
  format = fn_format_selector(value)
53
 
54
  out = image_utils.postprocess_image(
 
66
 
67
  return out
68
 
69
+ def preprocess(
70
+ self, payload: ImageData
71
+ ) -> str | PIL.Image.Image | np.ndarray | None:
72
  out = super().preprocess(payload)
73
 
74
  if "settings" in payload.meta:
 
89
  right = left + min_side
90
  bottom = top + min_side
91
  img = img.crop((left, top, right, bottom))
92
+
93
  img.thumbnail((max_dim, max_dim))
94
  temp_file = tempfile.NamedTemporaryFile(suffix=".webp", delete=False)
95
  img.save(temp_file.name, "WEBP")
96
+
97
  return temp_file.name
98
 
99
  def process_example(