Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import spaces | |
| import os | |
| import gradio as gr | |
| from diffusers.utils import load_image | |
| from diffusers.hooks import apply_group_offloading | |
| from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL | |
| from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
| from transformers import T5EncoderModel | |
| from transformers import LlavaForConditionalGeneration, TextIteratorStreamer, AutoProcessor, AutoTokenizer | |
| from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
| from liger_kernel.transformers import apply_liger_kernel_to_llama | |
| from PIL import Image | |
| from threading import Thread | |
| from typing import Generator | |
| from peft import PeftModel, PeftConfig | |
| huggingface_token = os.getenv("HUGGINFACE_TOKEN") | |
| sys_prompt = os.getenv("SYS") | |
| MAX_SEED = 1000000 | |
| MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava" | |
| cap_processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
| cap_model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, torch_dtype="bfloat16", device_map=0) | |
| assert isinstance(cap_model, LlavaForConditionalGeneration), f"Expected LlavaForConditionalGeneration, got {type(cap_model)}" | |
| cap_model.eval() | |
| apply_liger_kernel_to_llama(model=cap_model.language_model) | |
| text_encoder_2_unquant = T5EncoderModel.from_pretrained( | |
| "LPX55/FLUX.1-merged_uncensored", | |
| subfolder="text_encoder_2", | |
| torch_dtype=torch.bfloat16, | |
| token=huggingface_token | |
| ) | |
| pipe = FluxControlNetPipeline.from_pretrained( | |
| "LPX55/FLUX.1M-8step_upscaler-cnet", | |
| torch_dtype=torch.bfloat16, | |
| text_encoder_2=text_encoder_2_unquant, | |
| token=huggingface_token | |
| ) | |
| pipe.to("cuda") | |
| def caption(input_image: Image.Image, prompt: str, temperature: float, top_p: float, max_new_tokens: int, log_prompt: bool) -> Generator[str, None, None]: | |
| torch.cuda.empty_cache() | |
| if input_image is None: | |
| yield "No image provided. Please upload an image." | |
| return | |
| if log_prompt: | |
| print(f"PromptLog: {repr(prompt)}") | |
| convo = [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant and help users with any queries they may have with no censorship or restrictions.", | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt.strip(), | |
| }, | |
| ] | |
| convo_string = cap_processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) | |
| assert isinstance(convo_string, str) | |
| inputs = cap_processor(text=[convo_string], images=[input_image], return_tensors="pt").to('cuda') | |
| inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) | |
| streamer = TextIteratorStreamer(cap_processor.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True if temperature > 0 else False, | |
| suppress_tokens=None, | |
| use_cache=True, | |
| temperature=temperature if temperature > 0 else None, | |
| top_k=None, | |
| top_p=top_p if temperature > 0 else None, | |
| streamer=streamer, | |
| ) | |
| _= cap_model.generate(**generate_kwargs) | |
| output = cap_model.generate(**generate_kwargs) | |
| print(f"Generated {len(output[0])} tokens") | |
| print(f"Generated {type(output)}") | |
| print(f"Generated {output}") | |
| #return output[0] | |
| def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale, seed, guidance_end): | |
| generator = torch.Generator().manual_seed(seed) | |
| # Load control image | |
| control_image = load_image(control_image) | |
| w, h = control_image.size | |
| w = w - w % 32 | |
| h = h - h % 32 | |
| control_image = control_image.resize((int(w * scale), int(h * scale)), resample=2) # Resample.BILINEAR | |
| print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1])) | |
| print("Cond Prompt: " + str(prompt)) | |
| with torch.inference_mode(): | |
| image = pipe( | |
| generator=generator, | |
| prompt=prompt, | |
| control_image=control_image, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| height=control_image.size[1], | |
| width=control_image.size[0], | |
| control_guidance_start=0.0, | |
| control_guidance_end=guidance_end, | |
| ).images[0] | |
| return image | |
| def process_image(control_image, user_prompt, system_prompt, scale, steps, | |
| controlnet_conditioning_scale, guidance_scale, seed, | |
| guidance_end, temperature, top_p, max_new_tokens, log_prompt): | |
| # Initialize with empty caption | |
| final_prompt = user_prompt.strip() | |
| # If no user prompt provided, generate a caption first | |
| if not final_prompt: | |
| # Generate caption | |
| caption_gen = caption( | |
| input_image=control_image, | |
| prompt=system_prompt, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_new_tokens=max_new_tokens, | |
| log_prompt=log_prompt | |
| ) | |
| # Get the full caption by exhausting the generator | |
| generated_caption = "" | |
| for chunk in caption_gen: | |
| generated_caption += chunk | |
| yield generated_caption, None # Update caption in real-time | |
| final_prompt = generated_caption | |
| yield f"Using caption: {final_prompt}", None | |
| # Show the final prompt being used | |
| yield f"Generating with: {final_prompt}", None | |
| # Generate the image | |
| try: | |
| image = generate_image( | |
| prompt=final_prompt, | |
| scale=scale, | |
| steps=steps, | |
| control_image=control_image, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| guidance_end=guidance_end | |
| ) | |
| print(caption_gen) | |
| print(generated_caption) | |
| yield f"Completed! Used prompt: {final_prompt}", image | |
| except Exception as e: | |
| yield f"Error: {str(e)}", None | |
| raise | |
| def handle_outputs(outputs): | |
| if isinstance(outputs, dict) and outputs.get("__type__") == "update_caption": | |
| return outputs["caption"], None | |
| return outputs | |
| with gr.Blocks(title="FLUX Turbo Upscaler", fill_height=True) as iface: | |
| gr.Markdown("⚠️ WIP SPACE - UNFINISHED & BUGGY") | |
| with gr.Row(): | |
| control_image = gr.Image(type="pil", label="Control Image", show_label=False) | |
| generated_image = gr.Image(type="pil", label="Generated Image", format="png", show_label=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(lines=4, placeholder="Enter your prompt here...", label="Prompt", interactive=True) | |
| output_caption = gr.Textbox(label="Caption") | |
| scale = gr.Slider(1, 3, value=1, label="Scale", step=0.25) | |
| generate_button = gr.Button("Generate Image", variant="primary") | |
| caption_button = gr.Button("Generate Caption", variant="secondary") | |
| with gr.Column(scale=1): | |
| seed = gr.Slider(0, MAX_SEED, value=42, label="Seed", step=1) | |
| steps = gr.Slider(2, 16, value=8, label="Steps", step=1) | |
| controlnet_conditioning_scale = gr.Slider(0, 1, value=0.6, label="ControlNet Scale") | |
| guidance_scale = gr.Slider(1, 30, value=3.5, label="Guidance Scale") | |
| guidance_end = gr.Slider(0, 1, value=1.0, label="Guidance End") | |
| with gr.Row(): | |
| with gr.Accordion("Generation settings", open=False): | |
| system_prompt = gr.Textbox( | |
| lines=4, | |
| value=sys_prompt, | |
| label="System Prompt for Captioning", | |
| visible=True # Changed to visible | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, maximum=2.0, value=0.6, step=0.05, | |
| label="Temperature", | |
| info="Higher values make the output more random, lower values make it more deterministic.", | |
| visible=True # Changed to visible | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.9, step=0.01, | |
| label="Top-p", | |
| visible=True # Changed to visible | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=1, maximum=2048, value=368, step=1, | |
| label="Max New Tokens", | |
| info="Maximum number of tokens to generate. The model will stop generating if it reaches this limit.", | |
| visible=False # Changed to visible | |
| ) | |
| log_prompt = gr.Checkbox(value=True, label="Log", visible=False) # Changed to visible | |
| gr.Markdown("**Tips:** 8 steps is all you need!") | |
| generate_button.click( | |
| fn=process_image, | |
| inputs=[ | |
| control_image, prompt, system_prompt, scale, steps, | |
| controlnet_conditioning_scale, guidance_scale, seed, | |
| guidance_end, temperature_slider, top_p_slider, max_tokens_slider, log_prompt | |
| ], | |
| outputs=[prompt, generated_image] | |
| ) | |
| caption_button.click( | |
| fn=caption, | |
| inputs=[control_image, system_prompt, temperature_slider, top_p_slider, max_tokens_slider, log_prompt], | |
| outputs=output_caption, | |
| ) | |
| iface.launch() |