toshas commited on
Commit
23d9a29
·
1 Parent(s): b303828

implement a custom image component with settings

Browse files

change examples to route through the custom image rather than the slider
cleanup
fourth cat

examples/cat4.jpg ADDED

Git LFS Details

  • SHA256: 8f23ecf9d6751def35125772f404ce5fa2b3becf2a180e45bf133cfc2f8241b4
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
gradio_dualvision/app_template.py CHANGED
@@ -28,12 +28,12 @@ import re
28
 
29
  import gradio as gr
30
  import spaces
31
- from PIL import Image
32
- from gradio.components.base import Component
33
 
34
  from .gradio_patches.examples import Examples
35
  from .gradio_patches.gallery import Gallery
36
- from .gradio_patches.imageslider import ImageSlider
37
  from .gradio_patches.radio import Radio
38
  from .version import __version__
39
 
@@ -281,7 +281,7 @@ class DualVisionApp(gr.Blocks):
281
  with self:
282
  self.make_interface()
283
 
284
- def process(self, image_in: Image.Image, **kwargs):
285
  """
286
  Process an input image into multiple modalities using the provided arguments or default settings.
287
  Returns two dictionaries: one containing the modalities and another with the actual settings.
@@ -321,9 +321,9 @@ class DualVisionApp(gr.Blocks):
321
  if os.path.isfile(image_settings_path):
322
  with open(image_settings_path, "r") as f:
323
  image_settings = json.load(f)
324
- image_in = Image.open(image_in).convert("RGB")
325
  else:
326
- if not isinstance(image_in, Image.Image):
327
  raise gr.Error(f"Input must be a PIL image, got {type(image_in)}")
328
  image_in = image_in.convert("RGB")
329
  image_settings.update(kwargs)
@@ -345,7 +345,7 @@ class DualVisionApp(gr.Blocks):
345
  raise gr.Error(
346
  f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input"
347
  )
348
- if not isinstance(v, Image.Image):
349
  raise gr.Error(
350
  f"Value for key '{k}' must be a PIL Image, got type {type(v)}"
351
  )
@@ -417,6 +417,21 @@ class DualVisionApp(gr.Blocks):
417
  image_in, modality_selector_left, modality_selector_right, **input_dict
418
  )
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  def on_process_subsequent(
421
  self, results_state, modality_selector_left, modality_selector_right, *args
422
  ):
@@ -449,6 +464,10 @@ class DualVisionApp(gr.Blocks):
449
 
450
  results_state = Gallery(visible=False)
451
 
 
 
 
 
452
  image_slider = self.make_slider()
453
 
454
  if self.left_selector_visible or not self.advanced_settings_can_be_half_width:
@@ -469,7 +488,7 @@ class DualVisionApp(gr.Blocks):
469
  )
470
 
471
  self.make_examples(
472
- image_slider,
473
  [
474
  results_state,
475
  image_slider,
@@ -580,15 +599,13 @@ class DualVisionApp(gr.Blocks):
580
  raise gr.Error("Not all example paths are valid files")
581
  examples_dirname = os.path.basename(os.path.normpath(self.examples_path))
582
  return Examples(
583
- examples=[
584
- (e, e) for e in examples
585
- ],
586
  inputs=inputs,
587
  outputs=outputs,
588
  examples_per_page=self.examples_per_page,
589
  cache_examples=True,
590
  cache_mode=self.examples_cache,
591
- fn=self.on_process_first,
592
  directory_name=examples_dirname,
593
  )
594
 
 
28
 
29
  import gradio as gr
30
  import spaces
31
+ from PIL import Image as PILImage
32
+ from gradio import Component, ImageSlider
33
 
34
  from .gradio_patches.examples import Examples
35
  from .gradio_patches.gallery import Gallery
36
+ from .gradio_patches.image import Image
37
  from .gradio_patches.radio import Radio
38
  from .version import __version__
39
 
 
281
  with self:
282
  self.make_interface()
283
 
284
+ def process(self, image_in: PILImage.Image, **kwargs):
285
  """
