Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| import gc | |
| import os | |
| import shutil | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| from collections import defaultdict | |
| from typing import List, Dict, Tuple | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import trimesh | |
| from PIL import Image | |
| from pillow_heif import register_heif_opener | |
| from sklearn.cluster import DBSCAN | |
| from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals | |
| from mapanything.utils.hf_utils.css_and_html import ( | |
| get_gradio_theme, | |
| GRADIO_CSS, | |
| ) | |
| from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model, initialize_mapanything_local | |
| from mapanything.utils.hf_utils.viz import predictions_to_glb | |
| from mapanything.utils.image import load_images, rgb | |
| register_heif_opener() | |
| sys.path.append("mapanything/") | |
| # ============================================================================ | |
| # 全局配置 | |
| # ============================================================================ | |
| # MapAnything Configuration | |
| high_level_config = { | |
| "path": "configs/train.yaml", | |
| "hf_model_name": "facebook/map-anything", | |
| "model_str": "mapanything", | |
| "config_overrides": [ | |
| "machine=aws", | |
| "model=mapanything", | |
| "model/task=images_only", | |
| "model.encoder.uses_torch_hub=false", | |
| ], | |
| "checkpoint_name": "model.safetensors", | |
| "config_name": "config.json", | |
| "trained_with_amp": True, | |
| "trained_with_amp_dtype": "bf16", | |
| "data_norm_type": "dinov2", | |
| "patch_size": 14, | |
| "resolution": 518, | |
| } | |
| # GroundingDINO 配置 - 从 HuggingFace 加载 | |
| GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" | |
| GROUNDING_DINO_BOX_THRESHOLD = 0.25 | |
| GROUNDING_DINO_TEXT_THRESHOLD = 0.2 | |
| # SAM 配置 - 使用 HuggingFace 的 SAM 模型 | |
| SAM_MODEL_ID = "facebook/sam-vit-huge" # 或使用 "facebook/sam-vit-base" 更快更小 | |
| DEFAULT_TEXT_PROMPT = "window . table . sofa . tv . book . door" | |
| # 通用物体列表(GroundingDINO 会检测图像中存在的物体) | |
| COMMON_OBJECTS_PROMPT = ( | |
| "person . face . hand . " | |
| "chair . sofa . couch . bed . table . desk . cabinet . shelf . drawer . " | |
| "door . window . wall . floor . ceiling . curtain . " | |
| "tv . monitor . screen . computer . laptop . keyboard . mouse . " | |
| "phone . tablet . remote . " | |
| "lamp . light . chandelier . " | |
| "book . magazine . paper . pen . pencil . " | |
| "bottle . cup . glass . mug . plate . bowl . fork . knife . spoon . " | |
| "vase . plant . flower . pot . " | |
| "clock . picture . frame . mirror . " | |
| "pillow . cushion . blanket . towel . " | |
| "bag . backpack . suitcase . " | |
| "box . basket . container . " | |
| "shoe . hat . coat . " | |
| "toy . ball . " | |
| "car . bicycle . motorcycle . bus . truck . " | |
| "tree . grass . sky . cloud . sun . " | |
| "dog . cat . bird . " | |
| "building . house . bridge . road . street . " | |
| "sign . pole . bench" | |
| ) | |
| # V8: DBSCAN聚类配置 | |
| # 根据物体类型设置不同的聚类半径(eps) | |
| DBSCAN_EPS_CONFIG = { | |
| 'sofa': 1.5, # 沙发:1.5米半径(大物体,同一个沙发的检测可能相距较远) | |
| 'bed': 1.5, | |
| 'couch': 1.5, | |
| 'desk': 0.8, # 桌子:0.8米半径(中等物体) | |
| 'table': 0.8, | |
| 'chair': 0.6, # 椅子:0.6米(较小) | |
| 'cabinet': 0.8, | |
| 'window': 0.5, # 窗户:0.5米(位置固定,聚类严格) | |
| 'door': 0.6, | |
| 'tv': 0.6, | |
| 'default': 1.0 # 默认:1米 | |
| } | |
| DBSCAN_MIN_SAMPLES = 1 # 最小样本数(设为1意味着单个检测也能成为一个簇) | |
| ENABLE_VISUAL_FEATURES = False | |
| # 分割质量控制 | |
| MIN_DETECTION_CONFIDENCE = 0.35 # 最低检测置信度(过滤误检测) | |
| MIN_MASK_AREA = 100 # 最小mask面积(像素) | |
| # 匹配分数计算配置(用于备用匹配算法) | |
| MATCH_3D_DISTANCE_THRESHOLD = 2.5 # 3D距离阈值(米) | |
| # 全局模型变量 | |
| model = None | |
| grounding_dino_model = None | |
| grounding_dino_processor = None | |
| sam_predictor = None | |
| # ============================================================================ | |
| # 分割模型加载 | |
| # ============================================================================ | |
| def load_grounding_dino_model(device): | |
| """加载 GroundingDINO 模型 - 从 HuggingFace""" | |
| global grounding_dino_model, grounding_dino_processor | |
| if grounding_dino_model is not None: | |
| print("✅ GroundingDINO 已加载") | |
| return | |
| try: | |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
| print(f"📥 从 HuggingFace 加载 GroundingDINO: {GROUNDING_DINO_MODEL_ID}") | |
| grounding_dino_processor = AutoProcessor.from_pretrained(GROUNDING_DINO_MODEL_ID) | |
| grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained( | |
| GROUNDING_DINO_MODEL_ID | |
| ).to(device).eval() | |
| print("✅ GroundingDINO 加载成功") | |
| except Exception as e: | |
| print(f"❌ GroundingDINO 加载失败: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| def load_sam_model(device): | |
| """加载 SAM 模型 - 从 HuggingFace""" | |
| global sam_predictor | |
| if sam_predictor is not None: | |
| print("✅ SAM 已加载") | |
| return | |
| try: | |
| from transformers import SamModel, SamProcessor | |
| print(f"📥 从 HuggingFace 加载 SAM: {SAM_MODEL_ID}") | |
| sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device).eval() | |
| sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID) | |
| # 将模型和处理器存储为全局变量 | |
| sam_predictor = {'model': sam_model, 'processor': sam_processor} | |
| print("✅ SAM 加载成功") | |
| except Exception as e: | |
| print(f"❌ SAM 加载失败: {e}") | |
| print(" SAM 功能将被禁用,将使用边界框作为mask") | |
| import traceback | |
| traceback.print_exc() | |
| # ============================================================================ | |
| # 分割功能 | |
| # ============================================================================ | |
| def generate_distinct_colors(n): | |
| """生成 N 个视觉上区分度高的颜色(RGB,0-255)""" | |
| import colorsys | |
| if n == 0: | |
| return [] | |
| colors = [] | |
| for i in range(n): | |
| hue = i / max(n, 1) | |
| rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95) | |
| rgb_color = tuple(int(c * 255) for c in rgb) | |
| colors.append(rgb_color) | |
| return colors | |
| def run_grounding_dino_detection(image_np, text_prompt, device): | |
| """使用 GroundingDINO 进行检测""" | |
| if grounding_dino_model is None or grounding_dino_processor is None: | |
| print("⚠️ GroundingDINO 未加载") | |
| return [] | |
| try: | |
| print(f"🔍 GroundingDINO 检测: {text_prompt}") | |
| # 转换为 PIL Image | |
| if image_np.dtype == np.uint8: | |
| pil_image = Image.fromarray(image_np) | |
| else: | |
| pil_image = Image.fromarray((image_np * 255).astype(np.uint8)) | |
| # 预处理 | |
| inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # 推理 | |
| with torch.no_grad(): | |
| outputs = grounding_dino_model(**inputs) | |
| # 后处理 | |
| results = grounding_dino_processor.post_process_grounded_object_detection( | |
| outputs, | |
| inputs["input_ids"], | |
| threshold=GROUNDING_DINO_BOX_THRESHOLD, | |
| text_threshold=GROUNDING_DINO_TEXT_THRESHOLD, | |
| target_sizes=[pil_image.size[::-1]] | |
| )[0] | |
| # 转换为统一格式 | |
| detections = [] | |
| boxes = results["boxes"].cpu().numpy() | |
| scores = results["scores"].cpu().numpy() | |
| labels = results["labels"] | |
| print(f"✅ 检测到 {len(boxes)} 个物体") | |
| for box, score, label in zip(boxes, scores, labels): | |
| detection = { | |
| 'bbox': box.tolist(), # [x1, y1, x2, y2] | |
| 'label': label, | |
| 'confidence': float(score) | |
| } | |
| detections.append(detection) | |
| print(f" - {label}: {score:.2f}") | |
| return detections | |
| except Exception as e: | |
| print(f"❌ GroundingDINO 检测失败: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return [] | |
| def run_sam_refinement(image_np, boxes): | |
| """使用 SAM 精确分割 - HuggingFace Transformers 版本""" | |
| if sam_predictor is None: | |
| print("⚠️ SAM 未加载,使用 bbox 作为 mask") | |
| # 使用 bbox 创建简单的矩形 mask | |
| masks = [] | |
| h, w = image_np.shape[:2] | |
| for box in boxes: | |
| x1, y1, x2, y2 = map(int, box) | |
| mask = np.zeros((h, w), dtype=bool) | |
| mask[y1:y2, x1:x2] = True | |
| masks.append(mask) | |
| return masks | |
| try: | |
| print(f"🎯 SAM 精确分割 {len(boxes)} 个区域...") | |
| from PIL import Image | |
| sam_model = sam_predictor['model'] | |
| sam_processor = sam_predictor['processor'] | |
| device = sam_model.device | |
| # 转换为 PIL Image | |
| if image_np.dtype == np.uint8: | |
| pil_image = Image.fromarray(image_np) | |
| else: | |
| pil_image = Image.fromarray((image_np * 255).astype(np.uint8)) | |
| masks = [] | |
| for box in boxes: | |
| x1, y1, x2, y2 = map(int, box) | |
| input_boxes = [[[x1, y1, x2, y2]]] # SAM 需要的格式 | |
| # 处理输入 | |
| inputs = sam_processor(pil_image, input_boxes=input_boxes, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # 推理 | |
| with torch.no_grad(): | |
| outputs = sam_model(**inputs) | |
| # 后处理获取mask | |
| pred_masks = sam_processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| inputs["reshaped_input_sizes"].cpu() | |
| )[0][0][0] # 取第一个mask | |
| masks.append(pred_masks.numpy() > 0.5) | |
| print(f"✅ SAM 分割完成") | |
| return masks | |
| except Exception as e: | |
| print(f"❌ SAM 分割失败: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Fallback to bbox masks | |
| masks = [] | |
| h, w = image_np.shape[:2] | |
| for box in boxes: | |
| x1, y1, x2, y2 = map(int, box) | |
| mask = np.zeros((h, w), dtype=bool) | |
| mask[y1:y2, x1:x2] = True | |
| masks.append(mask) | |
| return masks | |
| def normalize_label(label): | |
| """规范化标签,提取主要类别 | |
| 例如: 'sofa bed' -> 'sofa', 'desk cabinet' -> 'desk', 'table desk' -> 'table' | |
| 'windows' -> 'window', 'chairs' -> 'chair' (单复数转换) | |
| """ | |
| label = label.strip().lower() | |
| # 优先级顺序(从高到低) | |
| priority_labels = ['sofa', 'bed', 'table', 'desk', 'chair', 'cabinet', 'window', 'door'] | |
| # 查找标签中是否包含优先级类别 | |
| for priority in priority_labels: | |
| if priority in label: | |
| return priority | |
| # 如果没有匹配,返回第一个词 | |
| first_word = label.split()[0] if label else label | |
| # 处理常见复数形式 -> 单数 | |
| if first_word.endswith('s') and len(first_word) > 1: | |
| singular = first_word[:-1] # 去掉末尾的 's' | |
| # 特殊复数规则 | |
| if first_word.endswith('sses'): # glasses -> glass | |
| singular = first_word[:-2] | |
| elif first_word.endswith('ies'): # cherries -> cherry | |
| singular = first_word[:-3] + 'y' | |
| elif first_word.endswith('ves'): # shelves -> shelf | |
| singular = first_word[:-3] + 'f' | |
| # 返回单数形式 | |
| return singular | |
| return first_word | |
| def labels_match(label1, label2): | |
| """判断两个标签是否匹配(支持模糊匹配) | |
| 例如: 'sofa' 和 'sofa bed' 匹配 | |
| 'desk' 和 'table desk' 匹配 | |
| """ | |
| norm1 = normalize_label(label1) | |
| norm2 = normalize_label(label2) | |
| return norm1 == norm2 | |
| def compute_object_3d_center(points, mask): | |
| """计算物体的 3D 中心点""" | |
| masked_points = points[mask] | |
| if len(masked_points) == 0: | |
| return None | |
| return np.median(masked_points, axis=0) | |
| def compute_3d_bbox_iou(center1, size1, center2, size2): | |
| """计算两个3D边界框的IoU""" | |
| try: | |
| # 计算边界框范围 [min, max] | |
| min1 = center1 - size1 / 2 | |
| max1 = center1 + size1 / 2 | |
| min2 = center2 - size2 / 2 | |
| max2 = center2 + size2 / 2 | |
| # 计算交集 | |
| inter_min = np.maximum(min1, min2) | |
| inter_max = np.minimum(max1, max2) | |
| inter_size = np.maximum(0, inter_max - inter_min) | |
| inter_volume = np.prod(inter_size) | |
| # 计算并集 | |
| volume1 = np.prod(size1) | |
| volume2 = np.prod(size2) | |
| union_volume = volume1 + volume2 - inter_volume | |
| if union_volume == 0: | |
| return 0.0 | |
| return inter_volume / union_volume | |
| except: | |
| return 0.0 | |
| def compute_2d_mask_iou(mask1, mask2): | |
| """计算两个2D mask的IoU""" | |
| try: | |
| intersection = np.logical_and(mask1, mask2).sum() | |
| union = np.logical_or(mask1, mask2).sum() | |
| if union == 0: | |
| return 0.0 | |
| return intersection / union | |
| except: | |
| return 0.0 | |
| def extract_visual_features(image, mask, encoder): | |
| """提取mask区域的视觉特征(使用DINOv2) | |
| Args: | |
| image: [H, W, 3] float32 in [0, 1] or uint8 in [0, 255] | |
| mask: [H, W] bool | |
| encoder: DINOv2 encoder model | |
| Returns: | |
| feature vector (1D numpy array) or None if failed | |
| """ | |
| try: | |
| # 将mask区域裁剪出来 | |
| coords = np.argwhere(mask) | |
| if len(coords) == 0: | |
| return None | |
| y_min, x_min = coords.min(axis=0) | |
| y_max, x_max = coords.max(axis=0) | |
| # 确保裁剪区域有效 | |
| if y_max <= y_min or x_max <= x_min: | |
| return None | |
| # 裁剪并resize到224x224 | |
| cropped = image[y_min:y_max+1, x_min:x_max+1] | |
| # 确保是 uint8 格式 | |
| if cropped.dtype == np.float32 or cropped.dtype == np.float64: | |
| if cropped.max() <= 1.0: | |
| cropped = (cropped * 255).astype(np.uint8) | |
| else: | |
| cropped = cropped.astype(np.uint8) | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| pil_img = Image.fromarray(cropped) | |
| pil_img = pil_img.resize((224, 224), Image.BILINEAR) | |
| # 转换为tensor | |
| transform = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # 获取encoder的设备 | |
| try: | |
| device = next(encoder.parameters()).device | |
| except: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| img_tensor = transform(pil_img).unsqueeze(0).to(device) # [1, 3, 224, 224] | |
| # 提取特征 - 使用 encoder 的前向传播 | |
| with torch.no_grad(): | |
| # 不同的encoder可能有不同的调用方式 | |
| if hasattr(encoder, 'forward_features'): | |
| # 如果有 forward_features 方法(标准 DINOv2) | |
| features = encoder.forward_features(img_tensor) | |
| else: | |
| # 否则直接调用(DINOv2Encoder 只需要 input tensor) | |
| features = encoder(img_tensor) | |
| # 如果 features 不是 tensor,尝试转换 | |
| if not isinstance(features, torch.Tensor): | |
| if isinstance(features, dict): | |
| # 如果返回字典,尝试获取 'x' 或 'last_hidden_state' | |
| features = features.get('x', features.get('last_hidden_state', None)) | |
| if features is None: | |
| return None | |
| elif hasattr(features, 'data'): | |
| # 如果是某种包装对象,尝试获取 data 属性 | |
| features = features.data | |
| else: | |
| # 无法处理,返回 None | |
| return None | |
| # 确保 features 是 tensor | |
| if not isinstance(features, torch.Tensor): | |
| return None | |
| # 确保是 4D tensor: [B, C, H, W] 或 3D: [B, N, C] 或 2D: [B, C] | |
| if len(features.shape) == 4: | |
| # [B, C, H, W] -> Global average pooling | |
| features = features.mean(dim=[2, 3]) # [B, C] | |
| elif len(features.shape) == 3: | |
| # [B, N, C] -> 取平均 or 取 CLS token | |
| features = features.mean(dim=1) # [B, C] | |
| elif len(features.shape) == 2: | |
| # [B, C] -> 已经是我们需要的格式 | |
| pass | |
| else: | |
| # 不支持的 shape | |
| return None | |
| # L2 normalize | |
| features = features / (features.norm(dim=1, keepdim=True) + 1e-8) | |
| return features.cpu().numpy()[0] | |
| except Exception as e: | |
| import traceback | |
| print(f" ⚠️ 特征提取失败: {type(e).__name__}: {e}") | |
| print(f" 调用栈:\n{traceback.format_exc()}") # 显示完整堆栈 | |
| return None | |
| def compute_feature_similarity(feat1, feat2): | |
| """计算特征相似度(余弦相似度)""" | |
| if feat1 is None or feat2 is None: | |
| return 0.0 | |
| try: | |
| return np.dot(feat1, feat2) | |
| except: | |
| return 0.0 | |
| def compute_match_score(obj1, obj2, weights={'distance': 0.5, 'iou_3d': 0.25, 'iou_2d': 0.15, 'feature': 0.1}): | |
| """计算综合匹配分数(0-1) | |
| 动态调整权重:如果某个准则不可用,将其权重重新分配给其他准则 | |
| """ | |
| scores = {} | |
| available_criteria = [] | |
| # 1. 3D距离分数(距离越近,分数越高) | |
| if obj1.get('center_3d') is not None and obj2.get('center_3d') is not None: | |
| distance = np.linalg.norm(obj1['center_3d'] - obj2['center_3d']) | |
| scores['distance'] = max(0, 1 - distance / MATCH_3D_DISTANCE_THRESHOLD) | |
| available_criteria.append('distance') | |
| else: | |
| scores['distance'] = 0.0 | |
| # 2. 3D IoU分数 | |
| if obj1.get('bbox_3d') is not None and obj2.get('bbox_3d') is not None: | |
| scores['iou_3d'] = compute_3d_bbox_iou( | |
| obj1['bbox_3d']['center'], obj1['bbox_3d']['size'], | |
| obj2['bbox_3d']['center'], obj2['bbox_3d']['size'] | |
| ) | |
| available_criteria.append('iou_3d') | |
| else: | |
| scores['iou_3d'] = 0.0 | |
| # 3. 2D IoU分数 | |
| if obj1.get('mask_2d') is not None and obj2.get('mask_2d') is not None: | |
| scores['iou_2d'] = compute_2d_mask_iou(obj1['mask_2d'], obj2['mask_2d']) | |
| available_criteria.append('iou_2d') | |
| else: | |
| scores['iou_2d'] = 0.0 | |
| # 4. 视觉特征相似度 | |
| if obj1.get('visual_feature') is not None and obj2.get('visual_feature') is not None: | |
| scores['feature'] = compute_feature_similarity(obj1['visual_feature'], obj2['visual_feature']) | |
| available_criteria.append('feature') | |
| else: | |
| scores['feature'] = 0.0 | |
| # 动态调整权重:只使用可用的准则 | |
| if len(available_criteria) == 0: | |
| return 0.0, scores | |
| # 重新归一化权重 | |
| total_available_weight = sum(weights[k] for k in available_criteria) | |
| if total_available_weight == 0: | |
| return 0.0, scores | |
| adjusted_weights = {k: weights[k] / total_available_weight for k in available_criteria} | |
| # 加权求和 | |
| total_score = sum(scores[k] * adjusted_weights[k] for k in available_criteria) | |
| return total_score, scores | |
| def compute_adaptive_eps(centers, base_eps): | |
| """自适应计算eps值 | |
| 根据物体的3D位置分布自动调整eps: | |
| - 如果物体很分散,增大eps(避免过度分割) | |
| - 如果物体很集中,使用默认eps | |
| """ | |
| if len(centers) <= 1: | |
| return base_eps | |
| # 计算所有点之间的距离 | |
| from scipy.spatial.distance import pdist | |
| distances = pdist(centers) | |
| if len(distances) == 0: | |
| return base_eps | |
| # 使用中位数距离作为参考 | |
| median_dist = np.median(distances) | |
| # 自适应策略:如果中位数距离很大,说明物体分散,增大eps | |
| # 如果中位数距离很小,说明物体集中,保持或减小eps | |
| if median_dist > base_eps * 2: | |
| # 物体非常分散,大幅增大eps(可能是同一物体的多视图检测) | |
| adaptive_eps = min(median_dist * 0.6, base_eps * 2.5) | |
| elif median_dist > base_eps: | |
| # 物体较分散,适度增大eps | |
| adaptive_eps = median_dist * 0.5 | |
| else: | |
| # 物体集中,使用默认eps | |
| adaptive_eps = base_eps | |
| return adaptive_eps | |
| def match_objects_across_views(all_view_detections): | |
| """跨视图匹配相同物体(V8增强版:自适应DBSCAN聚类) | |
| V8增强版改进: | |
| - 自适应eps:根据物体分布自动调整聚类半径 | |
| - 智能合并:聚类后再检查是否有明显重复的簇 | |
| - 置信度加权:使用置信度加权计算簇中心 | |
| Args: | |
| all_view_detections: List[List[Dict]], 每个视图的检测结果 | |
| Returns: | |
| object_id_map: Dict[view_idx][det_idx] = global_object_id | |
| unique_objects: List[Dict] - 唯一物体列表 | |
| """ | |
| print("\n🔗 V8增强版: 自适应DBSCAN聚类匹配物体...") | |
| # 收集所有检测,按标签分组 | |
| objects_by_label = defaultdict(list) | |
| for view_idx, detections in enumerate(all_view_detections): | |
| for det_idx, det in enumerate(detections): | |
| # 只处理有3D中心的物体 | |
| if det.get('center_3d') is None: | |
| continue | |
| norm_label = normalize_label(det['label']) | |
| objects_by_label[norm_label].append({ | |
| 'view_idx': view_idx, | |
| 'det_idx': det_idx, | |
| 'label': det['label'], | |
| 'norm_label': norm_label, | |
| 'center_3d': det['center_3d'], | |
| 'confidence': det['confidence'], | |
| 'bbox_3d': det.get('bbox_3d'), | |
| }) | |
| if len(objects_by_label) == 0: | |
| return {}, [] | |
| # V8: 对每种物体类别分别进行DBSCAN聚类 | |
| object_id_map = defaultdict(dict) | |
| unique_objects = [] | |
| next_global_id = 0 | |
| for norm_label, objects in objects_by_label.items(): | |
| print(f"\n 📦 处理 {norm_label}: {len(objects)} 个检测") | |
| # 如果只有1个检测,直接作为1个物体 | |
| if len(objects) == 1: | |
| obj = objects[0] | |
| unique_objects.append({ | |
| 'global_id': next_global_id, | |
| 'label': obj['label'], | |
| 'views': [(obj['view_idx'], obj['det_idx'])], | |
| 'center_3d': obj['center_3d'], | |
| }) | |
| object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id | |
| next_global_id += 1 | |
| print(f" → 1个簇(单独检测)") | |
| continue | |
| # 提取3D中心点坐标 | |
| centers = np.array([obj['center_3d'] for obj in objects]) | |
| # 获取该类型的基础聚类半径 | |
| base_eps = DBSCAN_EPS_CONFIG.get(norm_label, DBSCAN_EPS_CONFIG.get('default', 1.0)) | |
| # 🔥 V8增强:自适应计算eps | |
| eps = compute_adaptive_eps(centers, base_eps) | |
| # DBSCAN聚类 | |
| clustering = DBSCAN(eps=eps, min_samples=DBSCAN_MIN_SAMPLES, metric='euclidean') | |
| cluster_labels = clustering.fit_predict(centers) | |
| # 统计簇 | |
| n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) | |
| n_noise = list(cluster_labels).count(-1) | |
| if eps != base_eps: | |
| print(f" → {n_clusters} 个簇 (基础eps={base_eps}m → 自适应eps={eps:.2f}m)") | |
| else: | |
| print(f" → {n_clusters} 个簇 (eps={eps}m)") | |
| if n_noise > 0: | |
| print(f" ⚠️ {n_noise} 个噪声点(孤立检测)") | |
| # 调试:显示每个簇的详细信息 | |
| for cluster_id in sorted(set(cluster_labels)): | |
| if cluster_id == -1: | |
| continue | |
| cluster_objs = [objects[i] for i, label in enumerate(cluster_labels) if label == cluster_id] | |
| cluster_centers = [obj['center_3d'] for obj in cluster_objs] | |
| cluster_views = [f"V{obj['view_idx']+1}" for obj in cluster_objs] | |
| # 计算簇内最大距离 | |
| max_dist = 0 | |
| if len(cluster_centers) > 1: | |
| from scipy.spatial.distance import pdist | |
| distances = pdist(np.array(cluster_centers)) | |
| max_dist = distances.max() if len(distances) > 0 else 0 | |
| print(f" 簇 {cluster_id}: {len(cluster_objs)} 个检测 (来自视图: {', '.join(cluster_views)}, 最大簇内距离: {max_dist:.2f}m)") | |
| # 为每个簇创建一个全局物体 | |
| cluster_to_global_id = {} | |
| for cluster_id in set(cluster_labels): | |
| if cluster_id == -1: | |
| # 噪声点,每个单独成为一个物体 | |
| for i, label in enumerate(cluster_labels): | |
| if label == -1: | |
| obj = objects[i] | |
| unique_objects.append({ | |
| 'global_id': next_global_id, | |
| 'label': obj['label'], | |
| 'views': [(obj['view_idx'], obj['det_idx'])], | |
| 'center_3d': obj['center_3d'], | |
| }) | |
| object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id | |
| next_global_id += 1 | |
| else: | |
| # 正常簇 | |
| cluster_objects = [objects[i] for i, label in enumerate(cluster_labels) if label == cluster_id] | |
| # 计算簇的中心(加权平均,权重为置信度) | |
| total_conf = sum(o['confidence'] for o in cluster_objects) | |
| weighted_center = sum(o['center_3d'] * o['confidence'] for o in cluster_objects) / total_conf | |
| # 创建全局物体 | |
| unique_objects.append({ | |
| 'global_id': next_global_id, | |
| 'label': cluster_objects[0]['label'], | |
| 'views': [(o['view_idx'], o['det_idx']) for o in cluster_objects], | |
| 'center_3d': weighted_center, | |
| }) | |
| # 映射所有检测到这个全局ID | |
| for obj in cluster_objects: | |
| object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id | |
| print(f" 簇 {cluster_id}: {len(cluster_objects)} 个检测合并") | |
| next_global_id += 1 | |
| print(f"\n 📊 总结:") | |
| print(f" 总检测数: {sum(len(objs) for objs in objects_by_label.values())}") | |
| print(f" 唯一物体: {len(unique_objects)}") | |
| # 打印匹配结果(按规范化标签统计) | |
| label_counts = defaultdict(int) | |
| original_labels = defaultdict(set) | |
| for obj in unique_objects: | |
| norm_label = normalize_label(obj['label']) | |
| label_counts[norm_label] += 1 | |
| original_labels[norm_label].add(obj['label']) | |
| print(f"\n 📊 物体类别统计(规范化后):") | |
| for norm_label, count in sorted(label_counts.items()): | |
| orig_labels = original_labels[norm_label] | |
| if len(orig_labels) > 1: | |
| print(f" {norm_label} (原标签: {', '.join(sorted(orig_labels))}): {count} 个") | |
| else: | |
| print(f" {norm_label}: {count} 个") | |
| return object_id_map, unique_objects | |
| def create_multi_view_segmented_mesh(processed_data, all_view_detections, all_view_masks, | |
| object_id_map, unique_objects, target_dir, use_sam=True): | |
| """创建多视图融合的分割 mesh(使用 utils3d.image_mesh)""" | |
| try: | |
| print("\n🎨 生成多视图分割 mesh...") | |
| # 按物体类别(label)分配颜色,使用规范化标签避免组合标签问题 | |
| # 获取所有不同的规范化类别 | |
| unique_normalized_labels = sorted(set(normalize_label(obj['label']) for obj in unique_objects)) | |
| label_colors = {} | |
| colors = generate_distinct_colors(len(unique_normalized_labels)) | |
| # 为规范化标签分配颜色 | |
| for i, norm_label in enumerate(unique_normalized_labels): | |
| label_colors[norm_label] = colors[i] | |
| # 为每个唯一物体分配基于规范化类别的颜色 | |
| for obj in unique_objects: | |
| norm_label = normalize_label(obj['label']) | |
| obj['color'] = label_colors[norm_label] | |
| obj['normalized_label'] = norm_label # 保存规范化标签 | |
| # 打印类别-颜色映射(按规范化标签) | |
| print(f" 物体类别颜色映射(规范化标签):") | |
| for norm_label, color in sorted(label_colors.items()): | |
| count = sum(1 for obj in unique_objects if normalize_label(obj['label']) == norm_label) | |
| # 显示所有原始标签 | |
| original_labels = set(obj['label'] for obj in unique_objects if normalize_label(obj['label']) == norm_label) | |
| if len(original_labels) > 1: | |
| print(f" {norm_label} (包含: {', '.join(sorted(original_labels))}) × {count} → RGB{color}") | |
| else: | |
| print(f" {norm_label} × {count} → RGB{color}") | |
| # 导入 utils3d | |
| import utils3d | |
| all_meshes = [] | |
| # 为每个视图生成 mesh | |
| for view_idx in range(len(processed_data)): | |
| view_data = processed_data[view_idx] | |
| image = view_data["image"] | |
| points3d = view_data["points3d"] | |
| mask = view_data.get("mask") | |
| normal = view_data.get("normal") | |
| detections = all_view_detections[view_idx] | |
| masks = all_view_masks[view_idx] | |
| if len(detections) == 0: | |
| continue | |
| # 确保图像在 [0, 255] 范围 | |
| if image.dtype != np.uint8: | |
| if image.max() <= 1.0: | |
| image = (image * 255).astype(np.uint8) | |
| else: | |
| image = image.astype(np.uint8) | |
| # 创建彩色图像(使用置信度优先策略避免颜色混乱) | |
| colored_image = image.copy() | |
| confidence_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32) # 记录每个像素的置信度 | |
| # 收集所有检测及其信息(应用质量过滤) | |
| detections_info = [] | |
| filtered_count = 0 | |
| for det_idx, (det, seg_mask) in enumerate(zip(detections, masks)): | |
| # 过滤低置信度检测 | |
| if det['confidence'] < MIN_DETECTION_CONFIDENCE: | |
| filtered_count += 1 | |
| continue | |
| # 过滤过小的mask | |
| mask_area = seg_mask.sum() | |
| if mask_area < MIN_MASK_AREA: | |
| filtered_count += 1 | |
| continue | |
| global_id = object_id_map[view_idx].get(det_idx) | |
| if global_id is None: | |
| continue | |
| unique_obj = next((obj for obj in unique_objects if obj['global_id'] == global_id), None) | |
| if unique_obj is None: | |
| continue | |
| detections_info.append({ | |
| 'mask': seg_mask, | |
| 'color': unique_obj['color'], | |
| 'confidence': det['confidence'], | |
| 'label': det['label'], | |
| 'area': mask_area | |
| }) | |
| if filtered_count > 0: | |
| print(f" 视图 {view_idx + 1}: 过滤了 {filtered_count} 个低质量检测") | |
| # 按置信度排序(从低到高),这样高置信度的会最后写入 | |
| detections_info.sort(key=lambda x: x['confidence']) | |
| # 应用颜色(置信度高的优先) | |
| for info in detections_info: | |
| seg_mask = info['mask'] | |
| color = info['color'] | |
| conf = info['confidence'] | |
| # 只在当前置信度更高的地方覆盖 | |
| update_mask = seg_mask & (conf > confidence_map) | |
| colored_image[update_mask] = color | |
| confidence_map[update_mask] = conf | |
| # 使用 utils3d.image_mesh 生成 mesh | |
| height, width = image.shape[:2] | |
| if normal is None: | |
| faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( | |
| points3d, | |
| colored_image.astype(np.float32) / 255, | |
| utils3d.numpy.image_uv(width=width, height=height), | |
| mask=mask if mask is not None else np.ones((height, width), dtype=bool), | |
| tri=True | |
| ) | |
| vertex_normals = None | |
| else: | |
| faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh( | |
| points3d, | |
| colored_image.astype(np.float32) / 255, | |
| utils3d.numpy.image_uv(width=width, height=height), | |
| normal, | |
| mask=mask if mask is not None else np.ones((height, width), dtype=bool), | |
| tri=True | |
| ) | |
| # 坐标变换 | |
| vertices = vertices * np.array([1, -1, -1], dtype=np.float32) | |
| if vertex_normals is not None: | |
| vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32) | |
| # 创建 mesh | |
| view_mesh = trimesh.Trimesh( | |
| vertices=vertices, | |
| faces=faces, | |
| vertex_normals=vertex_normals, | |
| vertex_colors=(vertex_colors * 255).astype(np.uint8), | |
| process=False | |
| ) | |
| all_meshes.append(view_mesh) | |
| print(f" 视图 {view_idx + 1}: {len(vertices):,} 顶点, {len(faces):,} 面") | |
| if len(all_meshes) == 0: | |
| print("⚠️ 未生成任何 mesh") | |
| return None | |
| # 融合所有 mesh | |
| print(" 融合所有视图...") | |
| combined_mesh = trimesh.util.concatenate(all_meshes) | |
| # 保存 | |
| glb_path = os.path.join(target_dir, 'multi_view_segmented_mesh.glb') | |
| combined_mesh.export(glb_path) | |
| print(f"✅ 多视图分割 mesh 已保存: {glb_path}") | |
| print(f" 总计: {len(combined_mesh.vertices):,} 顶点, {len(combined_mesh.faces):,} 面") | |
| print(f" {len(unique_objects)} 个唯一物体") | |
| return glb_path | |
| except Exception as e: | |
| print(f"❌ 生成多视图 mesh 失败: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def create_segmented_pointcloud(processed_data, detections, masks, target_dir, use_sam=True): | |
| """创建分割点云(单视图,仅用于兼容)""" | |
| if len(detections) == 0: | |
| return None | |
| try: | |
| print(f"🎨 生成分割点云...") | |
| # 使用第一个视图 | |
| first_view = processed_data[0] | |
| image = first_view["image"] | |
| points3d = first_view["points3d"] | |
| normal = first_view.get("normal") | |
| mask = first_view.get("mask") | |
| # 确保图像在 [0, 255] 范围 | |
| if image.dtype != np.uint8: | |
| if image.max() <= 1.0: | |
| image = (image * 255).astype(np.uint8) | |
| else: | |
| image = image.astype(np.uint8) | |
| # 生成颜色 | |
| distinct_colors = generate_distinct_colors(len(detections)) | |
| # 创建彩色图像 | |
| colored_image = image.copy() | |
| for i, (det, seg_mask) in enumerate(zip(detections, masks)): | |
| color = distinct_colors[i] | |
| colored_image[seg_mask] = color | |
| print(f" {det['label']} → RGB{color}") | |
| # 生成点云(使用 MapAnything 的方法) | |
| height, width = image.shape[:2] | |
| # 简单方法:直接从 points3d 生成顶点颜色 | |
| vertices = points3d.reshape(-1, 3) | |
| colors = (colored_image.astype(np.float32) / 255.0).reshape(-1, 3) | |
| if mask is not None: | |
| valid_mask = mask.reshape(-1) | |
| vertices = vertices[valid_mask] | |
| colors = colors[valid_mask] | |
| # 坐标变换 | |
| vertices = vertices * np.array([1, -1, -1], dtype=np.float32) | |
| # 创建点云 | |
| pointcloud = trimesh.PointCloud( | |
| vertices=vertices, | |
| colors=(colors * 255).astype(np.uint8) | |
| ) | |
| # 保存 | |
| seg_glb_path = os.path.join(target_dir, 'segmented_pointcloud.glb') | |
| pointcloud.export(seg_glb_path) | |
| print(f"✅ 分割点云已保存: {seg_glb_path}") | |
| return seg_glb_path | |
| except Exception as e: | |
| print(f"❌ 生成分割点云失败: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| # ============================================================================ | |
| # 核心模型推理 | |
| # ============================================================================ | |
| def run_model( | |
| target_dir, | |
| apply_mask=True, | |
| mask_edges=True, | |
| filter_black_bg=False, | |
| filter_white_bg=False, | |
| enable_segmentation=False, | |
| text_prompt=DEFAULT_TEXT_PROMPT, | |
| use_sam=True, | |
| ): | |
| """ | |
| Run the MapAnything model + GroundingDINO + SAM segmentation | |
| """ | |
| global model, grounding_dino_model, sam_predictor | |
| import torch | |
| print(f"处理图像: {target_dir}") | |
| # 设备检查 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| device = torch.device(device) | |
| # 初始化 MapAnything 模型 - 从 HuggingFace | |
| if model is None: | |
| print("📥 从 HuggingFace 加载 MapAnything...") | |
| model = initialize_mapanything_model(high_level_config, device) | |
| print("✅ MapAnything 加载成功") | |
| else: | |
| model = model.to(device) | |
| model.eval() | |
| # 加载分割模型 | |
| if enable_segmentation: | |
| load_grounding_dino_model(device) | |
| if use_sam: | |
| load_sam_model(device) | |
| # 加载图像 | |
| print("加载图像...") | |
| image_folder_path = os.path.join(target_dir, "images") | |
| views = load_images(image_folder_path) | |
| print(f"加载了 {len(views)} 张图像") | |
| if len(views) == 0: | |
| raise ValueError("未找到图像") | |
| # 运行 MapAnything 推理 | |
| print("运行 3D 重建...") | |
| outputs = model.infer( | |
| views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False | |
| ) | |
| # 转换预测结果 | |
| predictions = {} | |
| extrinsic_list = [] | |
| intrinsic_list = [] | |
| world_points_list = [] | |
| depth_maps_list = [] | |
| images_list = [] | |
| final_mask_list = [] | |
| confidences = [] | |
| for pred in outputs: | |
| depthmap_torch = pred["depth_z"][0].squeeze(-1) | |
| intrinsics_torch = pred["intrinsics"][0] | |
| camera_pose_torch = pred["camera_poses"][0] | |
| conf = pred["conf"][0].squeeze(-1) | |
| pts3d_computed, valid_mask = depthmap_to_world_frame( | |
| depthmap_torch, intrinsics_torch, camera_pose_torch | |
| ) | |
| if "mask" in pred: | |
| mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) | |
| else: | |
| mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) | |
| mask = mask & valid_mask.cpu().numpy() | |
| image = pred["img_no_norm"][0].cpu().numpy() | |
| extrinsic_list.append(camera_pose_torch.cpu().numpy()) | |
| intrinsic_list.append(intrinsics_torch.cpu().numpy()) | |
| world_points_list.append(pts3d_computed.cpu().numpy()) | |
| depth_maps_list.append(depthmap_torch.cpu().numpy()) | |
| images_list.append(image) | |
| final_mask_list.append(mask) | |
| confidences.append(conf.cpu().numpy()) | |
| predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) | |
| predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) | |
| predictions["world_points"] = np.stack(world_points_list, axis=0) | |
| predictions["conf"] = np.stack(confidences, axis=0) | |
| depth_maps = np.stack(depth_maps_list, axis=0) | |
| if len(depth_maps.shape) == 3: | |
| depth_maps = depth_maps[..., np.newaxis] | |
| predictions["depth"] = depth_maps | |
| predictions["images"] = np.stack(images_list, axis=0) | |
| predictions["final_mask"] = np.stack(final_mask_list, axis=0) | |
| # 处理可视化数据 | |
| processed_data = process_predictions_for_visualization( | |
| predictions, views, high_level_config, filter_black_bg, filter_white_bg | |
| ) | |
| # 多视图分割处理 | |
| segmented_glb = None | |
| if enable_segmentation and grounding_dino_model is not None: | |
| print("\n🎯 开始多视图分割...") | |
| print(f"🔍 使用检测提示: {text_prompt[:100]}...") | |
| all_view_detections = [] | |
| all_view_masks = [] | |
| # 对每个视图进行分割 | |
| for view_idx, ref_image in enumerate(images_list): | |
| print(f"\n📸 处理视图 {view_idx + 1}/{len(images_list)}...") | |
| if ref_image.dtype != np.uint8: | |
| ref_image_np = (ref_image * 255).astype(np.uint8) | |
| else: | |
| ref_image_np = ref_image | |
| # GroundingDINO 检测 | |
| detections = run_grounding_dino_detection(ref_image_np, text_prompt, device) | |
| if len(detections) > 0: | |
| # SAM 精确分割 | |
| boxes = [d['bbox'] for d in detections] | |
| masks = run_sam_refinement(ref_image_np, boxes) if use_sam else [] | |
| # 获取3D点云和encoder(用于特征提取) | |
| points3d = world_points_list[view_idx] | |
| encoder = model.encoder if hasattr(model, 'encoder') else None | |
| # V5: 为每个检测物体提取多种特征 | |
| for det_idx, (det, mask) in enumerate(zip(detections, masks)): | |
| # 1. 计算3D中心点 | |
| center_3d = compute_object_3d_center(points3d, mask) | |
| det['center_3d'] = center_3d | |
| # 2. 计算3D边界框 | |
| if center_3d is not None: | |
| masked_points = points3d[mask] | |
| if len(masked_points) > 0: | |
| bbox_min = masked_points.min(axis=0) | |
| bbox_max = masked_points.max(axis=0) | |
| bbox_size = bbox_max - bbox_min | |
| det['bbox_3d'] = { | |
| 'center': center_3d, | |
| 'size': bbox_size, | |
| 'min': bbox_min, | |
| 'max': bbox_max | |
| } | |
| # 3. 存储2D mask(用于IoU计算) | |
| det['mask_2d'] = mask | |
| # 4. 提取视觉特征(DINOv2)- 可选 | |
| if ENABLE_VISUAL_FEATURES and encoder is not None: | |
| visual_feat = extract_visual_features(ref_image, mask, encoder) | |
| det['visual_feature'] = visual_feat | |
| else: | |
| det['visual_feature'] = None | |
| all_view_detections.append(detections) | |
| all_view_masks.append(masks) | |
| else: | |
| all_view_detections.append([]) | |
| all_view_masks.append([]) | |
| # 跨视图匹配物体 | |
| if any(len(dets) > 0 for dets in all_view_detections): | |
| object_id_map, unique_objects = match_objects_across_views(all_view_detections) | |
| # 生成多视图分割 mesh | |
| segmented_glb = create_multi_view_segmented_mesh( | |
| processed_data, all_view_detections, all_view_masks, | |
| object_id_map, unique_objects, target_dir, use_sam | |
| ) | |
| # 清理 | |
| torch.cuda.empty_cache() | |
| return predictions, processed_data, segmented_glb | |
| # ============================================================================ | |
| # 从 gradio_app.py 复制的其他函数 | |
| # ============================================================================ | |
| def update_view_selectors(processed_data): | |
| """Update view selector dropdowns based on available views""" | |
| if processed_data is None or len(processed_data) == 0: | |
| choices = ["View 1"] | |
| else: | |
| num_views = len(processed_data) | |
| choices = [f"View {i + 1}" for i in range(num_views)] | |
| return ( | |
| gr.Dropdown(choices=choices, value=choices[0]), | |
| gr.Dropdown(choices=choices, value=choices[0]), | |
| gr.Dropdown(choices=choices, value=choices[0]), | |
| ) | |
| def get_view_data_by_index(processed_data, view_index): | |
| """Get view data by index, handling bounds""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return None | |
| view_keys = list(processed_data.keys()) | |
| if view_index < 0 or view_index >= len(view_keys): | |
| view_index = 0 | |
| return processed_data[view_keys[view_index]] | |
| def update_depth_view(processed_data, view_index): | |
| """Update depth view for a specific view index""" | |
| view_data = get_view_data_by_index(processed_data, view_index) | |
| if view_data is None or view_data["depth"] is None: | |
| return None | |
| return colorize_depth(view_data["depth"], mask=view_data.get("mask")) | |
| def update_normal_view(processed_data, view_index): | |
| """Update normal view for a specific view index""" | |
| view_data = get_view_data_by_index(processed_data, view_index) | |
| if view_data is None or view_data["normal"] is None: | |
| return None | |
| return colorize_normal(view_data["normal"], mask=view_data.get("mask")) | |
| def update_measure_view(processed_data, view_index): | |
| """Update measure view for a specific view index with mask overlay""" | |
| view_data = get_view_data_by_index(processed_data, view_index) | |
| if view_data is None: | |
| return None, [] | |
| image = view_data["image"].copy() | |
| if image.dtype != np.uint8: | |
| if image.max() <= 1.0: | |
| image = (image * 255).astype(np.uint8) | |
| else: | |
| image = image.astype(np.uint8) | |
| if view_data["mask"] is not None: | |
| mask = view_data["mask"] | |
| invalid_mask = ~mask | |
| if invalid_mask.any(): | |
| overlay_color = np.array([255, 220, 220], dtype=np.uint8) | |
| alpha = 0.5 | |
| for c in range(3): | |
| image[:, :, c] = np.where( | |
| invalid_mask, | |
| (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], | |
| image[:, :, c], | |
| ).astype(np.uint8) | |
| return image, [] | |
| def navigate_depth_view(processed_data, current_selector_value, direction): | |
| """Navigate depth view""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return "View 1", None | |
| try: | |
| current_view = int(current_selector_value.split()[1]) - 1 | |
| except: | |
| current_view = 0 | |
| num_views = len(processed_data) | |
| new_view = (current_view + direction) % num_views | |
| new_selector_value = f"View {new_view + 1}" | |
| depth_vis = update_depth_view(processed_data, new_view) | |
| return new_selector_value, depth_vis | |
| def navigate_normal_view(processed_data, current_selector_value, direction): | |
| """Navigate normal view""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return "View 1", None | |
| try: | |
| current_view = int(current_selector_value.split()[1]) - 1 | |
| except: | |
| current_view = 0 | |
| num_views = len(processed_data) | |
| new_view = (current_view + direction) % num_views | |
| new_selector_value = f"View {new_view + 1}" | |
| normal_vis = update_normal_view(processed_data, new_view) | |
| return new_selector_value, normal_vis | |
| def navigate_measure_view(processed_data, current_selector_value, direction): | |
| """Navigate measure view""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return "View 1", None, [] | |
| try: | |
| current_view = int(current_selector_value.split()[1]) - 1 | |
| except: | |
| current_view = 0 | |
| num_views = len(processed_data) | |
| new_view = (current_view + direction) % num_views | |
| new_selector_value = f"View {new_view + 1}" | |
| measure_image, measure_points = update_measure_view(processed_data, new_view) | |
| return new_selector_value, measure_image, measure_points | |
| def populate_visualization_tabs(processed_data): | |
| """Populate the depth, normal, and measure tabs with processed data""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return None, None, None, [] | |
| depth_vis = update_depth_view(processed_data, 0) | |
| normal_vis = update_normal_view(processed_data, 0) | |
| measure_img, _ = update_measure_view(processed_data, 0) | |
| return depth_vis, normal_vis, measure_img, [] | |
| def handle_uploads(input_video, input_images, s_time_interval=1.0): | |
| """Handle uploaded video/images""" | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| target_dir = f"input_images_{timestamp}" | |
| target_dir_images = os.path.join(target_dir, "images") | |
| if os.path.exists(target_dir): | |
| shutil.rmtree(target_dir) | |
| os.makedirs(target_dir) | |
| os.makedirs(target_dir_images) | |
| image_paths = [] | |
| # Handle images | |
| if input_images is not None: | |
| for file_data in input_images: | |
| if isinstance(file_data, dict) and "name" in file_data: | |
| file_path = file_data["name"] | |
| else: | |
| file_path = file_data | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| if file_ext in [".heic", ".heif"]: | |
| try: | |
| with Image.open(file_path) as img: | |
| if img.mode not in ("RGB", "L"): | |
| img = img.convert("RGB") | |
| base_name = os.path.splitext(os.path.basename(file_path))[0] | |
| dst_path = os.path.join(target_dir_images, f"{base_name}.jpg") | |
| img.save(dst_path, "JPEG", quality=95) | |
| image_paths.append(dst_path) | |
| except Exception as e: | |
| print(f"Error converting HEIC: {e}") | |
| dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) | |
| shutil.copy(file_path, dst_path) | |
| image_paths.append(dst_path) | |
| else: | |
| dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) | |
| shutil.copy(file_path, dst_path) | |
| image_paths.append(dst_path) | |
| # Handle video | |
| if input_video is not None: | |
| if isinstance(input_video, dict) and "name" in input_video: | |
| video_path = input_video["name"] | |
| else: | |
| video_path = input_video | |
| vs = cv2.VideoCapture(video_path) | |
| fps = vs.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(fps * s_time_interval) | |
| count = 0 | |
| video_frame_num = 0 | |
| while True: | |
| gotit, frame = vs.read() | |
| if not gotit: | |
| break | |
| count += 1 | |
| if count % frame_interval == 0: | |
| image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png") | |
| cv2.imwrite(image_path, frame) | |
| image_paths.append(image_path) | |
| video_frame_num += 1 | |
| image_paths = sorted(image_paths) | |
| end_time = time.time() | |
| print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds") | |
| return target_dir, image_paths | |
| def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0): | |
| """Update gallery on upload""" | |
| if not input_video and not input_images: | |
| return None, None, None, None, None | |
| target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval) | |
| return ( | |
| None, | |
| None, | |
| target_dir, | |
| image_paths, | |
| "上传完成,点击「重建」开始 3D 处理", | |
| ) | |
| def gradio_demo( | |
| target_dir, | |
| frame_filter="All", | |
| show_cam=True, | |
| filter_black_bg=False, | |
| filter_white_bg=False, | |
| conf_thres=3.0, | |
| apply_mask=True, | |
| show_mesh=True, | |
| enable_segmentation=False, | |
| text_prompt=DEFAULT_TEXT_PROMPT, | |
| use_sam=True, | |
| ): | |
| """Perform reconstruction""" | |
| if not os.path.isdir(target_dir) or target_dir == "None": | |
| return None, None, "请先上传文件", None, None, None, None, None, None, None, None | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| target_dir_images = os.path.join(target_dir, "images") | |
| all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] | |
| all_files_display = [f"{i}: {filename}" for i, filename in enumerate(all_files)] | |
| frame_filter_choices = ["All"] + all_files_display | |
| print("运行 MapAnything 模型...") | |
| with torch.no_grad(): | |
| predictions, processed_data, segmented_glb = run_model( | |
| target_dir, apply_mask, True, filter_black_bg, filter_white_bg, | |
| enable_segmentation, text_prompt, use_sam | |
| ) | |
| # 保存预测结果 | |
| prediction_save_path = os.path.join(target_dir, "predictions.npz") | |
| np.savez(prediction_save_path, **predictions) | |
| if frame_filter is None: | |
| frame_filter = "All" | |
| # 生成原始 GLB | |
| glbfile = os.path.join( | |
| target_dir, | |
| f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}.glb", | |
| ) | |
| glbscene = predictions_to_glb( | |
| predictions, | |
| filter_by_frames=frame_filter, | |
| show_cam=show_cam, | |
| mask_black_bg=filter_black_bg, | |
| mask_white_bg=filter_white_bg, | |
| as_mesh=show_mesh, | |
| conf_percentile=conf_thres, | |
| ) | |
| glbscene.export(file_obj=glbfile) | |
| # 清理 | |
| del predictions | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| end_time = time.time() | |
| print(f"总耗时: {end_time - start_time:.2f}秒") | |
| log_msg = f"✅ 重建成功 ({len(all_files)} 帧)" | |
| # 填充可视化标签 | |
| depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(processed_data) | |
| # 更新视图选择器 | |
| depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data) | |
| return ( | |
| glbfile, | |
| segmented_glb, | |
| log_msg, | |
| gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), | |
| processed_data, | |
| depth_vis, | |
| normal_vis, | |
| measure_img, | |
| "", | |
| depth_selector, | |
| normal_selector, | |
| measure_selector, | |
| ) | |
| def colorize_depth(depth_map, mask=None): | |
| """Convert depth map to colorized visualization""" | |
| if depth_map is None: | |
| return None | |
| depth_normalized = depth_map.copy() | |
| valid_mask = depth_normalized > 0 | |
| if mask is not None: | |
| valid_mask = valid_mask & mask | |
| if valid_mask.sum() > 0: | |
| valid_depths = depth_normalized[valid_mask] | |
| p5 = np.percentile(valid_depths, 5) | |
| p95 = np.percentile(valid_depths, 95) | |
| depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) | |
| import matplotlib.pyplot as plt | |
| colormap = plt.cm.turbo_r | |
| colored = colormap(depth_normalized) | |
| colored = (colored[:, :, :3] * 255).astype(np.uint8) | |
| colored[~valid_mask] = [255, 255, 255] | |
| return colored | |
| def colorize_normal(normal_map, mask=None): | |
| """Convert normal map to colorized visualization""" | |
| if normal_map is None: | |
| return None | |
| normal_vis = normal_map.copy() | |
| if mask is not None: | |
| invalid_mask = ~mask | |
| normal_vis[invalid_mask] = [0, 0, 0] | |
| normal_vis = (normal_vis + 1.0) / 2.0 | |
| normal_vis = (normal_vis * 255).astype(np.uint8) | |
| return normal_vis | |
| def process_predictions_for_visualization( | |
| predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False | |
| ): | |
| """Extract depth, normal, and 3D points from predictions for visualization""" | |
| processed_data = {} | |
| for view_idx, view in enumerate(views): | |
| image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) | |
| pred_pts3d = predictions["world_points"][view_idx] | |
| view_data = { | |
| "image": image[0], | |
| "points3d": pred_pts3d, | |
| "depth": None, | |
| "normal": None, | |
| "mask": None, | |
| } | |
| mask = predictions["final_mask"][view_idx].copy() | |
| if filter_black_bg: | |
| view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] | |
| black_bg_mask = view_colors.sum(axis=2) >= 16 | |
| mask = mask & black_bg_mask | |
| if filter_white_bg: | |
| view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] | |
| white_bg_mask = ~( | |
| (view_colors[:, :, 0] > 240) | |
| & (view_colors[:, :, 1] > 240) | |
| & (view_colors[:, :, 2] > 240) | |
| ) | |
| mask = mask & white_bg_mask | |
| view_data["mask"] = mask | |
| view_data["depth"] = predictions["depth"][view_idx].squeeze() | |
| normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) | |
| view_data["normal"] = normals | |
| processed_data[view_idx] = view_data | |
| return processed_data | |
| def reset_measure(processed_data): | |
| """Reset measure points""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return None, [], "" | |
| first_view = list(processed_data.values())[0] | |
| return first_view["image"], [], "" | |
| def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData): | |
| """Handle measurement on images""" | |
| try: | |
| if processed_data is None or len(processed_data) == 0: | |
| return None, [], "没有可用数据" | |
| try: | |
| current_view_index = int(current_view_selector.split()[1]) - 1 | |
| except: | |
| current_view_index = 0 | |
| if current_view_index < 0 or current_view_index >= len(processed_data): | |
| current_view_index = 0 | |
| view_keys = list(processed_data.keys()) | |
| current_view = processed_data[view_keys[current_view_index]] | |
| if current_view is None: | |
| return None, [], "没有视图数据" | |
| point2d = event.index[0], event.index[1] | |
| if ( | |
| current_view["mask"] is not None | |
| and 0 <= point2d[1] < current_view["mask"].shape[0] | |
| and 0 <= point2d[0] < current_view["mask"].shape[1] | |
| ): | |
| if not current_view["mask"][point2d[1], point2d[0]]: | |
| masked_image, _ = update_measure_view(processed_data, current_view_index) | |
| return ( | |
| masked_image, | |
| measure_points, | |
| '<span style="color: red; font-weight: bold;">无法在遮罩区域测量(显示为灰色)</span>', | |
| ) | |
| measure_points.append(point2d) | |
| image, _ = update_measure_view(processed_data, current_view_index) | |
| if image is None: | |
| return None, [], "没有可用图像" | |
| image = image.copy() | |
| points3d = current_view["points3d"] | |
| if image.dtype != np.uint8: | |
| if image.max() <= 1.0: | |
| image = (image * 255).astype(np.uint8) | |
| else: | |
| image = image.astype(np.uint8) | |
| for p in measure_points: | |
| if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: | |
| image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2) | |
| depth_text = "" | |
| for i, p in enumerate(measure_points): | |
| if ( | |
| current_view["depth"] is not None | |
| and 0 <= p[1] < current_view["depth"].shape[0] | |
| and 0 <= p[0] < current_view["depth"].shape[1] | |
| ): | |
| d = current_view["depth"][p[1], p[0]] | |
| depth_text += f"- **P{i + 1} 深度: {d:.2f}m.**\n" | |
| else: | |
| if ( | |
| points3d is not None | |
| and 0 <= p[1] < points3d.shape[0] | |
| and 0 <= p[0] < points3d.shape[1] | |
| ): | |
| z = points3d[p[1], p[0], 2] | |
| depth_text += f"- **P{i + 1} Z坐标: {z:.2f}m.**\n" | |
| if len(measure_points) == 2: | |
| point1, point2 = measure_points | |
| if ( | |
| 0 <= point1[0] < image.shape[1] | |
| and 0 <= point1[1] < image.shape[0] | |
| and 0 <= point2[0] < image.shape[1] | |
| and 0 <= point2[1] < image.shape[0] | |
| ): | |
| image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2) | |
| distance_text = "- **距离: 无法计算**" | |
| if ( | |
| points3d is not None | |
| and 0 <= point1[1] < points3d.shape[0] | |
| and 0 <= point1[0] < points3d.shape[1] | |
| and 0 <= point2[1] < points3d.shape[0] | |
| and 0 <= point2[0] < points3d.shape[1] | |
| ): | |
| try: | |
| p1_3d = points3d[point1[1], point1[0]] | |
| p2_3d = points3d[point2[1], point2[0]] | |
| distance = np.linalg.norm(p1_3d - p2_3d) | |
| distance_text = f"- **距离: {distance:.2f}m**" | |
| except Exception as e: | |
| distance_text = f"- **距离计算错误: {e}**" | |
| measure_points = [] | |
| text = depth_text + distance_text | |
| return [image, measure_points, text] | |
| else: | |
| return [image, measure_points, depth_text] | |
| except Exception as e: | |
| print(f"测量错误: {e}") | |
| return None, [], f"测量错误: {e}" | |
| def clear_fields(): | |
| """Clear 3D viewer""" | |
| return None, None | |
| def update_log(): | |
| """Display log message""" | |
| return "加载和重建中..." | |
| def update_visualization( | |
| target_dir, | |
| frame_filter, | |
| show_cam, | |
| is_example, | |
| conf_thres=None, | |
| filter_black_bg=False, | |
| filter_white_bg=False, | |
| show_mesh=True, | |
| ): | |
| """Update visualization""" | |
| if is_example == "True": | |
| return gr.update(), "没有可用的重建。请先点击重建按钮。" | |
| if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): | |
| return gr.update(), "没有可用的重建。请先点击重建按钮。" | |
| predictions_path = os.path.join(target_dir, "predictions.npz") | |
| if not os.path.exists(predictions_path): | |
| return gr.update(), f"没有可用的重建。请先运行「重建」。" | |
| loaded = np.load(predictions_path, allow_pickle=True) | |
| predictions = {key: loaded[key] for key in loaded.keys()} | |
| glbfile = os.path.join( | |
| target_dir, | |
| f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", | |
| ) | |
| glbscene = predictions_to_glb( | |
| predictions, | |
| filter_by_frames=frame_filter, | |
| show_cam=show_cam, | |
| mask_black_bg=filter_black_bg, | |
| mask_white_bg=filter_white_bg, | |
| as_mesh=show_mesh, | |
| conf_percentile=conf_thres, | |
| ) | |
| glbscene.export(file_obj=glbfile) | |
| return glbfile, "可视化已更新。" | |
| def update_all_views_on_filter_change( | |
| target_dir, | |
| filter_black_bg, | |
| filter_white_bg, | |
| processed_data, | |
| depth_view_selector, | |
| normal_view_selector, | |
| measure_view_selector, | |
| ): | |
| """Update all views on filter change""" | |
| if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): | |
| return processed_data, None, None, None, [] | |
| predictions_path = os.path.join(target_dir, "predictions.npz") | |
| if not os.path.exists(predictions_path): | |
| return processed_data, None, None, None, [] | |
| try: | |
| loaded = np.load(predictions_path, allow_pickle=True) | |
| predictions = {key: loaded[key] for key in loaded.keys()} | |
| image_folder_path = os.path.join(target_dir, "images") | |
| views = load_images(image_folder_path) | |
| new_processed_data = process_predictions_for_visualization( | |
| predictions, views, high_level_config, filter_black_bg, filter_white_bg | |
| ) | |
| try: | |
| depth_view_idx = int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0 | |
| except: | |
| depth_view_idx = 0 | |
| try: | |
| normal_view_idx = int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0 | |
| except: | |
| normal_view_idx = 0 | |
| try: | |
| measure_view_idx = int(measure_view_selector.split()[1]) - 1 if measure_view_selector else 0 | |
| except: | |
| measure_view_idx = 0 | |
| depth_vis = update_depth_view(new_processed_data, depth_view_idx) | |
| normal_vis = update_normal_view(new_processed_data, normal_view_idx) | |
| measure_img, _ = update_measure_view(new_processed_data, measure_view_idx) | |
| return new_processed_data, depth_vis, normal_vis, measure_img, [] | |
| except Exception as e: | |
| print(f"更新视图失败: {e}") | |
| return processed_data, None, None, None, [] | |
| # ============================================================================ | |
| # 示例场景 | |
| # ============================================================================ | |
| def get_scene_info(examples_dir): | |
| """Get information about scenes in the examples directory""" | |
| import glob | |
| scenes = [] | |
| if not os.path.exists(examples_dir): | |
| return scenes | |
| for scene_folder in sorted(os.listdir(examples_dir)): | |
| scene_path = os.path.join(examples_dir, scene_folder) | |
| if os.path.isdir(scene_path): | |
| image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] | |
| image_files = [] | |
| for ext in image_extensions: | |
| image_files.extend(glob.glob(os.path.join(scene_path, ext))) | |
| image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) | |
| if image_files: | |
| image_files = sorted(image_files) | |
| first_image = image_files[0] | |
| num_images = len(image_files) | |
| scenes.append( | |
| { | |
| "name": scene_folder, | |
| "path": scene_path, | |
| "thumbnail": first_image, | |
| "num_images": num_images, | |
| "image_files": image_files, | |
| } | |
| ) | |
| return scenes | |
| def load_example_scene(scene_name, examples_dir="examples"): | |
| """Load a scene from examples directory""" | |
| scenes = get_scene_info(examples_dir) | |
| selected_scene = None | |
| for scene in scenes: | |
| if scene["name"] == scene_name: | |
| selected_scene = scene | |
| break | |
| if selected_scene is None: | |
| return None, None, None, None, "场景未找到" | |
| target_dir, image_paths = handle_uploads(None, selected_scene["image_files"]) | |
| return ( | |
| None, | |
| None, | |
| target_dir, | |
| image_paths, | |
| f"已加载场景 '{scene_name}' ({selected_scene['num_images']} 张图像)。点击「重建」开始 3D 处理。", | |
| ) | |
| # ============================================================================ | |
| # Gradio UI | |
| # ============================================================================ | |
| theme = get_gradio_theme() | |
| # 自定义CSS防止UI抖动 | |
| CUSTOM_CSS = GRADIO_CSS + """ | |
| /* 防止组件撑开布局 */ | |
| .gradio-container { | |
| max-width: 100% !important; | |
| } | |
| /* 固定Gallery高度 */ | |
| .gallery-container { | |
| max-height: 350px !important; | |
| overflow-y: auto !important; | |
| } | |
| /* 固定File组件高度 */ | |
| .file-preview { | |
| max-height: 200px !important; | |
| overflow-y: auto !important; | |
| } | |
| /* 固定Video组件高度 */ | |
| .video-container { | |
| max-height: 300px !important; | |
| } | |
| /* 防止Textbox无限扩展 */ | |
| .textbox-container { | |
| max-height: 100px !important; | |
| } | |
| /* 保持Tabs内容区域稳定 */ | |
| .tab-content { | |
| min-height: 550px !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V8 - 3D重建与物体分割") as demo: | |
| is_example = gr.Textbox(label="is_example", visible=False, value="None") | |
| processed_data_state = gr.State(value=None) | |
| measure_points_state = gr.State(value=[]) | |
| # 顶部标题 | |
| gr.HTML(""" | |
| <div style="text-align: center; margin: 20px 0;"> | |
| <h2 style="color: #1976D2; margin-bottom: 10px;">MapAnything V8 - 3D重建与物体分割</h2> | |
| <p style="color: #666; font-size: 16px;">基于DBSCAN聚类的智能物体识别 | 多视图融合 | 自适应参数调整</p> | |
| </div> | |
| """) | |
| target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") | |
| with gr.Row(equal_height=False): | |
| # 左侧:输入区域 | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### 📤 输入") | |
| with gr.Tabs(): | |
| with gr.Tab("📷 图片"): | |
| input_images = gr.File( | |
| file_count="multiple", | |
| label="上传多张图片(推荐3-10张)", | |
| interactive=True, | |
| height=200 | |
| ) | |
| with gr.Tab("🎥 视频"): | |
| input_video = gr.Video( | |
| label="上传视频", | |
| interactive=True, | |
| height=300 | |
| ) | |
| s_time_interval = gr.Slider( | |
| minimum=0.1, maximum=5.0, value=1.0, step=0.1, | |
| label="帧采样间隔(秒)", interactive=True | |
| ) | |
| image_gallery = gr.Gallery( | |
| label="图片预览", columns=3, height=350, | |
| show_download_button=True, object_fit="contain", preview=True | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("🚀 开始重建", variant="primary", scale=2) | |
| clear_btn = gr.ClearButton( | |
| [input_video, input_images, target_dir_output, image_gallery], | |
| value="🗑️ 清空", scale=1 | |
| ) | |
| # 右侧:输出区域 | |
| with gr.Column(scale=2, min_width=600): | |
| gr.Markdown("### 🎯 输出") | |
| with gr.Tabs(): | |
| with gr.Tab("🏗️ 原始3D"): | |
| reconstruction_output = gr.Model3D( | |
| height=550, zoom_speed=0.5, pan_speed=0.5, | |
| clear_color=[0.0, 0.0, 0.0, 0.0] | |
| ) | |
| with gr.Tab("🎨 分割3D"): | |
| segmented_output = gr.Model3D( | |
| height=550, zoom_speed=0.5, pan_speed=0.5, | |
| clear_color=[0.0, 0.0, 0.0, 0.0] | |
| ) | |
| with gr.Tab("📊 深度图"): | |
| with gr.Row(elem_classes=["navigation-row"]): | |
| prev_depth_btn = gr.Button("◀", size="sm", scale=1) | |
| depth_view_selector = gr.Dropdown( | |
| choices=["View 1"], value="View 1", | |
| label="视图", scale=3, interactive=True | |
| ) | |
| next_depth_btn = gr.Button("▶", size="sm", scale=1) | |
| depth_map = gr.Image( | |
| type="numpy", label="", format="png", interactive=False, | |
| height=500 | |
| ) | |
| with gr.Tab("🧭 法线图"): | |
| with gr.Row(elem_classes=["navigation-row"]): | |
| prev_normal_btn = gr.Button("◀", size="sm", scale=1) | |
| normal_view_selector = gr.Dropdown( | |
| choices=["View 1"], value="View 1", | |
| label="视图", scale=3, interactive=True | |
| ) | |
| next_normal_btn = gr.Button("▶", size="sm", scale=1) | |
| normal_map = gr.Image( | |
| type="numpy", label="", format="png", interactive=False, | |
| height=500 | |
| ) | |
| with gr.Tab("📏 测量"): | |
| gr.Markdown("**点击图片两次进行距离测量**") | |
| with gr.Row(elem_classes=["navigation-row"]): | |
| prev_measure_btn = gr.Button("◀", size="sm", scale=1) | |
| measure_view_selector = gr.Dropdown( | |
| choices=["View 1"], value="View 1", | |
| label="视图", scale=3, interactive=True | |
| ) | |
| next_measure_btn = gr.Button("▶", size="sm", scale=1) | |
| measure_image = gr.Image( | |
| type="numpy", show_label=False, | |
| format="webp", interactive=False, sources=[], | |
| height=500 | |
| ) | |
| measure_text = gr.Markdown("") | |
| log_output = gr.Textbox( | |
| value="📌 请上传图片或视频,然后点击「开始重建」", | |
| label="状态信息", | |
| interactive=False, | |
| lines=1, | |
| max_lines=1 | |
| ) | |
| # 高级选项(可折叠) | |
| with gr.Accordion("⚙️ 高级选项", open=False): | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("#### 可视化参数") | |
| frame_filter = gr.Dropdown( | |
| choices=["All"], value="All", label="显示帧" | |
| ) | |
| conf_thres = gr.Slider( | |
| minimum=0, maximum=100, value=0, step=0.1, | |
| label="置信度阈值(百分位)" | |
| ) | |
| show_cam = gr.Checkbox(label="显示相机", value=True) | |
| show_mesh = gr.Checkbox(label="显示网格", value=True) | |
| filter_black_bg = gr.Checkbox(label="过滤黑色背景", value=False) | |
| filter_white_bg = gr.Checkbox(label="过滤白色背景", value=False) | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("#### 重建参数") | |
| apply_mask_checkbox = gr.Checkbox( | |
| label="应用深度掩码", value=True | |
| ) | |
| gr.Markdown("#### 分割参数") | |
| enable_segmentation = gr.Checkbox( | |
| label="启用语义分割", value=False | |
| ) | |
| use_sam_checkbox = gr.Checkbox( | |
| label="使用SAM精确分割", value=True | |
| ) | |
| text_prompt = gr.Textbox( | |
| value=DEFAULT_TEXT_PROMPT, | |
| label="检测物体(用 . 分隔)", | |
| placeholder="例如: chair . table . sofa", | |
| lines=2, | |
| max_lines=2 | |
| ) | |
| with gr.Row(): | |
| detect_all_btn = gr.Button("🔍 检测所有", size="sm") | |
| restore_default_btn = gr.Button("↻ 默认", size="sm") | |
| # 示例场景(可折叠) | |
| with gr.Accordion("🖼️ 示例场景", open=False): | |
| scenes = get_scene_info("examples") | |
| if scenes: | |
| for i in range(0, len(scenes), 4): | |
| with gr.Row(equal_height=True): | |
| for j in range(4): | |
| scene_idx = i + j | |
| if scene_idx < len(scenes): | |
| scene = scenes[scene_idx] | |
| with gr.Column(scale=1, min_width=150): | |
| scene_img = gr.Image( | |
| value=scene["thumbnail"], | |
| height=150, | |
| interactive=False, | |
| show_label=False, | |
| sources=[], | |
| container=False | |
| ) | |
| gr.Markdown( | |
| f"**{scene['name']}** ({scene['num_images']}张)", | |
| elem_classes=["text-center"] | |
| ) | |
| scene_img.select( | |
| fn=lambda name=scene["name"]: load_example_scene(name), | |
| outputs=[ | |
| reconstruction_output, segmented_output, | |
| target_dir_output, image_gallery, log_output | |
| ] | |
| ) | |
| # === 事件绑定 === | |
| # 分割选项按钮 | |
| detect_all_btn.click( | |
| fn=lambda: COMMON_OBJECTS_PROMPT, | |
| outputs=[text_prompt] | |
| ) | |
| restore_default_btn.click( | |
| fn=lambda: DEFAULT_TEXT_PROMPT, | |
| outputs=[text_prompt] | |
| ) | |
| # 上传文件自动更新 | |
| input_video.change( | |
| fn=update_gallery_on_upload, | |
| inputs=[input_video, input_images, s_time_interval], | |
| outputs=[reconstruction_output, segmented_output, target_dir_output, image_gallery, log_output] | |
| ) | |
| input_images.change( | |
| fn=update_gallery_on_upload, | |
| inputs=[input_video, input_images, s_time_interval], | |
| outputs=[reconstruction_output, segmented_output, target_dir_output, image_gallery, log_output] | |
| ) | |
| # 重建按钮 | |
| submit_btn.click( | |
| fn=clear_fields, | |
| outputs=[reconstruction_output, segmented_output] | |
| ).then( | |
| fn=update_log, | |
| outputs=[log_output] | |
| ).then( | |
| fn=gradio_demo, | |
| inputs=[ | |
| target_dir_output, frame_filter, show_cam, | |
| filter_black_bg, filter_white_bg, conf_thres, | |
| apply_mask_checkbox, show_mesh, | |
| enable_segmentation, text_prompt, use_sam_checkbox | |
| ], | |
| outputs=[ | |
| reconstruction_output, segmented_output, log_output, frame_filter, | |
| processed_data_state, depth_map, normal_map, measure_image, | |
| measure_text, depth_view_selector, normal_view_selector, measure_view_selector | |
| ] | |
| ).then( | |
| fn=lambda: "False", | |
| outputs=[is_example] | |
| ) | |
| # 清空按钮 | |
| clear_btn.add([reconstruction_output, segmented_output, log_output]) | |
| # 可视化参数实时更新 | |
| for component in [frame_filter, show_cam, conf_thres, show_mesh]: | |
| component.change( | |
| fn=update_visualization, | |
| inputs=[ | |
| target_dir_output, frame_filter, show_cam, is_example, | |
| conf_thres, filter_black_bg, filter_white_bg, show_mesh | |
| ], | |
| outputs=[reconstruction_output, log_output] | |
| ) | |
| # 背景过滤器更新所有视图 | |
| for bg_filter in [filter_black_bg, filter_white_bg]: | |
| bg_filter.change( | |
| fn=update_all_views_on_filter_change, | |
| inputs=[ | |
| target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, | |
| depth_view_selector, normal_view_selector, measure_view_selector | |
| ], | |
| outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state] | |
| ) | |
| # 深度图导航 | |
| prev_depth_btn.click( | |
| fn=lambda pd, cs: navigate_depth_view(pd, cs, -1), | |
| inputs=[processed_data_state, depth_view_selector], | |
| outputs=[depth_view_selector, depth_map] | |
| ) | |
| next_depth_btn.click( | |
| fn=lambda pd, cs: navigate_depth_view(pd, cs, 1), | |
| inputs=[processed_data_state, depth_view_selector], | |
| outputs=[depth_view_selector, depth_map] | |
| ) | |
| depth_view_selector.change( | |
| fn=lambda pd, sv: update_depth_view(pd, int(sv.split()[1]) - 1) if sv else None, | |
| inputs=[processed_data_state, depth_view_selector], | |
| outputs=[depth_map] | |
| ) | |
| # 法线图导航 | |
| prev_normal_btn.click( | |
| fn=lambda pd, cs: navigate_normal_view(pd, cs, -1), | |
| inputs=[processed_data_state, normal_view_selector], | |
| outputs=[normal_view_selector, normal_map] | |
| ) | |
| next_normal_btn.click( | |
| fn=lambda pd, cs: navigate_normal_view(pd, cs, 1), | |
| inputs=[processed_data_state, normal_view_selector], | |
| outputs=[normal_view_selector, normal_map] | |
| ) | |
| normal_view_selector.change( | |
| fn=lambda pd, sv: update_normal_view(pd, int(sv.split()[1]) - 1) if sv else None, | |
| inputs=[processed_data_state, normal_view_selector], | |
| outputs=[normal_map] | |
| ) | |
| # 测量功能 | |
| measure_image.select( | |
| fn=measure, | |
| inputs=[processed_data_state, measure_points_state, measure_view_selector], | |
| outputs=[measure_image, measure_points_state, measure_text] | |
| ) | |
| prev_measure_btn.click( | |
| fn=lambda pd, cs: navigate_measure_view(pd, cs, -1), | |
| inputs=[processed_data_state, measure_view_selector], | |
| outputs=[measure_view_selector, measure_image, measure_points_state] | |
| ) | |
| next_measure_btn.click( | |
| fn=lambda pd, cs: navigate_measure_view(pd, cs, 1), | |
| inputs=[processed_data_state, measure_view_selector], | |
| outputs=[measure_view_selector, measure_image, measure_points_state] | |
| ) | |
| measure_view_selector.change( | |
| fn=lambda pd, sv: update_measure_view(pd, int(sv.split()[1]) - 1) if sv else (None, []), | |
| inputs=[processed_data_state, measure_view_selector], | |
| outputs=[measure_image, measure_points_state] | |
| ) | |
| # 启动信息 | |
| print("\n" + "="*60) | |
| print("🚀 MapAnything V8 - 3D重建与物体分割") | |
| print("="*60) | |
| print("📊 核心技术: 自适应DBSCAN聚类 + 多视图融合") | |
| print(f"🔧 质量控制: 置信度≥{MIN_DETECTION_CONFIDENCE} | 面积≥{MIN_MASK_AREA}px") | |
| print(f"🎯 聚类半径: 沙发{DBSCAN_EPS_CONFIG['sofa']}m | 桌子{DBSCAN_EPS_CONFIG['table']}m | 窗户{DBSCAN_EPS_CONFIG['window']}m | 默认{DBSCAN_EPS_CONFIG['default']}m") | |
| print("="*60 + "\n") | |
| demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False) | |