tristan-deep commited on
Commit
57a5488
·
1 Parent(s): 037377e

plot updates and fixes

Browse files
Files changed (2) hide show
  1. main.py +3 -2
  2. plots.py +2 -2
main.py CHANGED
@@ -411,7 +411,7 @@ def run(
411
  num_images = hazy_images.shape[0]
412
  num_batches = (num_images + batch_size - 1) // batch_size
413
 
414
- progbar = keras.utils.Progbar(num_batches, verbose=verbose)
415
  i = 0
416
  batch_idx = 0
417
  for i in range(num_batches):
@@ -475,7 +475,7 @@ def run(
475
  def main(
476
  input_folder: str = "./assets",
477
  output_folder: str = "./temp",
478
- num_imgs_plot: int = 4,
479
  device: str = "auto:1",
480
  config: str = "configs/semantic_dps.yaml",
481
  ):
@@ -488,6 +488,7 @@ def main(
488
  seed = jax.random.PRNGKey(config.seed)
489
 
490
  paths = list(Path(input_folder).glob("*.png"))
 
491
 
492
  output_folder = Path(output_folder)
493
 
 
411
  num_images = hazy_images.shape[0]
412
  num_batches = (num_images + batch_size - 1) // batch_size
413
 
414
+ progbar = keras.utils.Progbar(num_batches, verbose=verbose, unit_name="batch")
415
  i = 0
416
  batch_idx = 0
417
  for i in range(num_batches):
 
475
  def main(
476
  input_folder: str = "./assets",
477
  output_folder: str = "./temp",
478
+ num_imgs_plot: int = 5,
479
  device: str = "auto:1",
480
  config: str = "configs/semantic_dps.yaml",
481
  ):
 
488
  seed = jax.random.PRNGKey(config.seed)
489
 
490
  paths = list(Path(input_folder).glob("*.png"))
491
+ paths = sorted(paths)
492
 
493
  output_folder = Path(output_folder)
494
 
plots.py CHANGED
@@ -142,7 +142,7 @@ def plot_batch_with_named_masks(
142
 
143
  im = axes[0].get_images()[0] if axes[0].get_images() else None
144
  cbar = fig.colorbar(im, cax=cax)
145
- cbar.set_label(r"Guidance weighting \mathbf{p}")
146
  cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6))
147
  cbar.ax.yaxis.set_tick_params(labelsize=7)
148
  cbar.ax.yaxis.label.set_size(8)
@@ -189,7 +189,7 @@ def plot_dehazed_results(
189
 
190
  def plot_metrics(metrics, limits, out_path):
191
  plt.style.use("seaborn-v0_8-darkgrid")
192
- fig, axes = plt.subplots(1, len(metrics), figsize=(7.2, 2.7), dpi=600)
193
  colors = ["#0057b7", "#ffb300", "#008744", "#d62d20"]
194
  metric_labels = {
195
  "CNR": r"CNR $\uparrow$",
 
142
 
143
  im = axes[0].get_images()[0] if axes[0].get_images() else None
144
  cbar = fig.colorbar(im, cax=cax)
145
+ cbar.set_label(r"Guidance weighting $\mathbf{p}$")
146
  cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6))
147
  cbar.ax.yaxis.set_tick_params(labelsize=7)
148
  cbar.ax.yaxis.label.set_size(8)
 
189
 
190
  def plot_metrics(metrics, limits, out_path):
191
  plt.style.use("seaborn-v0_8-darkgrid")
192
+ fig, axes = plt.subplots(1, len(metrics), figsize=(7.2, 2.7), dpi=200)
193
  colors = ["#0057b7", "#ffb300", "#008744", "#d62d20"]
194
  metric_labels = {
195
  "CNR": r"CNR $\uparrow$",