Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
-
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline,
|
| 5 |
from PIL import Image
|
| 6 |
import os
|
| 7 |
import time
|
|
@@ -10,6 +10,7 @@ from utils.dl_utils import dl_cn_model, dl_cn_config, dl_lora_model
|
|
| 10 |
from utils.image_utils import resize_image_aspect_ratio, base_generation
|
| 11 |
from utils.prompt_utils import remove_duplicates
|
| 12 |
|
|
|
|
| 13 |
path = os.getcwd()
|
| 14 |
cn_dir = f"{path}/controlnet"
|
| 15 |
lora_dir = f"{path}/lora"
|
|
@@ -20,6 +21,7 @@ dl_cn_model(cn_dir)
|
|
| 20 |
dl_cn_config(cn_dir)
|
| 21 |
dl_lora_model(lora_dir)
|
| 22 |
|
|
|
|
| 23 |
def load_model(lora_dir, cn_dir):
|
| 24 |
dtype = torch.float16
|
| 25 |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
|
@@ -32,6 +34,7 @@ def load_model(lora_dir, cn_dir):
|
|
| 32 |
pipe.load_lora_weights(lora_dir, weight_name="Fixhands_anime_bdsqlsz_V1.safetensors")
|
| 33 |
return pipe
|
| 34 |
|
|
|
|
| 35 |
@spaces.GPU(duration=120)
|
| 36 |
def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
|
| 37 |
pipe = load_model(lora_dir, cn_dir)
|
|
@@ -50,7 +53,7 @@ def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
|
|
| 50 |
control_image=resize_image,
|
| 51 |
strength=1.0,
|
| 52 |
prompt=prompt,
|
| 53 |
-
negative_prompt
|
| 54 |
controlnet_conditioning_scale=float(controlnet_scale),
|
| 55 |
generator=generator,
|
| 56 |
num_inference_steps=30,
|
|
@@ -67,7 +70,6 @@ class Img2Img:
|
|
| 67 |
self.input_image_path = None
|
| 68 |
self.canny_image = None
|
| 69 |
|
| 70 |
-
|
| 71 |
def layout(self):
|
| 72 |
css = """
|
| 73 |
#intro{
|
|
@@ -78,14 +80,15 @@ class Img2Img:
|
|
| 78 |
"""
|
| 79 |
with gr.Blocks(css=css) as demo:
|
| 80 |
with gr.Row():
|
| 81 |
-
with gr.Column():
|
| 82 |
-
|
| 83 |
-
self.
|
| 84 |
-
self.
|
| 85 |
-
self.
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
| 89 |
|
| 90 |
generate_button.click(
|
| 91 |
fn=predict,
|
|
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
+
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
|
| 5 |
from PIL import Image
|
| 6 |
import os
|
| 7 |
import time
|
|
|
|
| 10 |
from utils.image_utils import resize_image_aspect_ratio, base_generation
|
| 11 |
from utils.prompt_utils import remove_duplicates
|
| 12 |
|
| 13 |
+
# Setup directories and download necessary models
|
| 14 |
path = os.getcwd()
|
| 15 |
cn_dir = f"{path}/controlnet"
|
| 16 |
lora_dir = f"{path}/lora"
|
|
|
|
| 21 |
dl_cn_config(cn_dir)
|
| 22 |
dl_lora_model(lora_dir)
|
| 23 |
|
| 24 |
+
# Model loading function
|
| 25 |
def load_model(lora_dir, cn_dir):
|
| 26 |
dtype = torch.float16
|
| 27 |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
|
|
|
| 34 |
pipe.load_lora_weights(lora_dir, weight_name="Fixhands_anime_bdsqlsz_V1.safetensors")
|
| 35 |
return pipe
|
| 36 |
|
| 37 |
+
# Image prediction and processing function
|
| 38 |
@spaces.GPU(duration=120)
|
| 39 |
def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
|
| 40 |
pipe = load_model(lora_dir, cn_dir)
|
|
|
|
| 53 |
control_image=resize_image,
|
| 54 |
strength=1.0,
|
| 55 |
prompt=prompt,
|
| 56 |
+
negative_prompt=negative_prompt,
|
| 57 |
controlnet_conditioning_scale=float(controlnet_scale),
|
| 58 |
generator=generator,
|
| 59 |
num_inference_steps=30,
|
|
|
|
| 70 |
self.input_image_path = None
|
| 71 |
self.canny_image = None
|
| 72 |
|
|
|
|
| 73 |
def layout(self):
|
| 74 |
css = """
|
| 75 |
#intro{
|
|
|
|
| 80 |
"""
|
| 81 |
with gr.Blocks(css=css) as demo:
|
| 82 |
with gr.Row():
|
| 83 |
+
with gr.Column(scale=1):
|
| 84 |
+
gr.Markdown("### Stickman to Posing Doll Image Converter\nこのアプリは棒人間をポーズ人形画像に変換するアプリです。\n入力する棒人間の形状は以下のリンクを参考にしてください。\n[VRoid Hub Character Example](https://hub.vroid.com/characters/4765753841994800453/models/6738034259079048708)\nIf your stick figure resembles the linked shape, it should work reasonably well even if hand-drawn.")
|
| 85 |
+
self.input_image_path = gr.Image(label="Input Image", type='filepath')
|
| 86 |
+
self.prompt = gr.Textbox(label="Prompt", lines=3)
|
| 87 |
+
self.negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value="nsfw, nipples, bad anatomy, liquid fingers, low quality, worst quality, out of focus, ugly, error, jpeg artifacts, lowers, blurry, bokeh")
|
| 88 |
+
self.controlnet_scale = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.01, label="Controlnet Scale")
|
| 89 |
+
generate_button = gr.Button("Generate")
|
| 90 |
+
with gr.Column(scale=1):
|
| 91 |
+
self.output_image = gr.Image(type="pil", label="Output Image")
|
| 92 |
|
| 93 |
generate_button.click(
|
| 94 |
fn=predict,
|