jbilcke-hf's picture
Convert AI-Toolkit to a HF Space
8822914
import math
import os
import random
from collections import OrderedDict
from typing import List
import numpy as np
from PIL import Image
from diffusers import T2IAdapter
from diffusers.utils.torch_utils import randn_tensor
from torch.utils.data import DataLoader
from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
from tqdm import tqdm
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.sampler import get_sampler
from toolkit.stable_diffusion_model import StableDiffusion
import gc
import torch
from jobs.process import BaseExtensionProcess
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.train_tools import get_torch_dtype
from controlnet_aux.midas import MidasDetector
from diffusers.utils import load_image
from torchvision.transforms import ToTensor
def flush():
torch.cuda.empty_cache()
gc.collect()
class GenerateConfig:
def __init__(self, **kwargs):
self.prompts: List[str]
self.sampler = kwargs.get('sampler', 'ddpm')
self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1)
self.walk_seed = kwargs.get('walk_seed', False)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext = kwargs.get('ext', 'png')
self.denoise_strength = kwargs.get('denoise_strength', 0.5)
self.trigger_word = kwargs.get('trigger_word', None)
class Img2ImgGenerator(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
self.device = self.get_conf('device', 'cuda')
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
self.is_latents_cached = True
raw_datasets = self.get_conf('datasets', None)
if raw_datasets is not None and len(raw_datasets) > 0:
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
self.dtype = self.get_conf('dtype', 'float16')
self.torch_dtype = get_torch_dtype(self.dtype)
self.params = []
if raw_datasets is not None and len(raw_datasets) > 0:
for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset)
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
if not is_caching:
self.is_latents_cached = False
if dataset.is_reg:
if self.datasets_reg is None:
self.datasets_reg = []
self.datasets_reg.append(dataset)
else:
if self.datasets is None:
self.datasets = []
self.datasets.append(dataset)
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.dtype,
)
print(f"Using device {self.device}")
self.data_loader: DataLoader = None
self.adapter: T2IAdapter = None
def to_pil(self, img):
# image comes in -1 to 1. convert to a PIL RGB image
img = (img + 1) / 2
img = img.clamp(0, 1)
img = img[0].permute(1, 2, 0).cpu().numpy()
img = (img * 255).astype(np.uint8)
image = Image.fromarray(img)
return image
def run(self):
with torch.no_grad():
super().run()
print("Loading model...")
self.sd.load_model()
device = torch.device(self.device)
if self.model_config.is_xl:
pipe = StableDiffusionXLImg2ImgPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=get_sampler(self.generate_config.sampler),
).to(device, dtype=self.torch_dtype)
elif self.model_config.is_pixart:
pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
else:
raise NotImplementedError("Only XL models are supported")
pipe.set_progress_bar_config(disable=True)
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
num_batches = len(self.data_loader)
pbar = tqdm(total=num_batches, desc="Generating images")
seed = self.generate_config.seed
# load images from datasets, use tqdm
for i, batch in enumerate(self.data_loader):
batch: DataLoaderBatchDTO = batch
gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
generator = torch.manual_seed(gen_seed)
file_item: FileItemDTO = batch.file_items[0]
img_path = file_item.path
img_filename = os.path.basename(img_path)
img_filename_no_ext = os.path.splitext(img_filename)[0]
img_filename = img_filename_no_ext + '.' + self.generate_config.ext
output_path = os.path.join(self.output_folder, img_filename)
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
if self.copy_inputs_to is not None:
output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
else:
output_inputs_path = None
output_inputs_caption_path = None
caption = batch.get_caption_list()[0]
if self.generate_config.trigger_word is not None:
caption = caption.replace('[trigger]', self.generate_config.trigger_word)
img: torch.Tensor = batch.tensor.clone()
image = self.to_pil(img)
# image.save(output_depth_path)
if self.model_config.is_pixart:
pipe: PixArtSigmaPipeline = pipe
# Encode the full image once
encoded_image = pipe.vae.encode(
pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
if hasattr(encoded_image, "latent_dist"):
latents = encoded_image.latent_dist.sample(generator)
elif hasattr(encoded_image, "latents"):
latents = encoded_image.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
latents = pipe.vae.config.scaling_factor * latents
# latents = self.sd.encode_images(img)
# self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
# start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
# timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
# timestep = timestep.to(device, dtype=torch.int32)
# latent = latent.to(device, dtype=self.torch_dtype)
# noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
# latent = self.sd.add_noise(latent, noise, timestep)
# timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
batch_size = 1
num_images_per_prompt = 1
shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
image.width // pipe.vae_scale_factor)
noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
# noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
num_inference_steps = self.generate_config.sample_steps
strength = self.generate_config.denoise_strength
# Get timesteps
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
timesteps = pipe.scheduler.timesteps[t_start:]
timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents = pipe.scheduler.add_noise(latents, noise, timestep)
gen_images = pipe.__call__(
prompt=caption,
negative_prompt=self.generate_config.neg,
latents=latents,
timesteps=timesteps,
width=image.width,
height=image.height,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images_per_prompt,
guidance_scale=self.generate_config.guidance_scale,
# strength=self.generate_config.denoise_strength,
use_resolution_binning=False,
output_type="np"
).images[0]
gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
gen_images = Image.fromarray(gen_images)
else:
pipe: StableDiffusionXLImg2ImgPipeline = pipe
gen_images = pipe.__call__(
prompt=caption,
negative_prompt=self.generate_config.neg,
image=image,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
strength=self.generate_config.denoise_strength,
).images[0]
os.makedirs(os.path.dirname(output_path), exist_ok=True)
gen_images.save(output_path)
# save caption
with open(output_caption_path, 'w') as f:
f.write(caption)
if output_inputs_path is not None:
os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
image.save(output_inputs_path)
with open(output_inputs_caption_path, 'w') as f:
f.write(caption)
pbar.update(1)
batch.cleanup()
pbar.close()
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()