Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import torch | |
| import spaces | |
| import os | |
| import diffusers | |
| import PIL | |
| from diffusers.utils import load_image | |
| from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL | |
| from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
| from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
| import gradio as gr | |
| from accelerate import dispatch_model, infer_auto_device_map | |
| from PIL import Image | |
| from diffusers import FluxTransformer2DModel | |
| from transformers import T5EncoderModel | |
| import gc | |
| # Corrected and optimized FluxControlNet implementation | |
| huggingface_token = os.getenv("HUGGINFACE_TOKEN") | |
| device = "cuda" | |
| torch_dtype = torch.bfloat16 | |
| MAX_SEED = 1000000 | |
| def self_attention_slicing(module, slice_size=3): | |
| """Modified from Diffusers' original for Flux compatibility""" | |
| def sliced_attention(*args, **kwargs): | |
| if "dim" in kwargs: | |
| dim = kwargs["dim"] | |
| else: | |
| dim = 1 | |
| if slice_size == "auto": | |
| # Automatic slicing based on Flux architecture | |
| return module(*args, **kwargs) | |
| output = torch.cat([ | |
| module( | |
| *[arg[:, :, i:i+slice_size] if i == dim else arg | |
| for arg in args], | |
| **{k: v[:, :, i:i+slice_size] if k == dim else v | |
| for k,v in kwargs.items()} | |
| ) | |
| for i in range(0, args[0].shape[dim], slice_size) | |
| ], dim=dim) | |
| return output | |
| return sliced_attention | |
| quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,) | |
| text_encoder_2_8bit = T5EncoderModel.from_pretrained( | |
| "LPX55/FLUX.1-merged_uncensored", | |
| subfolder="text_encoder_2", | |
| quantization_config=quant_config, | |
| torch_dtype=torch.bfloat16, | |
| token=huggingface_token | |
| ) | |
| quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,) | |
| transformer_8bit = FluxTransformer2DModel.from_pretrained( | |
| "LPX55/FLUX.1-merged_uncensored", | |
| subfolder="transformer", | |
| quantization_config=quant_config, | |
| torch_dtype=torch.bfloat16, | |
| token=huggingface_token | |
| ) | |
| good_vae = AutoencoderKL.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16, | |
| use_safetensors=True, | |
| device_map=None, # Disable automatic mapping | |
| token=huggingface_token | |
| ).to(device) | |
| # 2. Main Pipeline Initialization WITH VAE SCOPE | |
| pipe = FluxControlNetPipeline.from_pretrained( | |
| "LPX55/FLUX.1-merged_uncensored", | |
| controlnet=FluxControlNetModel.from_pretrained( | |
| "jasperai/Flux.1-dev-Controlnet-Upscaler", | |
| torch_dtype=torch.bfloat16 | |
| ), | |
| vae=good_vae, # Now defined in scope | |
| transformer=transformer_8bit, | |
| text_encoder_2=text_encoder_2_8bit, | |
| torch_dtype=torch.bfloat16, | |
| use_safetensors=True, | |
| device_map=None, | |
| token=huggingface_token # Note corrected env var name | |
| ) | |
| pipe.to(device) | |
| # 3. Strict Order for Optimization Steps | |
| # A. Apply CPU Offloading FIRST | |
| #### pipe.enable_sequential_cpu_offload() # No arguments for new API | |
| # 2. Then apply custom VAE slicing | |
| if getattr(pipe, "vae", None) is not None: | |
| # Method 1: Use official implementation if available | |
| try: | |
| pipe.vae.enable_slicing() | |
| except AttributeError: | |
| # Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py | |
| print("Falling back to manual attention slicing.") | |
| pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2) | |
| pipe.enable_attention_slicing(1) | |
| # B. Enable Memory Optimizations | |
| # pipe.enable_vae_tiling() | |
| # pipe.enable_xformers_memory_efficient_attention() | |
| # C. Unified Precision Handling | |
| # for comp in [pipe.unet, pipe.vae, pipe.controlnet]: | |
| # comp.to(dtype=torch.bfloat16) | |
| print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB") | |
| def generate_image(prompt, scale, steps, seed, control_image, controlnet_conditioning_scale, guidance_scale, guidance_start, guidance_end): | |
| print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}") | |
| # Load control image | |
| control_image = load_image(control_image) | |
| w, h = control_image.size | |
| w = w - w % 8 | |
| h = h - h % 8 | |
| control_image = control_image.resize((int(w * scale), int(h * scale))) | |
| print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1])) | |
| generator = torch.Generator().manual_seed(seed) | |
| image = pipe( | |
| prompt=prompt, | |
| control_image=control_image, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| height=h, | |
| width=w, | |
| control_guidance_start=guidance_start, | |
| control_guidance_end=guidance_end, | |
| generator=generator | |
| ).images[0] | |
| return image | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=generate_image, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter your prompt here..."), | |
| gr.Slider(1, 3, value=1, label="Scale"), | |
| gr.Slider(2, 20, value=8, label="Steps"), | |
| gr.Slider(0, MAX_SEED, value=42, label="Seed"), | |
| gr.Image(type="pil", label="Control Image"), | |
| gr.Slider(0, 1, value=0.6, label="ControlNet Scale"), | |
| gr.Slider(1, 20, value=3.5, label="Guidance Scale"), | |
| gr.Slider(0, 1, value=0.0, label="Control Guidance Start"), | |
| gr.Slider(0, 1, value=1.0, label="Control Guidance End"), | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Generated Image", format="png"), | |
| ], | |
| title="FLUX ControlNet Image Generation", | |
| description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.", | |
| ) | |
| print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}") | |
| gc.enable() | |
| gc.collect() | |
| # Launch the app | |
| iface.launch(show_error=True, share=True) |