Spaces:
Runtime error
Runtime error
support cfg-zero*
Browse files- src/pipeline.py +10 -5
src/pipeline.py
CHANGED
|
@@ -526,9 +526,11 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|
| 526 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 527 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 528 |
max_sequence_length: int = 512,
|
| 529 |
-
spatial_images=
|
| 530 |
-
subject_images=
|
| 531 |
cond_size=512,
|
|
|
|
|
|
|
| 532 |
):
|
| 533 |
|
| 534 |
height = height or self.default_sample_size * self.vae_scale_factor
|
|
@@ -656,7 +658,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|
| 656 |
guidance = guidance.expand(latents.shape[0])
|
| 657 |
else:
|
| 658 |
guidance = None
|
| 659 |
-
|
| 660 |
## Caching conditions
|
| 661 |
# clean the cache
|
| 662 |
for name, attn_processor in self.transformer.attn_processors.items():
|
|
@@ -679,7 +681,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|
| 679 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 680 |
return_dict=False,
|
| 681 |
)[0]
|
| 682 |
-
|
| 683 |
# 6. Denoising loop
|
| 684 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 685 |
for i, t in enumerate(timesteps):
|
|
@@ -700,6 +702,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|
| 700 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 701 |
return_dict=False,
|
| 702 |
)[0]
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
# compute the previous noisy sample x_t -> x_t-1
|
| 705 |
latents_dtype = latents.dtype
|
|
@@ -742,4 +747,4 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|
| 742 |
if not return_dict:
|
| 743 |
return (image,)
|
| 744 |
|
| 745 |
-
return FluxPipelineOutput(images=image)
|
|
|
|
| 526 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 527 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 528 |
max_sequence_length: int = 512,
|
| 529 |
+
spatial_images=[],
|
| 530 |
+
subject_images=[],
|
| 531 |
cond_size=512,
|
| 532 |
+
use_zero_init: Optional[bool] = True,
|
| 533 |
+
zero_steps: Optional[int] = 0,
|
| 534 |
):
|
| 535 |
|
| 536 |
height = height or self.default_sample_size * self.vae_scale_factor
|
|
|
|
| 658 |
guidance = guidance.expand(latents.shape[0])
|
| 659 |
else:
|
| 660 |
guidance = None
|
| 661 |
+
|
| 662 |
## Caching conditions
|
| 663 |
# clean the cache
|
| 664 |
for name, attn_processor in self.transformer.attn_processors.items():
|
|
|
|
| 681 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 682 |
return_dict=False,
|
| 683 |
)[0]
|
| 684 |
+
|
| 685 |
# 6. Denoising loop
|
| 686 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 687 |
for i, t in enumerate(timesteps):
|
|
|
|
| 702 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 703 |
return_dict=False,
|
| 704 |
)[0]
|
| 705 |
+
|
| 706 |
+
if (i <= zero_steps) and use_zero_init:
|
| 707 |
+
noise_pred = noise_pred*0.
|
| 708 |
|
| 709 |
# compute the previous noisy sample x_t -> x_t-1
|
| 710 |
latents_dtype = latents.dtype
|
|
|
|
| 747 |
if not return_dict:
|
| 748 |
return (image,)
|
| 749 |
|
| 750 |
+
return FluxPipelineOutput(images=image)
|