286
  Process an input image into multiple modalities using the provided arguments or default settings.
287
  Returns two dictionaries: one containing the modalities and another with the actual settings.
 
321
  if os.path.isfile(image_settings_path):
322
  with open(image_settings_path, "r") as f:
323
  image_settings = json.load(f)
324
+ image_in = PILImage.open(image_in).convert("RGB")
325
  else:
326
+ if not isinstance(image_in, PILImage.Image):
327
  raise gr.Error(f"Input must be a PIL image, got {type(image_in)}")
328
  image_in = image_in.convert("RGB")
329
  image_settings.update(kwargs)
 
345
  raise gr.Error(
346
  f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input"
347
  )
348
+ if not isinstance(v, PILImage.Image):
349
  raise gr.Error(
350
  f"Value for key '{k}' must be a PIL Image, got type {type(v)}"
351
  )
 
417
  image_in, modality_selector_left, modality_selector_right, **input_dict
418
  )
419
 
420
+ def on_process_example(
421
+ self,
422
+ dummy_image_input,
423
+ modality_selector_left=None,
424
+ modality_selector_right=None,
425
+ *args,
426
+ ):
427
+ image_in = dummy_image_input
428
+ input_dict = {}
429
+ if len(args) > 0:
430
+ input_dict = {k: v for k, v in zip(self.input_keys, args)}
431
+ return self.process_components(
432
+ image_in, modality_selector_left, modality_selector_right, **input_dict
433
+ )
434
+
435
  def on_process_subsequent(
436
  self, results_state, modality_selector_left, modality_selector_right, *args
437
  ):
 
464
 
465
  results_state = Gallery(visible=False)
466
 
467
+ dummy_image_input = Image(
468
+ visible=False,
469
+ type="filepath",
470
+ )
471
  image_slider = self.make_slider()
472
 
473
  if self.left_selector_visible or not self.advanced_settings_can_be_half_width:
 
488
  )
489
 
