|
|
import gradio as gr |
|
|
import torch |
|
|
import io |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import spaces |
|
|
import math |
|
|
import re |
|
|
from einops import rearrange |
|
|
from mmengine.config import Config |
|
|
from xtuner.registry import BUILDER |
|
|
|
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from scripts.camera.cam_dataset import Cam_Generator |
|
|
from scripts.camera.visualization.visualize_batch import make_perspective_figures |
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
import os |
|
|
local_path = snapshot_download( |
|
|
repo_id="KangLiao/Puffin", |
|
|
repo_type="model", |
|
|
|
|
|
local_dir="checkpoints/", |
|
|
local_dir_use_symlinks=False, |
|
|
revision="main", |
|
|
) |
|
|
|
|
|
|
|
|
NUM = r"[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?" |
|
|
CAM_PATTERN = re.compile(r"(?:camera parameters.*?:|roll.*?:)\s*("+NUM+r")\s*,\s*("+NUM+r")\s*,\s*("+NUM+r")", re.IGNORECASE|re.DOTALL) |
|
|
|
|
|
def center_crop(image): |
|
|
w, h = image.size |
|
|
s = min(w, h) |
|
|
l = (w - s) // 2 |
|
|
t = (h - s) // 2 |
|
|
return image.crop((l, t, l + s, t + s)) |
|
|
|
|
|
|
|
|
|
|
|
config = "configs/pipelines/stage_2_base.py" |
|
|
config = Config.fromfile(config) |
|
|
model = BUILDER.build(config.model).cuda().bfloat16().eval() |
|
|
checkpoint_path = "checkpoints/Puffin-Base.pth" |
|
|
checkpoint = torch.load(checkpoint_path) |
|
|
info = model.load_state_dict(checkpoint, strict=False) |
|
|
|
|
|
def fig_to_image(fig): |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
|
buf.seek(0) |
|
|
img = Image.open(buf).convert('RGB') |
|
|
buf.close() |
|
|
return img |
|
|
|
|
|
def extract_up_lat_figs(fig_dict): |
|
|
fig_up, fig_lat = None, None |
|
|
others = {} |
|
|
for k, fig in fig_dict.items(): |
|
|
if ("up_field" in k) and (fig_up is None): |
|
|
fig_up = fig |
|
|
elif ("latitude_field" in k) and (fig_lat is None): |
|
|
fig_lat = fig |
|
|
else: |
|
|
others[k] = fig |
|
|
return fig_up, fig_lat, others |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
@spaces.GPU(duration=120) |
|
|
|
|
|
def camera_understanding(image_src, question, seed, progress=gr.Progress(track_tqdm=True)): |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(torch.cuda.is_available()) |
|
|
|
|
|
prompt = ("Describe the image in detail. Then reason its spatial distribution and estimate its camera parameters (roll, pitch, and field-of-view).") |
|
|
|
|
|
image = Image.fromarray(image_src).convert('RGB') |
|
|
image = center_crop(image) |
|
|
image = image.resize((512, 512)) |
|
|
x = torch.from_numpy(np.array(image)).float() |
|
|
x = x / 255.0 |
|
|
x = 2 * x - 1 |
|
|
x = rearrange(x, 'h w c -> c h w') |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.understand(prompt=[prompt], pixel_values=[x], progress_bar=False) |
|
|
|
|
|
text = outputs[0] |
|
|
|
|
|
gen = Cam_Generator(mode="base") |
|
|
cam = gen.get_cam(text) |
|
|
|
|
|
bgr = np.array(image)[:, :, ::-1].astype(np.float32) / 255.0 |
|
|
rgb = bgr[:, :, ::-1].copy() |
|
|
image_tensor = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) |
|
|
single_batch = {} |
|
|
single_batch["image"] = image_tensor |
|
|
single_batch["up_field"] = cam[:2].unsqueeze(0) |
|
|
single_batch["latitude_field"] = cam[2:].unsqueeze(0) |
|
|
|
|
|
figs = make_perspective_figures(single_batch, single_batch, n_pairs=1) |
|
|
up_img = lat_img = None |
|
|
for k, fig in figs.items(): |
|
|
if "up_field" in k: |
|
|
up_img = fig_to_image(fig) |
|
|
elif "latitude_field" in k: |
|
|
lat_img = fig_to_image(fig) |
|
|
plt.close(fig) |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
@spaces.GPU(duration=120) |
|
|
def generate_image(prompt_scene, |
|
|
seed=42, |
|
|
roll=0.1, |
|
|
pitch=0.1, |
|
|
fov=1.0, |
|
|
progress=gr.Progress(track_tqdm=True)): |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
print(torch.cuda.is_available()) |
|
|
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
prompt_camera = ( |
|
|
"The camera parameters (roll, pitch, and field-of-view) are: " |
|
|
f"{roll:.4f}, {pitch:.4f}, {fov:.4f}." |
|
|
) |
|
|
gen = Cam_Generator() |
|
|
cam_map = gen.get_cam(prompt_camera).to(model.device) |
|
|
cam_map = cam_map / (math.pi / 2) |
|
|
|
|
|
prompt = prompt_scene + " " + prompt_camera |
|
|
print("prompt:", prompt) |
|
|
|
|
|
bsz = 4 |
|
|
with torch.no_grad(): |
|
|
images, output_reasoning = model.generate( |
|
|
prompt=[prompt]*bsz, |
|
|
cfg_prompt=[""]*bsz, |
|
|
pixel_values_init=None, |
|
|
cfg_scale=4.5, |
|
|
num_steps=50, |
|
|
cam_values=[[cam_map]]*bsz, |
|
|
progress_bar=False, |
|
|
reasoning=False, |
|
|
prompt_reasoning=[""]*bsz, |
|
|
generator=generator, |
|
|
height=512, |
|
|
width=512 |
|
|
) |
|
|
|
|
|
images = rearrange(images, 'b c h w -> b h w c') |
|
|
images = torch.clamp(127.5 * images + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() |
|
|
ret_images = [Image.fromarray(image) for image in images] |
|
|
return ret_images |
|
|
|
|
|
|
|
|
|
|
|
css = ''' |
|
|
.gradio-container {max-width: 960px !important} |
|
|
''' |
|
|
with gr.Blocks(css=css) as demo: |
|
|
gr.Markdown("# Puffin") |
|
|
|
|
|
with gr.Tab("Camera-controllable Image Generation"): |
|
|
gr.Markdown(value="## Camera-controllable Image Generation") |
|
|
|
|
|
prompt_input = gr.Textbox(label="Prompt.") |
|
|
|
|
|
with gr.Accordion("Camera Parameters", open=True): |
|
|
with gr.Row(): |
|
|
roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value") |
|
|
pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value") |
|
|
fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value") |
|
|
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=42) |
|
|
|
|
|
generation_button = gr.Button("Generate Images") |
|
|
|
|
|
image_output = gr.Gallery(label="Generated Images", columns=4, rows=1) |
|
|
|
|
|
examples_t2i = gr.Examples( |
|
|
label="Prompt examples.", |
|
|
examples=[ |
|
|
"A sunny day casts light on two warmly colored buildings—yellow with green accents and deeper orange—framed by a lush green tree, with a blue sign and street lamp adding details in the foreground.", |
|
|
"A high-vantage-point view of lush, autumn-colored mountains blanketed in green and gold, set against a clear blue sky with scattered white clouds, offering a tranquil and breathtaking vista of a serene valley below.", |
|
|
"A grand, historic castle with pointed spires and elaborate stone structures stands against a clear blue sky, flanked by a circular fountain, vibrant red flowers, and neatly trimmed hedges in a beautifully landscaped garden.", |
|
|
"A serene aerial view of a coastal landscape at sunrise/sunset, featuring warm pink and orange skies transitioning to cool blues, with calm waters stretching to rugged, snow-capped mountains in the background, creating a tranquil and picturesque scene.", |
|
|
"A worn, light-yellow walls room with herringbone terracotta floors and three large arched windows framed in pink trim and white panes, showcasing signs of age and disrepair, overlooks a residential area through glimpses of greenery and neighboring buildings.", |
|
|
], |
|
|
inputs=prompt_input, |
|
|
) |
|
|
|
|
|
with gr.Tab("Camera Understanding"): |
|
|
gr.Markdown(value="## Camera Understanding") |
|
|
image_input = gr.Image() |
|
|
|
|
|
understanding_button = gr.Button("Chat") |
|
|
understanding_output = gr.Textbox(label="Response") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Accordion("Advanced options", open=False): |
|
|
und_seed_input = gr.Number(label="Seed", precision=0, value=42) |
|
|
|
|
|
examples_inpainting = gr.Examples( |
|
|
label="Camera Understanding examples", |
|
|
examples=[ |
|
|
"assets/1.jpg", |
|
|
"assets/2.jpg", |
|
|
"assets/3.jpg", |
|
|
"assets/4.jpg", |
|
|
"assets/5.jpg", |
|
|
"assets/6.jpg", |
|
|
], |
|
|
inputs=image_input, |
|
|
) |
|
|
|
|
|
generation_button.click( |
|
|
fn=generate_image, |
|
|
inputs=[prompt_input, seed_input, roll, pitch, fov], |
|
|
outputs=image_output |
|
|
) |
|
|
|
|
|
understanding_button.click( |
|
|
camera_understanding, |
|
|
inputs=[image_input, und_seed_input], |
|
|
outputs=[understanding_output] |
|
|
) |
|
|
|
|
|
demo.launch(share=True) |