rickveloper commited on
Commit
1ed16aa
·
verified ·
1 Parent(s): 6781bff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -64
app.py CHANGED
@@ -4,55 +4,64 @@ from PIL import Image, ImageDraw, ImageFont
4
  import gradio as gr
5
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
6
 
7
- # --------------------
8
- # Config
9
- # --------------------
10
- MODEL_ID = os.getenv("MODEL_ID", "runwayml/stable-diffusion-v1-5")
 
 
 
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
  DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
13
 
14
- # Simple prompt guardrail (blocks obvious NSFW attempts)
15
  NSFW_TERMS = [
16
  r"\bnsfw\b", r"\bnude\b", r"\bnudity\b", r"\bsex\b", r"\bexplicit\b", r"\bporn\b",
17
- r"\bboobs\b", r"\bbutt\b", r"\bass\b", r"\bnsfw\b", r"\bnaked\b", r"\btits\b",
18
  r"\b18\+\b", r"\berotic\b", r"\bfetish\b"
19
  ]
20
  NSFW_REGEX = re.compile("|".join(NSFW_TERMS), flags=re.IGNORECASE)
21
 
22
- # --------------------
23
- # Load pipeline
24
- # --------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  pipe = StableDiffusionPipeline.from_pretrained(
26
  MODEL_ID,
27
- torch_dtype=DTYPE
 
28
  )
 
 
29
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
30
 
31
- if torch.cuda.is_available():
32
  pipe = pipe.to("cuda")
33
  pipe.enable_attention_slicing()
34
  pipe.enable_vae_slicing()
35
  else:
36
  pipe = pipe.to("cpu")
37
 
