Chain-of-Zoom / inference_coz_single.py
alexnasa's picture
VLM lora added
584caad verified
import os
import tempfile
import uuid
import torch
from PIL import Image
from torchvision import transforms
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler
from peft import PeftModel
# -------------------------------------------------------------------
# Helper: Resize & center-crop to a fixed square
# -------------------------------------------------------------------
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
w, h = img.size
scale = size / min(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.LANCZOS)
left = (new_w - size) // 2
top = (new_h - size) // 2
return img.crop((left, top, left + size, top + size))
# -------------------------------------------------------------------
# Helper: Generate a single VLM prompt for recursive_multiscale
# -------------------------------------------------------------------
def _generate_vlm_prompt(
vlm_model: Qwen2_5_VLForConditionalGeneration,
vlm_processor: AutoProcessor,
process_vision_info, # this is your helper that turns “messages” → image_inputs / video_inputs
prev_pil: Image.Image, # <– pass PIL instead of path
zoomed_pil: Image.Image, # <– pass PIL instead of path
device: str = "cuda"
) -> str:
"""
Given two PIL.Image inputs:
- prev_pil: the “full” image at the previous recursion.
- zoomed_pil: the cropped+resized (zoom) image for this step.
Returns a single “recursive_multiscale” prompt string.
"""
# (1) System message
message_text = (
"The second image is a zoom-in of the first image. "
"Based on this knowledge, what is in the second image? "
"Give me a set of words."
)
# (2) Build the two-image “chat” payload
#
# Instead of passing a filename, we pass the actual PIL.Image.
# The processor’s `process_vision_info` should know how to turn
# a message of the form {"type":"image","image": PIL_IMAGE} into tensors.
messages = [
{"role": "system", "content": message_text},
{
"role": "user",
"content": [
{"type": "image", "image": prev_pil},
{"type": "image", "image": zoomed_pil},
],
},
]
# (3) Now run the “chat” through the VL processor
#
# - `apply_chat_template` will build the tokenized prompt (without running it yet).
# - `process_vision_info` should inspect the same `messages` list and return
# `image_inputs` and `video_inputs` (tensors) for any attached PIL images.
text = vlm_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = vlm_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(device)
# (4) Generate and decode
generated = vlm_model.generate(**inputs, max_new_tokens=128)
trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated)
]
out_text = vlm_processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return out_text.strip()
VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
VLM_NAME,
torch_dtype="auto",
device_map="auto" # immediately dispatches layers onto available GPUs
)
vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)
vlm_model = PeftModel.from_pretrained(vlm_model, "ckpt/VLM_LoRA/checkpoint-10000")
vlm_model = vlm_model.merge_and_unload()
vlm_model.eval()
device = "cuda"
process_size = 512
LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt"
SD3_MODEL = "stabilityai/stable-diffusion-3-medium-diffusers"
class _Args:
pass
args = _Args()
args.upscale = 4
args.lora_path = LORA_PATH
args.vae_path = VAE_PATH
args.pretrained_model_name_or_path = SD3_MODEL
args.merge_and_unload_lora = False
args.lora_rank = 4
args.vae_decoder_tiled_size = 224
args.vae_encoder_tiled_size = 1024
args.latent_tiled_size = 96
args.latent_tiled_overlap = 32
args.mixed_precision = "fp16"
args.efficient_memory = False
sd3 = SD3Euler()
sd3.text_enc_1.to(device)
sd3.text_enc_2.to(device)
sd3.text_enc_3.to(device)
sd3.transformer.to(device, dtype=torch.float32)
sd3.vae.to(device, dtype=torch.float32)
for p in (
sd3.text_enc_1,
sd3.text_enc_2,
sd3.text_enc_3,
sd3.transformer,
sd3.vae,
):
p.requires_grad_(False)
model_test = OSEDiff_SD3_TEST(args, sd3)
# -------------------------------------------------------------------
# Main Function: recursive_multiscale_sr (with multiple centers)
# -------------------------------------------------------------------
def recursive_multiscale_sr(
input_png_path: str,
upscale: int,
rec_num: int = 4,
centers: list[tuple[float, float]] = None,
) -> tuple[list[Image.Image], list[str]]:
"""
Perform `rec_num` recursive_multiscale super-resolution steps on a single PNG.
- input_png_path: path to a single .png file on disk.
- upscale: integer up-scale factor per recursion (e.g. 4).
- rec_num: how many recursion steps to perform.
- centers: a list of normalized (x, y) tuples in [0, 1], one per recursion step,
indicating where to center the low-res crop for each step. The list
length must equal rec_num. If centers is None, defaults to center=(0.5, 0.5)
for all steps.
Returns a tuple (sr_pil_list, prompt_list), where:
- sr_pil_list: list of PIL.Image outputs [SR1, SR2, …, SR_rec_num] in order.
- prompt_list: list of the VLM prompts generated at each recursion.
"""
###############################
# 0. Validate / fill default centers
###############################
if centers is None:
# Default: use center (0.5, 0.5) for every recursion
centers = [(0.5, 0.5) for _ in range(rec_num)]
else:
if not isinstance(centers, (list, tuple)) or len(centers) != rec_num:
raise ValueError(
f"`centers` must be a list of {rec_num} (x,y) tuples, but got length {len(centers)}."
)
unique_id = uuid.uuid4().hex
prefix = f"recms_{unique_id}_"
with tempfile.TemporaryDirectory(prefix=prefix) as td:
img0 = Image.open(input_png_path).convert("RGB")
img0 = resize_and_center_crop(img0, process_size)
prev_pil = img0.copy()
sr_pil_list: list[Image.Image] = []
prompt_list: list[str] = []
for rec in range(rec_num):
w, h = prev_pil.size # (512×512)
new_w, new_h = w // upscale, h // upscale
cx_norm, cy_norm = centers[rec]
cx = int(cx_norm * w)
cy = int(cy_norm * h)
half_w, half_h = new_w // 2, new_h // 2
left = max(0, min(cx - half_w, w - new_w))
top = max(0, min(cy - half_h, h - new_h))
right, bottom = left + new_w, top + new_h
cropped = prev_pil.crop((left, top, right, bottom))
zoomed_pil = cropped.resize((w, h), Image.BICUBIC)
prompt_tag = _generate_vlm_prompt(
vlm_model=vlm_model,
vlm_processor=vlm_processor,
process_vision_info=process_vision_info,
prev_pil=prev_pil, # <– PIL
zoomed_pil=zoomed_pil, # <– PIL
device=device,
)
to_tensor = transforms.ToTensor()
lq = to_tensor(zoomed_pil).unsqueeze(0).to(device) # (1,3,512,512)
lq = (lq * 2.0) - 1.0
with torch.no_grad():
out_tensor = model_test(lq, prompt=prompt_tag)[0]
out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
prev_pil = out_pil
# (G) Append to results
sr_pil_list.append(out_pil)
prompt_list.append(prompt_tag)
return sr_pil_list, prompt_list