490
  self.make_examples(
491
+ dummy_image_input,
492
  [
493
  results_state,
494
  image_slider,
 
599
  raise gr.Error("Not all example paths are valid files")
600
  examples_dirname = os.path.basename(os.path.normpath(self.examples_path))
601
  return Examples(
602
+ examples=examples,
 
 
603
  inputs=inputs,
604
  outputs=outputs,
605
  examples_per_page=self.examples_per_page,
606
  cache_examples=True,
607
  cache_mode=self.examples_cache,
608
+ fn=self.on_process_example,
609
  directory_name=examples_dirname,
610
  )
611
 
gradio_dualvision/gradio_patches/gallery.py CHANGED
@@ -1,35 +1,52 @@
1
- from __future__ import annotations
2
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from concurrent.futures import ThreadPoolExecutor
4
-
5
- from gradio_client import utils as client_utils
6
- from gradio.components.gallery import GalleryImage, GalleryData, GalleryMediaType, CaptionedGalleryMediaType, GalleryVideo
7
  from pathlib import Path
8
  from urllib.parse import quote, urlparse
9
 
10
- import gradio
11
- import numpy as np
12
- import PIL.Image
13
  from gradio_client.utils import is_http_url_like
14
-
15
- from gradio import processing_utils, utils, wasm_utils, image_utils
16
- from gradio.data_classes import FileData, ImageData
17
 
18
 
19
  class Gallery(gradio.Gallery):
 
20
  def postprocess(
21
- self,
22
- value: list[GalleryMediaType | CaptionedGalleryMediaType] | None,
23
  ) -> GalleryData:
24
  """
25
- This is a patched version of the original function, wherein the format for PIL is computed based on the data type:
26
- format = "png" if img.mode == "I;16" else "webp"
27
-
28
- Parameters:
29
- value: Expects the function to return a `list` of images or videos, or `list` of (media, `str` caption) tuples. Each image can be a `str` file path, a `numpy` array, or a `PIL.Image` object. Each video can be a `str` file path.
30
- Returns:
31
- a list of images or videos, or list of (media, caption) tuples
32
- """
33
  if value is None:
34
  return GalleryData(root=[])
35
  if isinstance(value, str):
@@ -51,7 +68,7 @@ class Gallery(gradio.Gallery):
51
  )
52
  file_path = str(utils.abspath(file))
53
  elif isinstance(img, PIL.Image.Image):
54
- format = "png" if img.mode == "I;16" else "webp" # Patch 1: change format based on the inbound dtype
55
  file = processing_utils.save_pil_to_cache(
56
  img, cache_dir=self.GRADIO_CACHE, format=format
57
  )
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License.
3
+ # See https://creativecommons.org/licenses/by-sa/4.0/ for details.
4
+ # --------------------------------------------------------------------------
5
+ # DualVision is a Gradio template app for image processing. It was developed
6
+ # to support the Marigold project. If you find this code useful, we kindly
7
+ # ask you to cite our most relevant papers.
8
+ # More information about Marigold:
9
+ # https://marigoldmonodepth.github.io
10
+ # https://marigoldcomputervision.github.io
11
+ # Efficient inference pipelines are now part of diffusers:
12
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
13
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
14
+ # Examples of trained models and live demos:
15
+ # https://huggingface.co/prs-eth
16
+ # Related projects:
17
+ # https://marigolddepthcompletion.github.io/
18
+ # https://rollingdepth.github.io/
19
+ # Citation (BibTeX):
20
+ # https://github.com/prs-eth/Marigold#-citation
21
+ # https://github.com/prs-eth/Marigold-DC#-citation
22
+ # https://github.com/prs-eth/rollingdepth#-citation
23
+ # --------------------------------------------------------------------------
24
+ import gradio
25
+ import numpy as np
26
+ import PIL.Image
27
  from concurrent.futures import ThreadPoolExecutor
 
 
 
28
  from pathlib import Path
29
  from urllib.parse import quote, urlparse
30
 
31
+ from gradio import FileData, image_utils, processing_utils, utils, wasm_utils
32
+ from gradio.components.gallery import GalleryData, GalleryImage, GalleryMediaType, CaptionedGalleryMediaType, GalleryVideo
33
+ from gradio_client import utils as client_utils
34
  from gradio_client.utils import is_http_url_like
35
+ from gradio.data_classes import ImageData
 
 
36
 
37
 
38
  class Gallery(gradio.Gallery):
39
+
40
  def postprocess(
41
+ self,
42
+ value: list[GalleryMediaType | CaptionedGalleryMediaType] | None,
43
  ) -> GalleryData:
44
  """
45
+ Parameters:
46
+ value: Expects the function to return a `list` of images or videos, or `list` of (media, `str` caption) tuples. Each image can be a `str` file path, a `numpy` array, or a `PIL.Image` object. Each video can be a `str` file path.
47
+ Returns:
48
+ a list of images or videos, or list of (media, caption) tuples
49
+ """
 
 
 
50
  if value is None:
51
  return GalleryData(root=[])
52
  if isinstance(value, str):
 
68
  )
69
  file_path = str(utils.abspath(file))
70
  elif isinstance(img, PIL.Image.Image):
71
+ format = "png" if img.mode == "I;16" else self.format # Patch 1: change format based on the inbound dtype
72
  file = processing_utils.save_pil_to_cache(
73
  img, cache_dir=self.GRADIO_CACHE, format=format
74
  )
gradio_dualvision/gradio_patches/{imageslider.py → image.py} RENAMED
@@ -25,88 +25,51 @@ import json
25
  import os.path
26
  import tempfile
27
  from pathlib import Path
28
- from typing import Union, Tuple, Optional
29
 
 
30
  import gradio
31
  import numpy as np
32
- from PIL import Image
33
  from gradio import image_utils
