|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import src.depth_pro as depth_pro |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import subprocess |
|
|
import spaces |
|
|
import torch |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
|
|
|
subprocess.run(["bash", "get_pretrained_models.sh"]) |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
model, transform = depth_pro.create_model_and_transforms() |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
def resize_image(image_path, max_size=1536): |
|
|
with Image.open(image_path) as img: |
|
|
|
|
|
ratio = max_size / max(img.size) |
|
|
new_size = tuple([int(x * ratio) for x in img.size]) |
|
|
|
|
|
|
|
|
img = img.resize(new_size, Image.LANCZOS) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: |
|
|
img.save(temp_file, format="PNG") |
|
|
return temp_file.name |
|
|
|
|
|
@spaces.GPU(duration=20) |
|
|
def predict_depth(input_image): |
|
|
temp_file = None |
|
|
try: |
|
|
|
|
|
temp_file = resize_image(input_image) |
|
|
|
|
|
|
|
|
result = depth_pro.load_rgb(temp_file) |
|
|
image = result[0] |
|
|
f_px = result[-1] |
|
|
image = transform(image) |
|
|
image = image.to(device) |
|
|
|
|
|
|
|
|
prediction = model.infer(image, f_px=f_px) |
|
|
depth = prediction["depth"] |
|
|
focallength_px = prediction["focallength_px"] |
|
|
|
|
|
|
|
|
if isinstance(depth, torch.Tensor): |
|
|
depth = depth.cpu().numpy() |
|
|
|
|
|
|
|
|
if depth.ndim != 2: |
|
|
depth = depth.squeeze() |
|
|
|
|
|
|
|
|
inverse_depth = 1.0 / depth |
|
|
|
|
|
|
|
|
inverse_depth_clipped = np.clip(inverse_depth, 0, 10) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(15.36, 15.36), dpi=100) |
|
|
plt.imshow(inverse_depth_clipped, cmap='viridis') |
|
|
plt.colorbar(label='Inverse Depth') |
|
|
plt.title('Predicted Inverse Depth Map') |
|
|
plt.axis('off') |
|
|
|
|
|
|
|
|
output_path = "inverse_depth_map.png" |
|
|
plt.savefig(output_path, dpi=100, bbox_inches='tight', pad_inches=0) |
|
|
plt.close() |
|
|
|
|
|
return output_path, f"Focal length: {focallength_px:.2f} pixels" |
|
|
except Exception as e: |
|
|
return None, f"An error occurred: {str(e)}" |
|
|
finally: |
|
|
|
|
|
if temp_file and os.path.exists(temp_file): |
|
|
os.remove(temp_file) |
|
|
|
|
|
|
|
|
example_images = [ |
|
|
"examples/lemur.jpg", |
|
|
] |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict_depth, |
|
|
inputs=gr.Image(type="filepath"), |
|
|
outputs=[ |
|
|
gr.Image(type="filepath", label="Inverse Depth Map", height=768, width=768), |
|
|
gr.Textbox(label="Focal Length or Error Message") |
|
|
], |
|
|
title="DepthPro Demo", |
|
|
description="[DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload an image to predict its inverse depth map and focal length. Large images will be automatically resized to 1536x1536 pixels.", |
|
|
examples=example_images |
|
|
) |
|
|
|
|
|
|
|
|
iface.launch() |