gaparmar commited on
Commit
6213d31
·
1 Parent(s): 49b7596

adding utils:

Browse files
app.py CHANGED
@@ -9,22 +9,13 @@ import torch.nn.functional as F
9
  from diffusers import FluxPipeline, AutoencoderTiny, FluxKontextPipeline
10
  from transformers import CLIPProcessor, CLIPModel, AutoModel
11
  from transformers.models.clip.modeling_clip import _get_vector_norm
12
- from nunchaku import NunchakuFluxTransformer2dModel
13
- from nunchaku.utils import get_precision
14
  from my_utils.group_inference import run_group_inference
15
  from my_utils.default_values import apply_defaults
16
- from diffusers.hooks import apply_group_offloading
17
- from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
18
- from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
19
-
20
  import argparse
21
 
22
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
23
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to("cuda")
24
- pipe.enable_model_cpu_offload()
25
-
26
-
27
- # pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to("cuda")
28
 
29
  m_clip = CLIPModel.from_pretrained("multimodalart/clip-vit-base-patch32").to("cuda")
30
  prep_clip = CLIPProcessor.from_pretrained("multimodalart/clip-vit-base-patch32")
@@ -283,6 +274,22 @@ with gr.Blocks(css=custom_css, js=js_func, theme=gr.themes.Soft(), elem_id="main
283
  binary_term = gr.Dropdown(choices=["diversity_dino", "diversity_clip", "dino_cls_pairwise"], value=default_args.binary_term,
284
  container=False, show_label=False, scale=3)
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  with gr.Row(scale=1):
287
  generate_btn = gr.Button("Generate", variant="primary")
288
 
@@ -295,4 +302,19 @@ with gr.Blocks(css=custom_css, js=js_func, theme=gr.themes.Soft(), elem_id="main
295
  outputs=[output_gallery_group]
296
  )
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  demo.launch()
 
9
  from diffusers import FluxPipeline, AutoencoderTiny, FluxKontextPipeline
10
  from transformers import CLIPProcessor, CLIPModel, AutoModel
11
  from transformers.models.clip.modeling_clip import _get_vector_norm
 
 
12
  from my_utils.group_inference import run_group_inference
13
  from my_utils.default_values import apply_defaults
 
 
 
 
14
  import argparse
15
 
16
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
17
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to("cuda")
18
+ # pipe.enable_model_cpu_offload()
 
 
 
19
 
20
  m_clip = CLIPModel.from_pretrained("multimodalart/clip-vit-base-patch32").to("cuda")
21
  prep_clip = CLIPProcessor.from_pretrained("multimodalart/clip-vit-base-patch32")
 
274
  binary_term = gr.Dropdown(choices=["diversity_dino", "diversity_clip", "dino_cls_pairwise"], value=default_args.binary_term,
275
  container=False, show_label=False, scale=3)
276
 
277
+
278
+ # Instructions for users
279
+ gr.HTML(
280
+ """
281
+ <div style="margin: 15px 0; padding: 10px; background-color: #f0f0f0; border-radius: 5px; font-size: 14px;">
282
+ <strong>Tips:</strong>
283
+ <ul style="margin: 5px 0; padding-left: 20px;">
284
+ <li>Try out the (cached) examples below first! </li>
285
+ <li>Higher lambda → more diverse outputs (no added runtime cost)</li>
286
+ <li>Lower lambda → improved quality and text-adherence (no added runtime cost)</li>
287
+ <li>More starting candidates → better quality and diversity (slower runtime)</li>
288
+ </ul>
289
+ </div>
290
+ """
291
+ )
292
+
293
  with gr.Row(scale=1):
294
  generate_btn = gr.Button("Generate", variant="primary")
295
 
 
302
  outputs=[output_gallery_group]
303
  )
304
 
