Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import random | |
| from PIL import Image | |
| import torch | |
| import gradio as gr | |
| import dotenv | |
| from adapter import load_ip_adapter_model, get_file_path | |
| from example import EXAMPLES | |
| dotenv.load_dotenv(".env.local") | |
| ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID") | |
| ADAPTER_MODEL_PATH = os.environ.get("ADAPTER_MODEL_PATH") | |
| ADAPTER_CONFIG_PATH = os.environ.get("ADAPTER_CONFIG_PATH") | |
| assert ADAPTER_REPO_ID is not None | |
| assert ADAPTER_MODEL_PATH is not None | |
| assert ADAPTER_CONFIG_PATH is not None | |
| BASE_MODEL_REPO_ID = os.environ.get( | |
| "BASE_MODEL_REPO_ID", "p1atdev/animagine-xl-4.0-bnb-nf4" | |
| ) | |
| BASE_MODEL_PATH = os.environ.get( | |
| "BASE_MODEL_PATH", "animagine-xl-4.0-opt.bnb_nf4.safetensors" | |
| ) | |
| INITIAL_BATCH_SIZE = int(os.environ.get("INITIAL_BATCH_SIZE", 1)) | |
| adapter_model_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_MODEL_PATH) | |
| adapter_config_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_CONFIG_PATH) | |
| base_model_path = get_file_path(BASE_MODEL_REPO_ID, BASE_MODEL_PATH) | |
| model = load_ip_adapter_model( | |
| model_path=base_model_path, | |
| config_path=adapter_config_path, | |
| adapter_path=adapter_model_path, | |
| ) | |
| model.to("cuda:0") | |
| def on_generate( | |
| prompt: str, | |
| negative_prompt: str, | |
| image: Image.Image | None, | |
| width: int, | |
| height: int, | |
| steps: int, | |
| cfg_scale: float, | |
| seed: int, | |
| randomize_seed: bool = True, | |
| num_images: int = 4, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if image is not None: | |
| image = image.convert("RGB") | |
| if randomize_seed: | |
| seed = random.randint(0, 2147483647) | |
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
| images = model.generate( | |
| prompt=[prompt] * num_images, # batch size 4 | |
| negative_prompt=negative_prompt, | |
| reference_image=image, | |
| num_inference_steps=steps, | |
| cfg_scale=cfg_scale, | |
| width=width, | |
| height=height, | |
| seed=seed, | |
| do_offloading=False, | |
| device="cuda:0", | |
| max_token_length=225, | |
| execution_dtype=torch.bfloat16, | |
| ) | |
| torch.cuda.empty_cache() | |
| return images, seed | |
| def main(): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.TextArea( | |
| label="Prompt", | |
| value="masterpiece, best quality", | |
| placeholder="masterpiece, best quality", | |
| interactive=True, | |
| ) | |
| input_image = gr.Image( | |
| label="Reference Image", | |
| type="pil", | |
| height=600, | |
| ) | |
| with gr.Accordion("Negative Prompt", open=False): | |
| negative_prompt = gr.TextArea( | |
| label="Negative Prompt", | |
| show_label=False, | |
| value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=2048, | |
| step=128, | |
| value=896, | |
| interactive=True, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=2048, | |
| step=128, | |
| value=1152, | |
| interactive=True, | |
| ) | |
| with gr.Accordion("Advanced options", open=False): | |
| num_images = gr.Slider( | |
| label="Number of images to generate", | |
| minimum=1, | |
| maximum=8, | |
| step=1, | |
| value=INITIAL_BATCH_SIZE, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=2147483647, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize seed", | |
| value=True, | |
| interactive=True, | |
| scale=1, | |
| ) | |
| steps = gr.Slider( | |
| label="Inference steps", | |
| minimum=10, | |
| maximum=50, | |
| step=1, | |
| value=25, | |
| interactive=True, | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG scale", | |
| minimum=3.0, | |
| maximum=8.0, | |
| step=0.5, | |
| value=5.0, | |
| interactive=True, | |
| ) | |
| with gr.Column(): | |
| generate_button = gr.Button( | |
| "Generate", | |
| variant="primary", | |
| ) | |
| output_image = gr.Gallery( | |
| label="Generated images", | |
| type="pil", | |
| rows=2, | |
| height="768px", | |
| preview=True, | |
| show_label=True, | |
| ) | |
| comment = gr.Markdown( | |
| label="Comment", | |
| visible=False, | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[input_image, prompt, width, height, comment], | |
| cache_examples=False, | |
| ) | |
| gr.on( | |
| triggers=[generate_button.click], | |
| fn=on_generate, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| input_image, | |
| width, | |
| height, | |
| steps, | |
| cfg_scale, | |
| seed, | |
| randomize_seed, | |
| num_images, | |
| ], | |
| outputs=[output_image, seed], | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |