Spaces:
Runtime error
Runtime error
Commit
·
7302472
1
Parent(s):
459a0bd
[Safety Checker] Add Safety Checker Module
Browse filesFormer-commit-id: d0c714ae4afa1c011269a956d6f260f84f77025e
- scripts/txt2img.py +24 -1
scripts/txt2img.py
CHANGED
|
@@ -16,12 +16,29 @@ from ldm.util import instantiate_from_config
|
|
| 16 |
from ldm.models.diffusion.ddim import DDIMSampler
|
| 17 |
from ldm.models.diffusion.plms import PLMSSampler
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def chunk(it, size):
|
| 21 |
it = iter(it)
|
| 22 |
return iter(lambda: tuple(islice(it, size)), ())
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def load_model_from_config(config, ckpt, verbose=False):
|
| 26 |
print(f"Loading model from {ckpt}")
|
| 27 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
@@ -220,7 +237,9 @@ def main():
|
|
| 220 |
if opt.fixed_code:
|
| 221 |
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
| 222 |
|
|
|
|
| 223 |
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
|
|
|
| 224 |
with torch.no_grad():
|
| 225 |
with precision_scope("cuda"):
|
| 226 |
with model.ema_scope():
|
|
@@ -269,7 +288,11 @@ def main():
|
|
| 269 |
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
| 270 |
grid_count += 1
|
| 271 |
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
| 275 |
f" \nEnjoy.")
|
|
|
|
| 16 |
from ldm.models.diffusion.ddim import DDIMSampler
|
| 17 |
from ldm.models.diffusion.plms import PLMSSampler
|
| 18 |
|
| 19 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 20 |
+
from transformers import AutoFeatureExtractor
|
| 21 |
+
|
| 22 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True)
|
| 23 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True)
|
| 24 |
|
| 25 |
def chunk(it, size):
|
| 26 |
it = iter(it)
|
| 27 |
return iter(lambda: tuple(islice(it, size)), ())
|
| 28 |
|
| 29 |
|
| 30 |
+
def numpy_to_pil(images):
|
| 31 |
+
"""
|
| 32 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 33 |
+
"""
|
| 34 |
+
if images.ndim == 3:
|
| 35 |
+
images = images[None, ...]
|
| 36 |
+
images = (images * 255).round().astype("uint8")
|
| 37 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 38 |
+
|
| 39 |
+
return pil_images
|
| 40 |
+
|
| 41 |
+
|
| 42 |
def load_model_from_config(config, ckpt, verbose=False):
|
| 43 |
print(f"Loading model from {ckpt}")
|
| 44 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
|
| 237 |
if opt.fixed_code:
|
| 238 |
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
| 239 |
|
| 240 |
+
print("start code", start_code.abs().sum())
|
| 241 |
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
| 242 |
+
precision_scope = nullcontext
|
| 243 |
with torch.no_grad():
|
| 244 |
with precision_scope("cuda"):
|
| 245 |
with model.ema_scope():
|
|
|
|
| 288 |
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
| 289 |
grid_count += 1
|
| 290 |
|
| 291 |
+
image = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
| 292 |
+
|
| 293 |
+
# run safety checker
|
| 294 |
+
safety_checker_input = pipe.feature_extractor(numpy_to_pil(image), return_tensors="pt")
|
| 295 |
+
image, has_nsfw_concept = pipe.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
| 296 |
|
| 297 |
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
| 298 |
f" \nEnjoy.")
|