Spaces:
Paused
Paused
| from pathlib import Path | |
| import torch | |
| import gradio as gr | |
| from src.flux.xflux_pipeline import XFluxPipeline | |
| def create_demo( | |
| model_type: str, | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
| offload: bool = False, | |
| ckpt_dir: str = "", | |
| ): | |
| xflux_pipeline = XFluxPipeline(model_type, device, offload) | |
| checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", value="handsome woman in the city") | |
| with gr.Accordion("Generation Options", open=False): | |
| with gr.Row(): | |
| width = gr.Slider(512, 2048, 1024, step=16, label="Width") | |
| height = gr.Slider(512, 2048, 1024, step=16, label="Height") | |
| neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo") | |
| with gr.Row(): | |
| num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") | |
| timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg") | |
| with gr.Row(): | |
| guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True) | |
| true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True) | |
| seed = gr.Textbox(-1, label="Seed (-1 for random)") | |
| with gr.Accordion("ControlNet Options", open=False): | |
| control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type") | |
| control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True) | |
| local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint", | |
| info="Local Path to Controlnet weights (if no, it will be downloaded from HF)" | |
| ) | |
| controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True) | |
| with gr.Accordion("LoRA Options", open=False): | |
| lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True) | |
| lora_local_path = gr.Dropdown( | |
| checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights" | |
| ) | |
| with gr.Accordion("IP Adapter Options", open=False): | |
| image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True) | |
| ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale") | |
| neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True) | |
| neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale") | |
| ip_local_path = gr.Dropdown( | |
| checkpoints, label="IP Adapter Checkpoint", | |
| info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)" | |
| ) | |
| generate_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image") | |
| download_btn = gr.File(label="Download full-resolution") | |
| inputs = [prompt, image_prompt, controlnet_image, width, height, guidance, | |
| num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, | |
| neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, | |
| lora_weight, local_path, lora_local_path, ip_local_path | |
| ] | |
| generate_btn.click( | |
| fn=xflux_pipeline.gradio_generate, | |
| inputs=inputs, | |
| outputs=[output_image, download_btn], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Flux") | |
| parser.add_argument("--name", type=str, default="flux-dev", help="Model name") | |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") | |
| parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") | |
| parser.add_argument("--share", action="store_true", help="Create a public link to your demo") | |
| parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format") | |
| args = parser.parse_args() | |
| demo = create_demo(args.name, args.device, args.offload, args.ckpt_dir) | |
| demo.launch(share=args.share) | |