38
- # --------------------
39
- # Helpers
40
- # --------------------
41
- def blocked_tile(reason: str, width=512, height=512) -> Image.Image:
42
- img = Image.new("RGB", (width, height), (20, 20, 24))
43
- draw = ImageDraw.Draw(img)
44
- text = f"BLOCKED\n{reason}"
45
- try:
46
- font = ImageFont.truetype("DejaVuSans-Bold.ttf", 28)
47
- except:
48
- font = ImageFont.load_default()
49
- tw, th = draw.multiline_textbbox((0,0), text, font=font)[2:]
50
- draw.multiline_text(((width - tw)//2, (height - th)//2), text, fill=(255,255,255), font=font, align="center")
51
- return img
52
-
53
- def is_prompt_nsfw(prompt: str) -> bool:
54
- return bool(NSFW_REGEX.search(prompt or ""))
55
-
56
  def generate(
57
  prompt: str,
58
  negative_prompt: str,
@@ -64,12 +73,16 @@ def generate(
64
  batch_size: int
65
  ) -> Tuple[List[Image.Image], str]:
66
  if not prompt.strip():
67
- return [], "Add a prompt to get rolling."
 
 
 
 
68
 
69
- # Hard block obvious NSFW prompts before hitting the model
70
- if is_prompt_nsfw(prompt) or is_prompt_nsfw(negative_prompt or ""):
71
- img = blocked_tile("NSFW prompt detected")
72
- return [img], "Blocked: NSFW prompt."
73
 
74
  # Seed
75
  if seed < 0:
@@ -78,7 +91,7 @@ def generate(
78
 
79
  out = pipe(
80
  prompt=prompt,
81
- negative_prompt=negative_prompt or None,
82
  num_inference_steps=steps,
83
  guidance_scale=guidance,
84
  width=width,
@@ -87,29 +100,20 @@ def generate(
87
  generator=generator
88
  )
89
 
90
- images = out.images
91
- flags = getattr(out, "nsfw_content_detected", None)
92
-
93
- # If the underlying safety checker flags NSFW, block it (no blur)
94
- if flags:
95
- for i, flagged in enumerate(flags):
96
- if flagged:
97
- images[i] = blocked_tile("NSFW content flagged")
98
-
99
- msg = f"Seed: {seed} • Images: {len(images)}"
100
- if flags is not None:
101
- msg += f" • Flagged: {sum(1 for f in flags if f)}"
102
- return images, msg
103
 
104
- # --------------------
105
- # UI
106
- # --------------------
107
- with gr.Blocks(title="VibeForge — Clean Image Generator") as demo:
108
  gr.Markdown(
109
  """
110
- # VibeForge ⚒️
111
- **Clean, creative image generation.**
112
- NSFW inputs are blocked. Keep it classy and go wild on style, lighting, composition, mood.
113
  """
114
  )
115
 
@@ -117,23 +121,23 @@ NSFW inputs are blocked. Keep it classy and go wild on style, lighting, composit
117
  with gr.Column(scale=3):
118
  prompt = gr.Textbox(
119
  label="Prompt",
120
- placeholder="a cinematic photo of a vintage motorcycle by the ocean at sunset, golden hour, soft rim light, 50mm"
121
  )
122
- negative = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, watermark, jpeg artifacts")
123
  with gr.Row():
124
- steps = gr.Slider(10, 50, value=28, step=1, label="Steps")
125
- guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.1, label="CFG")
126
  with gr.Row():
127
- width = gr.Dropdown(choices=[384, 448, 512, 640, 768], value=512, label="Width")
128
- height = gr.Dropdown(choices=[384, 448, 512, 640, 768], value=512, label="Height")
129
  with gr.Row():
130
  seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
131
- batch = gr.Slider(1, 4, value=1, step=1, label="Batch")
132
 
133
  go = gr.Button("Generate", variant="primary")
134
 
135
  with gr.Column(scale=5):
136
- gallery = gr.Gallery(label="Output", columns=2, height=512)
137
  info = gr.Markdown()
138
 
139
  go.click(
@@ -143,4 +147,4 @@ NSFW inputs are blocked. Keep it classy and go wild on style, lighting, composit
143
  )
144
 
145
  if __name__ == "__main__":
146
- demo.launch()
 
4
  import gradio as gr
5
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
6
 
7
+ # =========================
8
+ # SPEED PRESET
9
+ # =========================
10
+ # Use SD Turbo (1.5) – optimized for very few steps on CPU
11
+ DEFAULT_MODEL_ID = "stabilityai/sd-turbo"
12
+ MODEL_ID = os.getenv("MODEL_ID", DEFAULT_MODEL_ID)
13
+
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
16
 
17
+ # Short NSFW guardrail (block, no blur)
18
  NSFW_TERMS = [
19
  r"\bnsfw\b", r"\bnude\b", r"\bnudity\b", r"\bsex\b", r"\bexplicit\b", r"\bporn\b",
20
+ r"\bboobs\b", r"\bbutt\b", r"\bass\b", r"\bnaked\b", r"\btits\b",
21
  r"\b18\+\b", r"\berotic\b", r"\bfetish\b"
22
  ]
23
  NSFW_REGEX = re.compile("|".join(NSFW_TERMS), flags=re.IGNORECASE)
24
 
25
+ def _blocked_tile(reason: str, w=384, h=384) -> Image.Image:
26
+ img = Image.new("RGB", (w, h), (18, 20, 26))
27
+ d = ImageDraw.Draw(img)
28
+ text = f"BLOCKED\n{reason}"
29
+ try:
30
+ font = ImageFont.truetype("DejaVuSans-Bold.ttf", 26)
31
+ except:
32
+ font = ImageFont.load_default()
33
+ box = d.multiline_textbbox((0,0), text, font=font, align="center")
34
+ tw, th = box[2]-box[0], box[3]-box[1]
35
+ d.multiline_text(((w-tw)//2, (h-th)//2), text, font=font, fill=(255,255,255), align="center")
36
+ return img
37
+
38
+ def _is_nsfw(s: str) -> bool:
39
+ return bool(NSFW_REGEX.search(s or ""))
40
+
41
+ # -------------------------
42
+ # Load pipeline (fast path)
43
+ # -------------------------
44
+ torch.set_grad_enabled(False)
45
+
46
  pipe = StableDiffusionPipeline.from_pretrained(
47
  MODEL_ID,
48
+ torch_dtype=DTYPE,
49
+ safety_checker=None # let model config handle; we block explicitly on prompts
50
  )
51
+
52
+ # Turbo still benefits from DPMSolver for CPU
53
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
54
 
55
+ if DEVICE == "cuda":
56
  pipe = pipe.to("cuda")
57
  pipe.enable_attention_slicing()
58
  pipe.enable_vae_slicing()
59
  else:
60
  pipe = pipe.to("cpu")
61
 
62
+ # -------------------------
63
+ # Generate fn (kept lean)
64
+ # -------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def generate(
66
  prompt: str,
67
  negative_prompt: str,
 
73
  batch_size: int
74
  ) -> Tuple[List[Image.Image], str]:
75
  if not prompt.strip():
76
+ return [], "Add a prompt first."
77
+
78
+ # block obvious NSFW prompts
79
+ if _is_nsfw(prompt) or _is_nsfw(negative_prompt or ""):
80
+ return [_blocked_tile("NSFW prompt detected", width, height)], "Blocked: NSFW prompt."
81
 
82
+ # SD-Turbo is designed for tiny step counts + low/zero CFG
83
+ # guard rails on parameters
84
+ steps = max(1, min(int(steps), 12))
85
+ guidance = max(0.0, min(float(guidance), 2.0))
86
 
87
  # Seed
88
  if seed < 0:
 
91
 
92
  out = pipe(
93
  prompt=prompt,
94
+ negative_prompt=(negative_prompt or None),
95
  num_inference_steps=steps,
96
  guidance_scale=guidance,
97
  width=width,
 
100
  generator=generator
101
  )
102
 
103
+ imgs = out.images
104
+ # Some sd-turbo configs may not return nsfw flags; we already block on prompt
105
+ msg = f"Model: {MODEL_ID} • Seed: {seed} • Steps: {steps} • CFG: {guidance} • {width}x{height} • Batch: {batch_size}"
106
+ return imgs, msg
 
 
 
 
 
 
 
 
 
107
 
108
+ # -------------------------
109
+ # UI (defaults tuned for CPU)
110
+ # -------------------------
111
+ with gr.Blocks(title="VibeForge — Fast (CPU-friendly) Image Gen") as demo:
112
  gr.Markdown(
113
  """
114
+ # VibeForge ⚒️
115
+ **Fast, clean image generation (CPU-friendly).**
116
+ Uses **SD-Turbo** tuned for low steps. NSFW inputs are blocked.
117
  """
118
  )
119
 
 
121
  with gr.Column(scale=3):
122
  prompt = gr.Textbox(
123
  label="Prompt",
124
+ placeholder="a neon-lit lighthouse on a stormy cliff at night, cinematic, volumetric fog, high contrast"
125
  )
126
+ negative = gr.Textbox(label="Negative Prompt", placeholder="low quality, watermark, overexposed")
127
  with gr.Row():
128
+ steps = gr.Slider(1, 12, value=4, step=1, label="Steps (SD-Turbo sweet spot: 2-6)")
129
+ guidance = gr.Slider(0.0, 2.0, value=0.5, step=0.1, label="CFG (SD-Turbo likes low)")
130
  with gr.Row():
131
+ width = gr.Dropdown(choices=[384, 448, 512], value=384, label="Width")
132
+ height = gr.Dropdown(choices=[384, 448, 512], value=384, label="Height")
133
  with gr.Row():
134
  seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
135
+ batch = gr.Slider(1, 2, value=1, step=1, label="Batch (keep small on CPU)")
136
 
137
  go = gr.Button("Generate", variant="primary")
138
 
139
  with gr.Column(scale=5):
140
+ gallery = gr.Gallery(label="Output", columns=2, height=448)
141
  info = gr.Markdown()
142
 
143
  go.click(
 
147
  )
148
 
149
  if __name__ == "__main__":
150
+ demo.launch()