Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import gradio as gr | |
| import random | |
| import torch | |
| from collections import defaultdict | |
| from diffusers import DiffusionPipeline | |
| from functools import partial | |
| from itertools import zip_longest | |
| from typing import List | |
| from PIL import Image | |
| SELECT_LABEL = "Select as seed" | |
| MODEL_ID = "CompVis/ldm-text2im-large-256" | |
| STEPS = 50 | |
| ETA = 0.3 | |
| GUIDANCE_SCALE = 6 | |
| ldm = DiffusionPipeline.from_pretrained(MODEL_ID) | |
| import torch | |
| print(f"cuda: {torch.cuda.is_available()}") | |
| print(f"device: {torch.cuda.get_device_name()}") | |
| with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo: | |
| state = gr.Variable({ | |
| 'selected': -1, | |
| 'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)] | |
| }) | |
| def infer_seeded_image(prompt, seed): | |
| print(f"Prompt: {prompt}, seed: {seed}") | |
| images, _ = infer_grid(prompt, n=1, seeds=[seed]) | |
| return images[0] | |
| def infer_grid(prompt, n=6, seeds=[]): | |
| # Unfortunately we have to iterate instead of requesting all images at once, | |
| # because we have no way to get the intermediate generation seeds. | |
| result = defaultdict(list) | |
| for _, seed in zip_longest(range(n), seeds, fillvalue=None): | |
| seed = random.randint(0, 2**32 - 1) if seed is None else seed | |
| print(f"Setting seed {seed}") | |
| _ = torch.manual_seed(seed) | |
| images = ldm( | |
| [prompt], | |
| num_inference_steps=STEPS, | |
| eta=ETA, | |
| guidance_scale=GUIDANCE_SCALE | |
| )["sample"] | |
| result["images"].append(images[0]) | |
| result["seeds"].append(seed) | |
| return result["images"], result["seeds"] | |
| def infer(prompt, state): | |
| """ | |
| Outputs: | |
| - Grid images (list) | |
| - Seeded Image (Image or None) | |
| - Grid Box with updated visibility | |
| - Seeded Box with updated visibility | |
| """ | |
| grid_images = [None] * 6 | |
| image_with_seed = None | |
| visible = (False, False) | |
| if (seed_index := state["selected"]) > -1: | |
| seed = state["seeds"][seed_index] | |
| image_with_seed = infer_seeded_image(prompt, seed) | |
| visible = (False, True) | |
| else: | |
| grid_images, seeds = infer_grid(prompt) | |
| state["seeds"] = seeds | |
| visible = (True, False) | |
| boxes = [gr.Box.update(visible=v) for v in visible] | |
| return grid_images + [image_with_seed] + boxes + [state] | |
| def update_state(selected_index: int, value, state): | |
| if value == '': | |
| others_value = None | |
| else: | |
| others_value = '' | |
| state["selected"] = selected_index | |
| others = gr.Radio.update(value=others_value) | |
| return [others] * 5 + [state] | |
| def clear_seed(state): | |
| """Update state of Radio buttons, grid, seeded_box""" | |
| state["selected"] = -1 | |
| return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state] | |
| def image_block(): | |
| return gr.Image( | |
| interactive=False, show_label=False | |
| ).style( | |
| # border = (True, True, False, True), | |
| rounded = (True, True, False, False), | |
| ) | |
| def radio_block(): | |
| radio = gr.Radio( | |
| choices=[SELECT_LABEL], interactive=True, show_label=False, | |
| ).style( | |
| # border = (False, True, True, True), | |
| # rounded = (False, False, True, True) | |
| container=False | |
| ) | |
| return radio | |
| gr.Markdown( | |
| """ | |
| <h1><center>Latent Diffusion Demo</center></h1> | |
| <p>Type anything to generate a few images that represent your prompt. | |
| Select one of the results to use as a <b>seed</b> for the next generation: | |
| you can try variations of your prompt starting from the same state and see how it changes. | |
| For example, <i>Labrador in the style of Vermeer</i> could be tweaked to | |
| <i>Labrador in the style of Picasso</i> or <i>Lynx in the style of Van Gogh</i>. | |
| If your prompts are similar, the tweaked result should also have a similar structure | |
| but different details or style.</p> | |
| """ | |
| ) | |
| with gr.Group(): | |
| with gr.Box(): | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| text = gr.Textbox( | |
| label="Enter your prompt", show_label=False, max_lines=1 | |
| ).style( | |
| border=(True, False, True, True), | |
| # margin=False, | |
| rounded=(True, False, False, True), | |
| container=False, | |
| ) | |
| btn = gr.Button("Run").style( | |
| margin=False, | |
| rounded=(False, True, True, False), | |
| ) | |
| ## Can we create a Component with these, so it can participate as an output? | |
| with (grid := gr.Box()): | |
| with gr.Row(): | |
| with gr.Box().style(border=None): | |
| image1 = image_block() | |
| select1 = radio_block() | |
| with gr.Box().style(border=None): | |
| image2 = image_block() | |
| select2 = radio_block() | |
| with gr.Box().style(border=None): | |
| image3 = image_block() | |
| select3 = radio_block() | |
| with gr.Row(): | |
| with gr.Box().style(border=None): | |
| image4 = image_block() | |
| select4 = radio_block() | |
| with gr.Box().style(border=None): | |
| image5 = image_block() | |
| select5 = radio_block() | |
| with gr.Box().style(border=None): | |
| image6 = image_block() | |
| select6 = radio_block() | |
| images = [image1, image2, image3, image4, image5, image6] | |
| selectors = [select1, select2, select3, select4, select5, select6] | |
| for i, radio in enumerate(selectors): | |
| others = list(filter(lambda s: s != radio, selectors)) | |
| radio.change( | |
| partial(update_state, i), | |
| inputs=[radio, state], | |
| outputs=others + [state] | |
| ) | |
| with (seeded_box := gr.Box()): | |
| seeded_image = image_block() | |
| clear_seed_button = gr.Button("Return to Grid") | |
| seeded_box.visible = False | |
| clear_seed_button.click( | |
| clear_seed, | |
| inputs=[state], | |
| outputs=selectors + [grid, seeded_box] + [state] | |
| ) | |
| all_images = images + [seeded_image] | |
| boxes = [grid, seeded_box] | |
| infer_outputs = all_images + boxes + [state] | |
| text.submit( | |
| infer, | |
| inputs=[text, state], | |
| outputs=infer_outputs | |
| ) | |
| btn.click( | |
| infer, | |
| inputs=[text, state], | |
| outputs=infer_outputs | |
| ) | |
| demo.launch(enable_queue=True) |