305
+ gr.Examples(
306
+ examples=[
307
+ ["Cat is sitting in a cafe and working on his laptop.", 64, 4, 0.5, 1.0, 42, "clip_text_img", "diversity_dino", "assets/cat.png"],
308
+ ["Cat is playing outside in nature.", 64, 4, 0.5, 1.0, 42, "clip_text_img", "diversity_dino", "assets/cat.png"],
309
+ ["Cat is drinking a glass of milk.", 64, 4, 0.5, 1.0, 42, "clip_text_img", "diversity_dino", "assets/cat.png"],
310
+ ["Cat is an astronaut landing on the moon.", 64, 4, 0.5, 1.0, 42, "clip_text_img", "diversity_dino", "assets/cat.png"],
311
+ ["Cat is surfing in the ocean.", 64, 4, 0.5, 1.0, 42, "clip_text_img", "diversity_dino", "assets/cat.png"],
312
+ ],
313
+ inputs=[prompt, starting_candidates, output_group_size, pruning_ratio, lambda_score, seed, unary_term, binary_term, input_image],
314
+ outputs=[output_gallery_group],
315
+ fn=generate_images,
316
+ cache_examples=True,
317
+ label="Examples"
318
+ )
319
+
320
  demo.launch()
my_utils/default_values.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_VALUES = {
2
+ "flux-kontext": {
3
+ "num_inference_steps": 28,
4
+ "guidance_scale": 3.5,
5
+ "starting_candidates": 32,
6
+ "output_group_size": 4,
7
+ "pruning_ratio": 0.5,
8
+ "lambda_score": 1.0,
9
+ "output_dir": "outputs/flux-kontext",
10
+ "height": 512,
11
+ "width": 512,
12
+ "unary_term": "clip_text_img",
13
+ "binary_term": "diversity_dino"
14
+ }
15
+ }
16
+
17
+ def apply_defaults(args):
18
+ model_name = args.model_name
19
+
20
+ if model_name not in DEFAULT_VALUES:
21
+ raise ValueError(f"Unknown model name: {model_name}. Available models: {list(DEFAULT_VALUES.keys())}")
22
+
23
+ defaults = DEFAULT_VALUES[model_name]
24
+
25
+ for param_name, default_value in defaults.items():
26
+ if hasattr(args, param_name) and getattr(args, param_name) is None:
27
+ setattr(args, param_name, default_value)
28
+
29
+ return args
my_utils/group_inference.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, time
2
+ import math
3
+ import torch
4
+ import spaces
5
+ import numpy as np
6
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
7
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
8
+
9
+ from my_utils.solvers import gurobi_solver
10
+
11
+
12
+ def get_next_size(curr_size, final_size, keep_ratio):
13
+ """Calculate next size for progressive pruning during denoising.
14
+
15
+ Args:
16
+ curr_size: Current number of candidates
17
+ final_size: Target final size
18
+ keep_ratio: Fraction of candidates to keep at each step
19
+ """
20
+ if curr_size < final_size:
21
+ raise ValueError("Current size is less than the final size!")
22
+ elif curr_size == final_size:
23
+ return curr_size
24
+ else:
25
+ next_size = math.ceil(curr_size * keep_ratio)
26
+ return max(next_size, final_size)
27
+
28
+
29
+ @torch.no_grad()
30
+ def decode_latent(z, pipe, height, width):
31
+ """Decode latent tensor to image using VAE decoder.
32
+
33
+ Args:
34
+ z: Latent tensor to decode
35
+ pipe: Diffusion pipeline with VAE
36
+ height: Image height
37
+ width: Image width
38
+ """
39
+ z = pipe._unpack_latents(z, height, width, pipe.vae_scale_factor)
40
+ z = (z / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
41
+ z = pipe.vae.decode(z, return_dict=False)[0].clamp(-1,1)
42
+ return z
43
+
44
+
45
+ @torch.no_grad()
46
+ @spaces.GPU(duration=300)
47
+ def run_group_inference(pipe, model_name=None, prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None,
48
+ true_cfg_scale=1.0, height=None, width=None, num_inference_steps=28, sigmas=None, guidance_scale=3.5,
49
+ l_generator=None, max_sequence_length=512,
50
+ # group inference arguments
51
+ unary_score_fn=None, binary_score_fn=None,
52
+ starting_candidates=None, output_group_size=None, pruning_ratio=None, lambda_score=None,
53
+ # control arguments
54
+ control_image=None,
55
+ # input image for flux-kontext
56
+ input_image=None,
57
+ skip_first_cfg=True
58
+ ):
59
+ """Run group inference with progressive pruning for diverse, high-quality image generation.
60
+
61
+ Args:
62
+ pipe: Diffusion pipeline
63
+ model_name: Model type (flux-schnell, flux-dev, flux-depth, flux-canny, flux-kontext)
64
+ prompt: Text prompt for generation
65
+ unary_score_fn: Function to compute image quality scores
66
+ binary_score_fn: Function to compute pairwise diversity scores
67
+ starting_candidates: Initial number of noise samples
68
+ output_group_size: Final number of images to generate
69
+ pruning_ratio: Fraction to prune at each denoising step
70
+ lambda_score: Weight between quality and diversity terms
71
+ control_image: Control image for depth/canny models
72
+ input_image: Input image for flux-kontext editing
73
+ """
74
+ if l_generator is None:
75
+ l_generator = [torch.Generator("cpu").manual_seed(42+_seed) for _seed in range(starting_candidates)]
76
+
77
+ # use the default height and width if not provided
78
+ height = height or pipe.default_sample_size * pipe.vae_scale_factor
79
+ width = width or pipe.default_sample_size * pipe.vae_scale_factor
80
+
81
+ pipe._guidance_scale = guidance_scale
82
+ pipe._current_timestep = None
83
+ pipe._interrupt = False
84
+ pipe._joint_attention_kwargs = {}
85
+
86
+ device = pipe._execution_device
87
+
88
+ lora_scale = None
89
+ has_neg_prompt = negative_prompt is not None
90
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
91
+
92
+ # 3. Encode prompts
93
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(prompt=prompt, prompt_2=prompt_2, prompt_embeds=None, pooled_prompt_embeds=None, device=device, max_sequence_length=max_sequence_length, lora_scale=lora_scale)
94
+
95
+ if do_true_cfg:
96
+ negative_prompt_embeds, negative_pooled_prompt_embeds, _ = pipe.encode_prompt(prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_embeds=None, pooled_prompt_embeds=None, device=device, max_sequence_length=max_sequence_length, lora_scale=lora_scale)
97
+
98
+ # 4. Prepare latent variables
99
+ if model_name in ["flux-depth", "flux-canny"]:
100
+ # for control models, the pipe.transformer.config.in_channels is doubled
101
+ num_channels_latents = pipe.transformer.config.in_channels // 8
102
+ else:
103
+ num_channels_latents = pipe.transformer.config.in_channels // 4
104
+
105
+ # Handle different model types
106
+ image_latents = None
107
+ image_ids = None
108
+ if model_name == "flux-kontext":
109
+ processed_image = pipe.image_processor.preprocess(input_image, height=height, width=width)
110
+ l_latents = []
111
+ for _gen in l_generator:
112
+ latents, img_latents, latent_ids, img_ids = pipe.prepare_latents(
113
+ processed_image, 1, num_channels_latents, height, width,
114
+ prompt_embeds.dtype, device, _gen
115
+ )
116
+ l_latents.append(latents)
117
+ # Use the image_latents and image_ids from the first generator
118
+ _, image_latents, latent_image_ids, image_ids = pipe.prepare_latents(
119
+ processed_image, 1, num_channels_latents, height, width,
120
+ prompt_embeds.dtype, device, l_generator[0]
121
+ )
122
+ # Combine latent_ids with image_ids
123
+ if image_ids is not None:
124
+ latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0)
125
+ else:
126
+ # For other models (flux-schnell, flux-dev, flux-depth, flux-canny)
127
+ l_latents = [pipe.prepare_latents(1, num_channels_latents, height, width, prompt_embeds.dtype, device, _gen)[0] for _gen in l_generator]
128
+ _, latent_image_ids = pipe.prepare_latents(1, num_channels_latents, height, width, prompt_embeds.dtype, device, l_generator[0])
129
+
130
+ # 4.5. Prepare control image if provided
131
+ control_latents = None
132
+ if model_name in ["flux-depth", "flux-canny"]:
133
+ control_image_processed = pipe.prepare_image(image=control_image, width=width, height=height, batch_size=1, num_images_per_prompt=1, device=device, dtype=pipe.vae.dtype,)
134
+ if control_image_processed.ndim == 4:
135
+ control_latents = pipe.vae.encode(control_image_processed).latents
136
+ control_latents = (control_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
137
+ height_control_image, width_control_image = control_latents.shape[2:]
138
+ control_latents = pipe._pack_latents(control_latents, 1, num_channels_latents, height_control_image, width_control_image)
139
+
140
+ # 5. Prepare timesteps
141
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
142
+ image_seq_len = latent_image_ids.shape[0]
143
+ mu = calculate_shift(image_seq_len, pipe.scheduler.config.get("base_image_seq_len", 256), pipe.scheduler.config.get("max_image_seq_len", 4096), pipe.scheduler.config.get("base_shift", 0.5), pipe.scheduler.config.get("max_shift", 1.15))
144
+ timesteps, num_inference_steps = retrieve_timesteps(pipe.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
145
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
146
+ pipe._num_timesteps = len(timesteps)
147
+ _dtype = l_latents[0].dtype
148
+
149
+ # handle guidance
150
+ if pipe.transformer.config.guidance_embeds:
151
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(1)
152
+ else:
153
+ guidance = None
154
+ guidance_1 = torch.full([1], 1.0, device=device, dtype=torch.float32).expand(1)
155
+
156
+ # 6. Denoising loop
157
+ with pipe.progress_bar(total=num_inference_steps) as progress_bar:
158
+ for i, t in enumerate(timesteps):
159
+ if pipe.interrupt:
160
+ continue
161
+ if guidance is not None and skip_first_cfg and i == 0:
162
+ curr_guidance = guidance_1
163
+ else:
164
+ curr_guidance = guidance
165
+
166
+ pipe._current_timestep = t
167
+ timestep = t.expand(1).to(_dtype)
168
+ # ipdb.set_trace()
169
+ next_latents = []
170
+ x0_preds = []
171
+ # do 1 denoising step
172
+ for _latent in l_latents:
173
+ # prepare input for transformer based on model type
174
+ if model_name in ["flux-depth", "flux-canny"]:
175
+ # Control models: concatenate control latents along dim=2
176
+ latent_model_input = torch.cat([_latent, control_latents], dim=2)
177
+ elif model_name == "flux-kontext":
178
+ # Kontext model: concatenate image latents along dim=1
179
+ latent_model_input = torch.cat([_latent, image_latents], dim=1)
180
+ else:
181
+ # Standard models (flux-schnell, flux-dev): use latents as is
182
+ latent_model_input = _latent
183
+
184
+ noise_pred = pipe.transformer(hidden_states=latent_model_input, timestep=timestep / 1000, guidance=curr_guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=pipe.joint_attention_kwargs, return_dict=False)[0]
185
+
186
+ # For flux-kontext, we need to slice the noise_pred to match the latents size
187
+ if model_name == "flux-kontext":
188
+ noise_pred = noise_pred[:, : _latent.size(1)]
189
+
190
+ if do_true_cfg:
191
+ neg_noise_pred = pipe.transformer(hidden_states=latent_model_input, timestep=timestep / 1000, guidance=curr_guidance, pooled_projections=negative_pooled_prompt_embeds, encoder_hidden_states=negative_prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=pipe.joint_attention_kwargs, return_dict=False)[0]
192
+ if model_name == "flux-kontext":
193
+ neg_noise_pred = neg_noise_pred[:, : _latent.size(1)]
194
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
195
+ # compute the previous noisy sample x_t -> x_t-1
196
+ _latent = pipe.scheduler.step(noise_pred, t, _latent, return_dict=False)[0]
197
+ # the scheduler is not state-less, it maintains a step index that is incremented by one after each step,
198
+ # so we need to decrease it here
199
+ if hasattr(pipe.scheduler, "_step_index"):
200
+ pipe.scheduler._step_index -= 1
201
+
202
+ if type(pipe.scheduler) == FlowMatchEulerDiscreteScheduler:
203
+ dt = 0.0 - pipe.scheduler.sigmas[i]
204
+ x0_pred = _latent + dt * noise_pred
205
+ else:
206
+ raise NotImplementedError("Only Flow Scheduler is supported for now! For other schedulers, you need to manually implement the x0 prediction step.")
207
+
208
+ x0_preds.append(x0_pred)
209
+ next_latents.append(_latent)
210
+
211
+ if hasattr(pipe.scheduler, "_step_index"):
212
+ pipe.scheduler._step_index += 1
213
+
214
+ # if the size of next_latents > output_group_size, prune the latents
215
+ if len(next_latents) > output_group_size:
216
+ next_size = get_next_size(len(next_latents), output_group_size, 1 - pruning_ratio)
217
+ print(f"Pruning from {len(next_latents)} to {next_size}")
218
+ # decode the latents to pixels with tiny-vae
219
+ l_x0_decoded = [decode_latent(_latent, pipe, height, width) for _latent in x0_preds]
220
+ # compute the unary and binary scores
221
+ l_unary_scores = unary_score_fn(l_x0_decoded, target_caption=prompt)
222
+ M_binary_scores = binary_score_fn(l_x0_decoded) # upper triangular matrix
223
+ # run with Quadratic Integer Programming sover
224
+ t_start = time.time()
225
+ selected_indices = gurobi_solver(l_unary_scores, M_binary_scores, next_size, lam=lambda_score)
226
+ t_end = time.time()
227
+ print(f"Time taken for QIP: {t_end - t_start} seconds")
228
+ l_latents = [next_latents[_i] for _i in selected_indices]
229
+ else:
230
+ l_latents = next_latents
231
+
232
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
233
+ progress_bar.update()
234
+
235
+ pipe._current_timestep = None
236
+
237
+ l_images = [pipe._unpack_latents(_latent, height, width, pipe.vae_scale_factor) for _latent in l_latents]
238
+ l_images = [(latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor for latents in l_images]
239
+ l_images = [pipe.vae.decode(_image, return_dict=False)[0] for _image in l_images]
240
+ l_images_tensor = [image.clamp(-1, 1) for image in l_images] # Keep tensor version for scoring
241
+ l_images = [pipe.image_processor.postprocess(image, output_type="pil")[0] for image in l_images]
242
+
243
+ # Compute and print final scores
244
+ print(f"\n=== Final Scores for {len(l_images)} generated images ===")
245
+
246
+ # Compute unary scores
247
+ final_unary_scores = unary_score_fn(l_images_tensor, target_caption=prompt)
248
+ print(f"Unary scores (quality): {final_unary_scores}")
249
+ print(f"Mean unary score: {np.mean(final_unary_scores):.4f}")
250
+
251
+ # Compute binary scores
252
+ final_binary_scores = binary_score_fn(l_images_tensor)
253
+ print(f"Binary score matrix (diversity):")
254
+ print(final_binary_scores)
255
+
256
+ print("=" * 50)
257
+
258
+ pipe.maybe_free_model_hooks()
259
+ return l_images
260
+
my_utils/scores.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import torch.nn as nn
6
+ import torchvision.models as models
7
+ import torchvision.transforms as transforms
8
+ import cv2
9
+
10
+ from transformers import CLIPProcessor, CLIPModel, AutoModel
11
+ from transformers.models.clip.modeling_clip import _get_vector_norm
12
+
13
+
14
+
15
+ def validate_tensor_list(tensor_list, expected_type=torch.Tensor, min_dims=None, value_range=None, tolerance=0.1):
16
+ """
17
+ Validates a list of tensors with specified requirements.
18
+
19
+ Args:
20
+ tensor_list: List to validate
21
+ expected_type: Expected type of each element (default: torch.Tensor)
22
+ min_dims: Minimum number of dimensions each tensor should have
23
+ value_range: Tuple of (min_val, max_val) for tensor values
24
+ tolerance: Tolerance for value range checking (default: 0.1)
25
+ """
26
+ if not isinstance(tensor_list, list):
27
+ raise TypeError(f"Input must be a list, got {type(tensor_list)}")
28
+
29
+ if len(tensor_list) == 0:
30
+ raise ValueError("Input list cannot be empty")
31
+
32
+ for i, item in enumerate(tensor_list):
33
+ if not isinstance(item, expected_type):
34
+ raise TypeError(f"List element [{i}] must be {expected_type}, got {type(item)}")
35
+
36
+ if min_dims is not None and len(item.shape) < min_dims:
37
+ raise ValueError(f"List element [{i}] must have at least {min_dims} dimensions, got shape {item.shape}")
38
+
39
+ if value_range is not None:
40
+ min_val, max_val = value_range
41
+ item_min, item_max = item.min().item(), item.max().item()
42
+ if item_min < (min_val - tolerance) or item_max > (max_val + tolerance):
43
+ raise ValueError(f"List element [{i}] values must be in range [{min_val}, {max_val}], got range [{item_min:.3f}, {item_max:.3f}]")
44
+
45
+
46
+
47
+ def build_score_fn(name, device="cuda"):
48
+ """Build scoring functions for image quality and diversity assessment.
49
+
50
+ Args:
51
+ name: Score function name (clip_text_img, diversity_dino, dino_cls_pairwise, diversity_clip)
52
+ device: Device to load models on (default: cuda)
53
+ """
54
+ d_score_nets = {}
55
+
56
+ if name == "clip_text_img":
57
+ m_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
58
+ prep_clip = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
59
+ score_fn = functools.partial(unary_clip_text_img_t, device=device, m_clip=m_clip, preprocess_clip=prep_clip)
60
+ d_score_nets["m_clip"] = m_clip
61
+ d_score_nets["prep_clip"] = prep_clip
62
+
63
+ elif name == "diversity_dino":
64
+ dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
65
+ score_fn = functools.partial(binary_dino_pairwise_t, device=device, dino_model=dino_model)
66
+ d_score_nets["dino_model"] = dino_model
67
+
68
+ elif name == "dino_cls_pairwise":
69
+ dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
70
+ score_fn = functools.partial(binary_dino_cls_pairwise_t, device=device, dino_model=dino_model)
71
+ d_score_nets["dino_model"] = dino_model
72
+
73
+ elif name == "diversity_clip":
74
+ m_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
75
+ prep_clip = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
76
+ score_fn = functools.partial(binary_clip_pairwise_t, device=device, m_clip=m_clip, preprocess_clip=prep_clip)
77
+ d_score_nets["m_clip"] = m_clip
78
+ d_score_nets["prep_clip"] = prep_clip
79
+
80
+ else:
81
+ raise ValueError(f"Invalid score function name: {name}")
82
+
83
+ return score_fn, d_score_nets
84
+
85
+
86
+ @torch.no_grad()
87
+ def unary_clip_text_img_t(l_images, device, m_clip, preprocess_clip, target_caption, d_cache=None):
88
+ """Compute CLIP text-image similarity scores for a list of images.
89
+
90
+ Args:
91
+ l_images: List of image tensors in range [-1, 1]
92
+ device: Device for computation
93
+ m_clip: CLIP model
94
+ preprocess_clip: CLIP processor
95
+ target_caption: Text prompt for similarity comparison
96
+ d_cache: Optional cached text embeddings
97
+ """
98
+ # validate input images, l_images should be a list of torch tensors with range [-1, 1]
99
+ validate_tensor_list(l_images, expected_type=torch.Tensor, min_dims=3, value_range=(-1, 1))
100
+
101
+ _img_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(device)
102
+ _img_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(device)
103
+ b_images = torch.cat(l_images, dim=0)
104
+ b_images = F.interpolate(b_images, size=(224, 224), mode="bilinear", align_corners=False)
105
+ # re-normalize with clip mean and std
106
+ b_images = b_images*0.5 + 0.5
107
+ b_images = (b_images - _img_mean) / _img_std
108
+
109
+ if d_cache is None:
110
+ text_encoding = preprocess_clip.tokenizer(target_caption, return_tensors="pt", padding=True).to(device)
111
+ output = m_clip(pixel_values=b_images, **text_encoding).logits_per_image /m_clip.logit_scale.exp()
112
+ _score = output.view(-1).cpu().numpy()
113
+ else:
114
+ # compute with cached text embeddings
115
+ vision_outputs = m_clip.vision_model(pixel_values=b_images, output_attentions=False, output_hidden_states=False,
116
+ interpolate_pos_encoding=False, return_dict=True,)
117
+ image_embeds = m_clip.visual_projection(vision_outputs[1])
118
+ image_embeds = image_embeds / _get_vector_norm(image_embeds)
119
+ text_embeds = d_cache["text_embeds"]
120
+ _score = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)).t().view(-1).cpu().numpy()
121
+
122
+ return _score
123
+
124
+
125
+ @torch.no_grad()
126
+ def binary_dino_pairwise_t(l_images, device, dino_model):
127
+ """Compute pairwise diversity scores using DINO patch features.
128
+
129
+ Args:
130
+ l_images: List of image tensors in range [-1, 1]
131
+ device: Device for computation
132
+ dino_model: DINO model for feature extraction
133
+ """
134
+ # validate input images, l_images should be a list of torch tensors with range [-1, 1]
135
+ validate_tensor_list(l_images, expected_type=torch.Tensor, min_dims=3, value_range=(-1, 1))
136
+
137
+ b_images = torch.cat(l_images, dim=0)
138
+ _img_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
139
+ _img_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
140
+
141
+ b_images = F.interpolate(b_images, size=(256, 256), mode="bilinear", align_corners=False)
142
+ b_images = b_images*0.5 + 0.5
143
+ b_images = (b_images - _img_mean) / _img_std
144
+ all_features = dino_model(pixel_values=b_images).last_hidden_state[:, 1:, :].cpu() # B, 324, 768
145
+
146
+ N = len(l_images)
147
+ score_matrix = np.zeros((N, N))
148
+ for i in range(N):
149
+ f1 = all_features[i]
150
+ for j in range(i+1, N):
151
+ f2 = all_features[j]
152
+ cos_sim = (1 - F.cosine_similarity(f1, f2, dim=1)).mean().item()
153
+ score_matrix[i, j] = cos_sim
154
+ return score_matrix
155
+
156
+ @torch.no_grad()
157
+ def binary_dino_cls_pairwise_t(l_images, device, dino_model):
158
+ """Compute pairwise diversity scores using DINO CLS token features.
159
+
160
+ Args:
161
+ l_images: List of image tensors in range [-1, 1]
162
+ device: Device for computation
163
+ dino_model: DINO model for feature extraction
164
+ """
165
+ # validate input images, l_images should be a list of torch tensors with range [-1, 1]
166
+ validate_tensor_list(l_images, expected_type=torch.Tensor, min_dims=3, value_range=(-1, 1))
167
+
168
+ b_images = torch.cat(l_images, dim=0)
169
+ _img_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
170
+ _img_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
171
+
172
+ b_images = F.interpolate(b_images, size=(256, 256), mode="bilinear", align_corners=False)
173
+ b_images = b_images*0.5 + 0.5
174
+ b_images = (b_images - _img_mean) / _img_std
175
+ all_features = dino_model(pixel_values=b_images).last_hidden_state[:, 0:1, :].cpu() # B, 1, 768
176
+
177
+ N = len(l_images)
178
+ score_matrix = np.zeros((N, N))
179
+ for i in range(N):
180
+ f1 = all_features[i]
181
+ for j in range(i+1, N):
182
+ f2 = all_features[j]
183
+ cos_sim = (1 - F.cosine_similarity(f1, f2, dim=1)).mean().item()
184
+ score_matrix[i, j] = cos_sim
185
+ return score_matrix
186
+
187
+ @torch.no_grad()
188
+ def binary_clip_pairwise_t(l_images, device, m_clip, preprocess_clip):
189
+ """Compute pairwise diversity scores using CLIP image embeddings.
190
+
191
+ Args:
192
+ l_images: List of image tensors in range [-1, 1]
193
+ device: Device for computation
194
+ m_clip: CLIP model
195
+ preprocess_clip: CLIP processor
196
+ """
197
+ # validate input images, l_images should be a list of torch tensors with range [-1, 1]
198
+ validate_tensor_list(l_images, expected_type=torch.Tensor, min_dims=3, value_range=(-1, 1))
199
+
200
+ _img_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(device)
201
+ _img_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(device)
202
+ b_images = torch.cat(l_images, dim=0)
203
+ b_images = F.interpolate(b_images, size=(224, 224), mode="bilinear", align_corners=False)
204
+ # re-normalize with clip mean and std
205
+ b_images = b_images*0.5 + 0.5
206
+ b_images = (b_images - _img_mean) / _img_std
207
+
208
+ vision_outputs = m_clip.vision_model(pixel_values=b_images, output_attentions=False, output_hidden_states=False,
209
+ interpolate_pos_encoding=False, return_dict=True,)
210
+ image_embeds = m_clip.visual_projection(vision_outputs[1])
211
+ image_embeds = image_embeds / _get_vector_norm(image_embeds)
212
+
213
+ N = len(l_images)
214
+ score_matrix = np.zeros((N, N))
215
+ for i in range(N):
216
+ f1 = image_embeds[i]
217
+ for j in range(i+1, N):
218
+ f2 = image_embeds[j]
219
+ cos_sim = (1 - torch.dot(f1, f2)).item()
220
+ score_matrix[i, j] = cos_sim
221
+ return score_matrix
my_utils/solvers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gurobipy import Model, GRB, quicksum
2
+
3
+
4
+ def gurobi_solver(u, D, n_select, lam=1.0, time_limit=5.0):
5
+ """Solve quadratic integer programming problem for subset selection with unary and pairwise terms.
6
+
7
+ Args:
8
+ u: Unary scores for each item
9
+ D: Pairwise similarity matrix (upper triangular)
10
+ n_select: Number of items to select
11
+ lam: Weight for pairwise term (default: 1.0)
12
+ time_limit: Solver time limit in seconds (default: 5.0)
13
+ """
14
+ n = len(u)
15
+ model = Model()
16
+ model.Params.LogToConsole = 0
17
+ model.Params.TimeLimit = time_limit
18
+ model.Params.OutputFlag = 0
19
+
20
+ # Variables: x[i] in {0,1}
21
+ x = model.addVars(n, vtype=GRB.BINARY, name="x")
22
+ # Constraint: exactly k items selected
23
+ model.addConstr(quicksum(x[i] for i in range(n)) == n_select, name="select_k")
24
+
25
+ # Objective: sum of unary + lambda * pairwise
26
+ linear_part = quicksum(u[i] * x[i] for i in range(n))
27
+ quadratic_part = quicksum(lam * D[i, j] * x[i] * x[j] for i in range(n) for j in range(i + 1, n))
28
+
29
+ model.setObjective(linear_part + quadratic_part, GRB.MAXIMIZE)
30
+
31
+ model.optimize()
32
+ selected_indices = [i for i in range(n) if x[i].X > 0.5]
33
+ return selected_indices