34
- from gradio_client import utils as client_utils
35
- from gradio.components.imageslider import image_tuple
36
- from gradio.data_classes import GradioRootModel, JsonData, ImageData
37
 
38
 
39
- class ImageSliderPlusData(GradioRootModel):
40
- root: Union[
41
- Tuple[ImageData | None, ImageData | None, JsonData | None],
42
- Tuple[ImageData | None, ImageData | None],
43
- None,
44
- ]
45
-
46
-
47
- class ImageSlider(gradio.ImageSlider):
48
- data_model = ImageSliderPlusData
49
-
50
  def postprocess(
51
  self,
52
- value: image_tuple,
53
- ) -> ImageSliderPlusData:
54
- if value is None:
55
- return ImageSliderPlusData(root=(None, None, None))
56
 
57
- settings = None
58
- if type(value[0]) is str:
59
- settings_candidate_path = value[0] + ".settings.json"
 
 
 
 
 
60
  if os.path.isfile(settings_candidate_path):
61
  with open(settings_candidate_path, "r") as fp:
62
  settings = json.load(fp)
 
63
 
64
- fn_format_selector = lambda x: "png" if (isinstance(x, np.ndarray) and x.dtype == np.uint16 and x.squeeze().ndim == 2) or (isinstance(x, Image.Image) and x.mode == "I;16") else self.format
65
- format_0 = fn_format_selector(value[0])
66
- format_1 = fn_format_selector(value[1])
67
 
68
- return ImageSliderPlusData(
69
- root=(
70
- image_utils.postprocess_image(
71
- value[0], cache_dir=self.GRADIO_CACHE, format=format_0
72
- ),
73
- image_utils.postprocess_image(
74
- value[1], cache_dir=self.GRADIO_CACHE, format=format_1
75
- ),
76
- JsonData(settings),
77
- ),
78
- )
79
 
80
- def preprocess(self, payload: ImageSliderPlusData) -> image_tuple:
81
- if payload is None:
82
- return None
83
- if payload.root is None:
84
- raise ValueError("Payload is None.")
85
 
86
- out_0 = image_utils.preprocess_image(
87
- payload.root[0],
88
- cache_dir=self.GRADIO_CACHE,
89
- format=self.format,
90
- image_mode=self.image_mode,
91
- type=self.type,
92
- )
93
- out_1 = image_utils.preprocess_image(
94
- payload.root[1],
95
- cache_dir=self.GRADIO_CACHE,
96
- format=self.format,
97
- image_mode=self.image_mode,
98
- type=self.type,
99
- )
100
-
101
- if len(payload.root) > 2 and payload.root[2] is not None:
102
- with open(out_0 + ".settings.json", "w") as fp:
103
- json.dump(payload.root[2].root, fp)
104
-
105
- return out_0, out_1
106
 
107
  @staticmethod
108
  def resize_and_save(image_path: str, max_dim: int, square: bool = False) -> str:
109
- img = Image.open(image_path).convert("RGB")
 
110
  if square:
111
  width, height = img.size
112
  min_side = min(width, height)
@@ -115,31 +78,16 @@ class ImageSlider(gradio.ImageSlider):
115
  right = left + min_side
116
  bottom = top + min_side
117
  img = img.crop((left, top, right, bottom))
 
118
  img.thumbnail((max_dim, max_dim))
119
  temp_file = tempfile.NamedTemporaryFile(suffix=".webp", delete=False)
120
  img.save(temp_file.name, "WEBP")
 
121
  return temp_file.name
122
 
