Spaces:
Runtime error
Runtime error
| import os, random, re, torch | |
| from typing import List, Tuple | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
| # -------------------- | |
| # Config | |
| # -------------------- | |
| MODEL_ID = os.getenv("MODEL_ID", "runwayml/stable-diffusion-v1-5") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Simple prompt guardrail (blocks obvious NSFW attempts) | |
| NSFW_TERMS = [ | |
| r"\bnsfw\b", r"\bnude\b", r"\bnudity\b", r"\bsex\b", r"\bexplicit\b", r"\bporn\b", | |
| r"\bboobs\b", r"\bbutt\b", r"\bass\b", r"\bnsfw\b", r"\bnaked\b", r"\btits\b", | |
| r"\b18\+\b", r"\berotic\b", r"\bfetish\b" | |
| ] | |
| NSFW_REGEX = re.compile("|".join(NSFW_TERMS), flags=re.IGNORECASE) | |
| # -------------------- | |
| # Load pipeline | |
| # -------------------- | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE | |
| ) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| if torch.cuda.is_available(): | |
| pipe = pipe.to("cuda") | |
| pipe.enable_attention_slicing() | |
| pipe.enable_vae_slicing() | |
| else: | |
| pipe = pipe.to("cpu") | |
| # -------------------- | |
| # Helpers | |
| # -------------------- | |
| def blocked_tile(reason: str, width=512, height=512) -> Image.Image: | |
| img = Image.new("RGB", (width, height), (20, 20, 24)) | |
| draw = ImageDraw.Draw(img) | |
| text = f"BLOCKED\n{reason}" | |
| try: | |
| font = ImageFont.truetype("DejaVuSans-Bold.ttf", 28) | |
| except: | |
| font = ImageFont.load_default() | |
| tw, th = draw.multiline_textbbox((0,0), text, font=font)[2:] | |
| draw.multiline_text(((width - tw)//2, (height - th)//2), text, fill=(255,255,255), font=font, align="center") | |
| return img | |
| def is_prompt_nsfw(prompt: str) -> bool: | |
| return bool(NSFW_REGEX.search(prompt or "")) | |
| def generate( | |
| prompt: str, | |
| negative_prompt: str, | |
| steps: int, | |
| guidance: float, | |
| width: int, | |
| height: int, | |
| seed: int, | |
| batch_size: int | |
| ) -> Tuple[List[Image.Image], str]: | |
| if not prompt.strip(): | |
| return [], "Add a prompt to get rolling." | |
| # Hard block obvious NSFW prompts before hitting the model | |
| if is_prompt_nsfw(prompt) or is_prompt_nsfw(negative_prompt or ""): | |
| img = blocked_tile("NSFW prompt detected") | |
| return [img], "Blocked: NSFW prompt." | |
| # Seed | |
| if seed < 0: | |
| seed = random.randint(0, 2**31 - 1) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| out = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt or None, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance, | |
| width=width, | |
| height=height, | |
| num_images_per_prompt=batch_size, | |
| generator=generator | |
| ) | |
| images = out.images | |
| flags = getattr(out, "nsfw_content_detected", None) | |
| # If the underlying safety checker flags NSFW, block it (no blur) | |
| if flags: | |
| for i, flagged in enumerate(flags): | |
| if flagged: | |
| images[i] = blocked_tile("NSFW content flagged") | |
| msg = f"Seed: {seed} • Images: {len(images)}" | |
| if flags is not None: | |
| msg += f" • Flagged: {sum(1 for f in flags if f)}" | |
| return images, msg | |
| # -------------------- | |
| # UI | |
| # -------------------- | |
| with gr.Blocks(title="VibeForge — Clean Image Generator") as demo: | |
| gr.Markdown( | |
| """ | |
| # VibeForge ⚒️ | |
| **Clean, creative image generation.** | |
| NSFW inputs are blocked. Keep it classy and go wild on style, lighting, composition, mood. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="a cinematic photo of a vintage motorcycle by the ocean at sunset, golden hour, soft rim light, 50mm" | |
| ) | |
| negative = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, watermark, jpeg artifacts") | |
| with gr.Row(): | |
| steps = gr.Slider(10, 50, value=28, step=1, label="Steps") | |
| guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.1, label="CFG") | |
| with gr.Row(): | |
| width = gr.Dropdown(choices=[384, 448, 512, 640, 768], value=512, label="Width") | |
| height = gr.Dropdown(choices=[384, 448, 512, 640, 768], value=512, label="Height") | |
| with gr.Row(): | |
| seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0) | |
| batch = gr.Slider(1, 4, value=1, step=1, label="Batch") | |
| go = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=5): | |
| gallery = gr.Gallery(label="Output", columns=2, height=512) | |
| info = gr.Markdown() | |
| go.click( | |
| fn=generate, | |
| inputs=[prompt, negative, steps, guidance, width, height, seed, batch], | |
| outputs=[gallery, info] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |