Wan2GP_Demo / models /flux /flux_main.py
USF00's picture
Initial commit
2b67076
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
from mmgp import offload as offload
import torch
from shared.utils.utils import calculate_new_dimensions
from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack
from .modules.layers import get_linear_split_map
from transformers import SiglipVisionModel, SiglipImageProcessor
import torchvision.transforms.functional as TVF
import math
from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
from shared.utils import files_locator as fl
from .util import (
aspect_ratio_to_height_width,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
from PIL import Image
def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
# 获取原始图像的宽度和高度
image_w, image_h = raw_image.size
# 计算长边和短边
if image_w >= image_h:
new_w = long_size
new_h = int((long_size / image_w) * image_h)
else:
new_h = long_size
new_w = int((long_size / image_h) * image_w)
# 按新的宽高进行等比例缩放
raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
target_w = new_w // 16 * 16
target_h = new_h // 16 * 16
# 计算裁剪的起始坐标以实现中心裁剪
left = (new_w - target_w) // 2
top = (new_h - target_h) // 2
right = left + target_w
bottom = top + target_h
# 进行中心裁剪
raw_image = raw_image.crop((left, top, right, bottom))
# 转换为 RGB 模式
raw_image = raw_image.convert("RGB")
return raw_image
def stitch_images(img1, img2):
# Resize img2 to match img1's height
width1, height1 = img1.size
width2, height2 = img2.size
new_width2 = int(width2 * height1 / height2)
img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS)
stitched = Image.new('RGB', (width1 + new_width2, height1))
stitched.paste(img1, (0, 0))
stitched.paste(img2_resized, (width1, 0))
return stitched
class model_factory:
def __init__(
self,
checkpoint_dir,
model_filename = None,
model_type = None,
model_def = None,
base_model_type = None,
text_encoder_filename = None,
quantizeTransformer = False,
save_quantized = False,
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
mixed_precision_transformer = False
):
self.device = torch.device(f"cuda")
self.VAE_dtype = VAE_dtype
self.dtype = dtype
torch_device = "cpu"
self.guidance_max_phases = model_def.get("guidance_max_phases", 0)
# model_filename = ["c:/temp/flux1-schnell.safetensors"]
self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512)
self.clip = load_clip(torch_device)
self.name = model_def.get("flux-model", "flux-dev")
# self.name= "flux-dev-kontext"
# self.name= "flux-dev"
# self.name= "flux-schnell"
source = model_def.get("source", None)
self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device)
self.model_def = model_def
self.vae = load_ae(self.name, device=torch_device)
siglip_processor = siglip_model = feature_embedder = None
if self.name == 'flux-dev-uso':
siglip_path = fl.locate_folder("siglip-so400m-patch14-384")
siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
siglip_model = SiglipVisionModel.from_pretrained(siglip_path)
siglip_model.eval().to("cpu")
if len(model_filename) > 1:
from .modules.layers import SigLIPMultiFeatProjModel
feature_embedder = SigLIPMultiFeatProjModel(
siglip_token_nums=729,
style_token_nums=64,
siglip_token_dims=1152,
hidden_size=3072, #self.hidden_size,
context_layer_norm=True,
)
offload.load_model_data(feature_embedder, model_filename[1])
self.vision_encoder = siglip_model
self.vision_encoder_processor = siglip_processor
self.feature_embedder = feature_embedder
# offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "flux-dev.safetensors")
if not source is None:
from wgp import save_model
save_model(self.model, model_type, dtype, None)
if save_quantized:
from wgp import save_quantized_model
save_quantized_model(self.model, model_type, model_filename[0], dtype, None)
split_linear_modules_map = get_linear_split_map()
self.model.split_linear_modules_map = split_linear_modules_map
offload.split_linear_modules(self.model, split_linear_modules_map )
def generate(
self,
seed: int | None = None,
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
n_prompt: str = None,
sampling_steps: int = 20,
input_ref_images = None,
input_frames= None,
input_masks= None,
width= 832,
height=480,
embedded_guidance_scale: float = 2.5,
guide_scale = 2.5,
fit_into_canvas = None,
callback = None,
loras_slists = None,
batch_size = 1,
video_prompt_type = "",
joint_pass = False,
image_refs_relative_size = 100,
denoising_strength = 1.,
**bbargs
):
if self._interrupt:
return None
if self.guidance_max_phases < 1: guide_scale = 1
if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
device="cuda"
flux_dev_uso = self.name in ['flux-dev-uso']
flux_dev_umo = self.name in ['flux-dev-umo']
latent_stiching = self.name in ['flux-dev-uso', 'flux-dev-umo']
lock_dimensions= False
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
if flux_dev_umo:
ref_long_side = 512 if len(input_ref_images) <= 1 else 320
input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images]
lock_dimensions = True
ref_style_imgs = []
if "I" in video_prompt_type and len(input_ref_images) > 0:
if flux_dev_uso :
if "J" in video_prompt_type:
ref_style_imgs = input_ref_images
input_ref_images = []
elif len(input_ref_images) > 1 :
ref_style_imgs = input_ref_images[-1:]
input_ref_images = input_ref_images[:-1]
if latent_stiching:
# latents stiching with resize
if not lock_dimensions :
for i in range(len(input_ref_images)):
w, h = input_ref_images[i].size
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0)
input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
else:
# image stiching method
stiched = input_ref_images[0]
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
elif input_frames is not None:
input_ref_images = [convert_tensor_to_image(input_frames) ]
else:
input_ref_images = None
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
if self.name in ['flux-dev-uso', 'flux-dev-umo'] :
inp, height, width = prepare_multi_ip(
ae=self.vae,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
)
else:
inp, height, width = prepare_kontext(
ae=self.vae,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
img_mask=image_mask,
)
inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt))
if guide_scale != 1:
inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device))
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs]
if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None:
# processing style feat into textural hidden space
siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs]
siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1)
siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device)
inp["siglip_embedding"] = siglip_embedding
inp["siglip_embedding_ids"] = siglip_embedding_ids
def unpack_latent(x):
return unpack(x.float(), height, width)
# denoise initial noise
x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength)
if x==None: return None
# decode latents to pixel space
x = unpack_latent(x)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
x = self.vae.decode(x)
if image_mask is not None:
from shared.utils.utils import convert_image_to_tensor
img_msk_rebuilt = inp["img_msk_rebuilt"]
img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide)
x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
x = x.clamp(-1, 1)
x = x.transpose(0, 1)
return x