123
- def process_example_dims(
124
- self, input_data: tuple[str | Path | None] | None, max_dim: Optional[int] = None, square: bool = False
125
- ) -> image_tuple:
126
- if input_data is None:
127
- return None
128
- input_data = (str(input_data[0]), str(input_data[1]))
129
- if self.proxy_url or client_utils.is_http_url_like(input_data[0]):
130
- return input_data[0]
131
- if max_dim is not None:
132
- input_data = (
133
- self.resize_and_save(input_data[0], max_dim, square),
134
- self.resize_and_save(input_data[1], max_dim, square),
135
- )
136
- return (
137
- self.move_resource_to_block_cache(input_data[0]),
138
- self.move_resource_to_block_cache(input_data[1]),
139
- )
140
-
141
  def process_example(
142
- self, input_data: tuple[str | Path | None] | None
143
- ) -> image_tuple:
144
- return self.process_example_dims(input_data, 256, True)
145
-
 
 
25
  import os.path
26
  import tempfile
27
  from pathlib import Path
 
28
 
29
+ import PIL
30
  import gradio
31
  import numpy as np
32
+ from PIL import Image as PILImage
33
  from gradio import image_utils
34
+ from gradio.data_classes import ImageData
 
 
35
 
36
 
37
+ class Image(gradio.Image):
 
 
 
 
 
 
 
 
 
 
38
  def postprocess(
39
  self,
40
+ value,
41
+ ) -> ImageData:
42
+ fn_format_selector = lambda x: "png" if (isinstance(x, np.ndarray) and x.dtype == np.uint16 and x.squeeze().ndim == 2) or (isinstance(x, PILImage.Image) and x.mode == "I;16") else self.format
43
+ format = fn_format_selector(value)
44
 
45
+ out = image_utils.postprocess_image(
46
+ value,
47
+ cache_dir=self.GRADIO_CACHE,
48
+ format=format,
49
+ )
50
+
51
+ if type(value) is str:
52
+ settings_candidate_path = value + ".settings.json"
53
  if os.path.isfile(settings_candidate_path):
54
  with open(settings_candidate_path, "r") as fp:
55
  settings = json.load(fp)
56
+ out.meta["settings"] = settings
57
 
58
+ return out
 
 
59
 
60
+ def preprocess(self, payload: ImageData) -> str | PIL.Image.Image | np.ndarray | None:
61
+ out = super().preprocess(payload)
 
 
 
 
 
 
 
 
 
62
 
63
+ if "settings" in payload.meta:
64
+ with open(out + ".settings.json", "w") as fp:
65
+ json.dump(payload.meta["settings"], fp)
 
 
66
 
67
+ return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  @staticmethod
70
  def resize_and_save(image_path: str, max_dim: int, square: bool = False) -> str:
71
+ img = PILImage.open(image_path).convert("RGB")
72
+
73
  if square:
74
  width, height = img.size
75
  min_side = min(width, height)
 
78
  right = left + min_side
79
  bottom = top + min_side
80
  img = img.crop((left, top, right, bottom))
81
+
82
  img.thumbnail((max_dim, max_dim))
83
  temp_file = tempfile.NamedTemporaryFile(suffix=".webp", delete=False)
84
  img.save(temp_file.name, "WEBP")
85
+
86
  return temp_file.name
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def process_example(
89
+ self, input_data: str | Path | None
90
+ ) -> str | PIL.Image.Image | np.ndarray | None:
91
+ thumbnail = self.resize_and_save(input_data, 256, True)
92
+ out = super().process_example(thumbnail)
93
+ return out
gradio_dualvision/gradio_patches/radio.py CHANGED
@@ -23,13 +23,9 @@
23
  # --------------------------------------------------------------------------
24
  import gradio
25
  from gradio import components
26
- from gradio.components.base import Component
27
- from gradio.data_classes import (
28
- GradioModel,
29
- GradioRootModel,
30
- )
31
-
32
  from gradio.blocks import BlockContext
 
 
33
 
34
 
35
  def patched_postprocess_update_dict(
 
23
  # --------------------------------------------------------------------------
24
  import gradio
25
  from gradio import components
 
 
 
 
 
 
26
  from gradio.blocks import BlockContext
27
+ from gradio.components.base import Component
28
+ from gradio.data_classes import GradioModel, GradioRootModel
29
 
30
 
31
  def patched_postprocess_update_dict(