Puffin / app.py
wusize's picture
Upload folder using huggingface_hub
1a2a9f7 verified
raw
history blame
9.02 kB
import gradio as gr
import torch
import io
from PIL import Image
import numpy as np
import spaces # Import spaces for ZeroGPU compatibility
import math
import re
from einops import rearrange
from mmengine.config import Config
from xtuner.registry import BUILDER
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scripts.camera.cam_dataset import Cam_Generator
from scripts.camera.visualization.visualize_batch import make_perspective_figures
from huggingface_hub import snapshot_download
import os
local_path = snapshot_download(
repo_id="KangLiao/Puffin",
repo_type="model",
#filename="Puffin-Base.pth",
local_dir="checkpoints/",
local_dir_use_symlinks=False,
revision="main",
)
NUM = r"[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?"
CAM_PATTERN = re.compile(r"(?:camera parameters.*?:|roll.*?:)\s*("+NUM+r")\s*,\s*("+NUM+r")\s*,\s*("+NUM+r")", re.IGNORECASE|re.DOTALL)
def center_crop(image):
w, h = image.size
s = min(w, h)
l = (w - s) // 2
t = (h - s) // 2
return image.crop((l, t, l + s, t + s))
##### load model
config = "configs/pipelines/stage_2_base.py"
config = Config.fromfile(config)
model = BUILDER.build(config.model).cuda().bfloat16().eval()
checkpoint_path = "checkpoints/Puffin-Base.pth"
checkpoint = torch.load(checkpoint_path)
info = model.load_state_dict(checkpoint, strict=False)
def fig_to_image(fig):
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
img = Image.open(buf).convert('RGB')
buf.close()
return img
def extract_up_lat_figs(fig_dict):
fig_up, fig_lat = None, None
others = {}
for k, fig in fig_dict.items():
if ("up_field" in k) and (fig_up is None):
fig_up = fig
elif ("latitude_field" in k) and (fig_lat is None):
fig_lat = fig
else:
others[k] = fig
return fig_up, fig_lat, others
@torch.inference_mode()
@spaces.GPU(duration=120)
# Multimodal Understanding function
def camera_understanding(image_src, question, seed, progress=gr.Progress(track_tqdm=True)):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
# set seed
# torch.manual_seed(seed)
# np.random.seed(seed)
# torch.cuda.manual_seed(seed)
print(torch.cuda.is_available())
prompt = ("Describe the image in detail. Then reason its spatial distribution and estimate its camera parameters (roll, pitch, and field-of-view).")
image = Image.fromarray(image_src).convert('RGB')
image = center_crop(image)
image = image.resize((512, 512))
x = torch.from_numpy(np.array(image)).float()
x = x / 255.0
x = 2 * x - 1
x = rearrange(x, 'h w c -> c h w')
with torch.no_grad():
outputs = model.understand(prompt=[prompt], pixel_values=[x], progress_bar=False)
text = outputs[0]
gen = Cam_Generator(mode="base")
cam = gen.get_cam(text)
bgr = np.array(image)[:, :, ::-1].astype(np.float32) / 255.0
rgb = bgr[:, :, ::-1].copy()
image_tensor = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
single_batch = {}
single_batch["image"] = image_tensor
single_batch["up_field"] = cam[:2].unsqueeze(0)
single_batch["latitude_field"] = cam[2:].unsqueeze(0)
figs = make_perspective_figures(single_batch, single_batch, n_pairs=1)
up_img = lat_img = None
for k, fig in figs.items():
if "up_field" in k:
up_img = fig_to_image(fig)
elif "latitude_field" in k:
lat_img = fig_to_image(fig)
plt.close(fig)
return text#, up_img, lat_img
@torch.inference_mode()
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
def generate_image(prompt_scene,
seed=42,
roll=0.1,
pitch=0.1,
fov=1.0,
progress=gr.Progress(track_tqdm=True)):
# Clear CUDA cache and avoid tracking gradients
torch.cuda.empty_cache()
# Set the seed for reproducible results
# if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
print(torch.cuda.is_available())
generator = torch.Generator().manual_seed(seed)
prompt_camera = (
"The camera parameters (roll, pitch, and field-of-view) are: "
f"{roll:.4f}, {pitch:.4f}, {fov:.4f}."
)
gen = Cam_Generator()
cam_map = gen.get_cam(prompt_camera).to(model.device)
cam_map = cam_map / (math.pi / 2)
prompt = prompt_scene + " " + prompt_camera
print("prompt:", prompt)
bsz = 4
with torch.no_grad():
images, output_reasoning = model.generate(
prompt=[prompt]*bsz,
cfg_prompt=[""]*bsz,
pixel_values_init=None,
cfg_scale=4.5,
num_steps=50,
cam_values=[[cam_map]]*bsz,
progress_bar=False,
reasoning=False,
prompt_reasoning=[""]*bsz,
generator=generator,
height=512,
width=512
)
images = rearrange(images, 'b c h w -> b h w c')
images = torch.clamp(127.5 * images + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
ret_images = [Image.fromarray(image) for image in images]
return ret_images
# Gradio interface
css = '''
.gradio-container {max-width: 960px !important}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# Puffin")
with gr.Tab("Camera-controllable Image Generation"):
gr.Markdown(value="## Camera-controllable Image Generation")
prompt_input = gr.Textbox(label="Prompt.")
with gr.Accordion("Camera Parameters", open=True):
with gr.Row():
roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value")
pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value")
fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value")
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=42)
generation_button = gr.Button("Generate Images")
image_output = gr.Gallery(label="Generated Images", columns=4, rows=1)
examples_t2i = gr.Examples(
label="Prompt examples.",
examples=[
"A sunny day casts light on two warmly colored buildings—yellow with green accents and deeper orange—framed by a lush green tree, with a blue sign and street lamp adding details in the foreground.",
"A high-vantage-point view of lush, autumn-colored mountains blanketed in green and gold, set against a clear blue sky with scattered white clouds, offering a tranquil and breathtaking vista of a serene valley below.",
"A grand, historic castle with pointed spires and elaborate stone structures stands against a clear blue sky, flanked by a circular fountain, vibrant red flowers, and neatly trimmed hedges in a beautifully landscaped garden.",
"A serene aerial view of a coastal landscape at sunrise/sunset, featuring warm pink and orange skies transitioning to cool blues, with calm waters stretching to rugged, snow-capped mountains in the background, creating a tranquil and picturesque scene.",
"A worn, light-yellow walls room with herringbone terracotta floors and three large arched windows framed in pink trim and white panes, showcasing signs of age and disrepair, overlooks a residential area through glimpses of greenery and neighboring buildings.",
],
inputs=prompt_input,
)
with gr.Tab("Camera Understanding"):
gr.Markdown(value="## Camera Understanding")
image_input = gr.Image()
understanding_button = gr.Button("Chat")
understanding_output = gr.Textbox(label="Response")
#camera1 = gr.Gallery(label="Camera Maps", columns=1, rows=1)
#camera2 = gr.Gallery(label="Camera Maps", columns=1, rows=1)
with gr.Accordion("Advanced options", open=False):
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
examples_inpainting = gr.Examples(
label="Camera Understanding examples",
examples=[
"assets/1.jpg",
"assets/2.jpg",
"assets/3.jpg",
"assets/4.jpg",
"assets/5.jpg",
"assets/6.jpg",
],
inputs=image_input,
)
generation_button.click(
fn=generate_image,
inputs=[prompt_input, seed_input, roll, pitch, fov],
outputs=image_output
)
understanding_button.click(
camera_understanding,
inputs=[image_input, und_seed_input],
outputs=[understanding_output]#, camera1, camera2]
)
demo.launch(share=True)