Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Upload 169 files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitignore +57 -0
- app.py +19 -0
- mapanything/__init__.py +0 -0
- mapanything/__pycache__/__init__.cpython-312.pyc +0 -0
- mapanything/datasets/__init__.py +177 -0
- mapanything/datasets/base/__init__.py +0 -0
- mapanything/datasets/base/base_dataset.py +697 -0
- mapanything/datasets/base/batched_sampler.py +431 -0
- mapanything/datasets/base/easy_dataset.py +478 -0
- mapanything/datasets/utils/__init__.py +0 -0
- mapanything/datasets/utils/data_splits.py +1734 -0
- mapanything/datasets/wai/__init__.py +0 -0
- mapanything/datasets/wai/ase.py +294 -0
- mapanything/datasets/wai/blendedmvs.py +313 -0
- mapanything/datasets/wai/dl3dv.py +356 -0
- mapanything/datasets/wai/dynamicreplica.py +297 -0
- mapanything/datasets/wai/eth3d.py +277 -0
- mapanything/datasets/wai/megadepth.py +314 -0
- mapanything/datasets/wai/mpsd.py +311 -0
- mapanything/datasets/wai/mvs_synth.py +308 -0
- mapanything/datasets/wai/paralleldomain4d.py +309 -0
- mapanything/datasets/wai/sailvos3d.py +308 -0
- mapanything/datasets/wai/scannetpp.py +307 -0
- mapanything/datasets/wai/spring.py +316 -0
- mapanything/datasets/wai/tav2_wb.py +328 -0
- mapanything/datasets/wai/unrealstereo4k.py +309 -0
- mapanything/models/__init__.py +190 -0
- mapanything/models/__pycache__/__init__.cpython-312.pyc +0 -0
- mapanything/models/external/README.md +5 -0
- mapanything/models/external/__init__.py +0 -0
- mapanything/models/external/anycalib/__init__.py +100 -0
- mapanything/models/external/dinov2/__init__.py +6 -0
- mapanything/models/external/dinov2/hub/__init__.py +4 -0
- mapanything/models/external/dinov2/hub/backbones.py +183 -0
- mapanything/models/external/dinov2/hub/utils.py +42 -0
- mapanything/models/external/dinov2/layers/__init__.py +14 -0
- mapanything/models/external/dinov2/layers/attention.py +90 -0
- mapanything/models/external/dinov2/layers/block.py +290 -0
- mapanything/models/external/dinov2/layers/dino_head.py +67 -0
- mapanything/models/external/dinov2/layers/drop_path.py +36 -0
- mapanything/models/external/dinov2/layers/layer_scale.py +26 -0
- mapanything/models/external/dinov2/layers/mlp.py +40 -0
- mapanything/models/external/dinov2/layers/patch_embed.py +100 -0
- mapanything/models/external/dinov2/layers/swiglu_ffn.py +71 -0
- mapanything/models/external/dinov2/models/__init__.py +44 -0
- mapanything/models/external/dinov2/models/vision_transformer.py +448 -0
- mapanything/models/external/dinov2/utils/__init__.py +4 -0
- mapanything/models/external/dinov2/utils/cluster.py +102 -0
- mapanything/models/external/dinov2/utils/config.py +74 -0
- mapanything/models/external/dinov2/utils/dtype.py +38 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,57 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Python
         | 
| 2 | 
            +
            __pycache__/
         | 
| 3 | 
            +
            *.py[cod]
         | 
| 4 | 
            +
            *$py.class
         | 
| 5 | 
            +
            *.so
         | 
| 6 | 
            +
            .Python
         | 
| 7 | 
            +
            build/
         | 
| 8 | 
            +
            develop-eggs/
         | 
| 9 | 
            +
            dist/
         | 
| 10 | 
            +
            downloads/
         | 
| 11 | 
            +
            eggs/
         | 
| 12 | 
            +
            .eggs/
         | 
| 13 | 
            +
            lib/
         | 
| 14 | 
            +
            lib64/
         | 
| 15 | 
            +
            parts/
         | 
| 16 | 
            +
            sdist/
         | 
| 17 | 
            +
            var/
         | 
| 18 | 
            +
            wheels/
         | 
| 19 | 
            +
            *.egg-info/
         | 
| 20 | 
            +
            .installed.cfg
         | 
| 21 | 
            +
            *.egg
         | 
| 22 | 
            +
            MANIFEST
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # Virtual Environment
         | 
| 25 | 
            +
            venv/
         | 
| 26 | 
            +
            ENV/
         | 
| 27 | 
            +
            env/
         | 
| 28 | 
            +
            .venv
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # IDE
         | 
| 31 | 
            +
            .vscode/
         | 
| 32 | 
            +
            .idea/
         | 
| 33 | 
            +
            *.swp
         | 
| 34 | 
            +
            *.swo
         | 
| 35 | 
            +
            *~
         | 
| 36 | 
            +
            .DS_Store
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            # HuggingFace Space 临时文件
         | 
| 39 | 
            +
            input_images_*/
         | 
| 40 | 
            +
            *.glb
         | 
| 41 | 
            +
            *.npz
         | 
| 42 | 
            +
            flagged/
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            # 本地模型缓存(已改用 HuggingFace)
         | 
| 45 | 
            +
            models/
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            # 日志
         | 
| 48 | 
            +
            *.log
         | 
| 49 | 
            +
            logs/
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            # 测试文件
         | 
| 52 | 
            +
            .pytest_cache/
         | 
| 53 | 
            +
            .coverage
         | 
| 54 | 
            +
            htmlcov/
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            # 系统文件
         | 
| 57 | 
            +
            Thumbs.db
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            HuggingFace Space 入口文件
         | 
| 5 | 
            +
            直接导入并运行 gradio_app_v8
         | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import sys
         | 
| 9 | 
            +
            from pathlib import Path
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # 添加 scripts 目录到 Python 路径
         | 
| 12 | 
            +
            scripts_dir = Path(__file__).parent / "scripts"
         | 
| 13 | 
            +
            sys.path.insert(0, str(scripts_dir))
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # 导入并运行主应用
         | 
| 16 | 
            +
            if __name__ == "__main__":
         | 
| 17 | 
            +
                # 导入 gradio_app_v8(会自动启动 demo)
         | 
| 18 | 
            +
                import gradio_app_v8
         | 
| 19 | 
            +
             | 
    	
        mapanything/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        mapanything/__pycache__/__init__.cpython-312.pyc
    ADDED
    
    | Binary file (154 Bytes). View file | 
|  | 
    	
        mapanything/datasets/__init__.py
    ADDED
    
    | @@ -0,0 +1,177 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            MapAnything Datasets
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from mapanything.datasets.wai.ase import ASEWAI  # noqa
         | 
| 13 | 
            +
            from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI  # noqa
         | 
| 14 | 
            +
            from mapanything.datasets.wai.dl3dv import DL3DVWAI  # noqa
         | 
| 15 | 
            +
            from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI  # noqa
         | 
| 16 | 
            +
            from mapanything.datasets.wai.eth3d import ETH3DWAI  # noqa
         | 
| 17 | 
            +
            from mapanything.datasets.wai.megadepth import MegaDepthWAI  # noqa
         | 
| 18 | 
            +
            from mapanything.datasets.wai.mpsd import MPSDWAI  # noqa
         | 
| 19 | 
            +
            from mapanything.datasets.wai.mvs_synth import MVSSynthWAI  # noqa
         | 
| 20 | 
            +
            from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI  # noqa
         | 
| 21 | 
            +
            from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI  # noqa
         | 
| 22 | 
            +
            from mapanything.datasets.wai.scannetpp import ScanNetPPWAI  # noqa
         | 
| 23 | 
            +
            from mapanything.datasets.wai.spring import SpringWAI  # noqa
         | 
| 24 | 
            +
            from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI  # noqa
         | 
| 25 | 
            +
            from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI  # noqa
         | 
| 26 | 
            +
            from mapanything.utils.train_tools import get_rank, get_world_size
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def get_test_data_loader(
         | 
| 30 | 
            +
                dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True
         | 
| 31 | 
            +
            ):
         | 
| 32 | 
            +
                "Get simple PyTorch dataloader corresponding to the testing dataset"
         | 
| 33 | 
            +
                # PyTorch dataset
         | 
| 34 | 
            +
                if isinstance(dataset, str):
         | 
| 35 | 
            +
                    dataset = eval(dataset)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                world_size = get_world_size()
         | 
| 38 | 
            +
                rank = get_rank()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                if torch.distributed.is_initialized():
         | 
| 41 | 
            +
                    sampler = torch.utils.data.DistributedSampler(
         | 
| 42 | 
            +
                        dataset,
         | 
| 43 | 
            +
                        num_replicas=world_size,
         | 
| 44 | 
            +
                        rank=rank,
         | 
| 45 | 
            +
                        shuffle=shuffle,
         | 
| 46 | 
            +
                        drop_last=drop_last,
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                elif shuffle:
         | 
| 49 | 
            +
                    sampler = torch.utils.data.RandomSampler(dataset)
         | 
| 50 | 
            +
                else:
         | 
| 51 | 
            +
                    sampler = torch.utils.data.SequentialSampler(dataset)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                data_loader = torch.utils.data.DataLoader(
         | 
| 54 | 
            +
                    dataset,
         | 
| 55 | 
            +
                    sampler=sampler,
         | 
| 56 | 
            +
                    batch_size=batch_size,
         | 
| 57 | 
            +
                    num_workers=num_workers,
         | 
| 58 | 
            +
                    pin_memory=pin_mem,
         | 
| 59 | 
            +
                    drop_last=drop_last,
         | 
| 60 | 
            +
                )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                return data_loader
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def get_test_many_ar_data_loader(
         | 
| 66 | 
            +
                dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True
         | 
| 67 | 
            +
            ):
         | 
| 68 | 
            +
                "Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios"
         | 
| 69 | 
            +
                # PyTorch dataset
         | 
| 70 | 
            +
                if isinstance(dataset, str):
         | 
| 71 | 
            +
                    dataset = eval(dataset)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                world_size = get_world_size()
         | 
| 74 | 
            +
                rank = get_rank()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                # Get BatchedMultiFeatureRandomSampler
         | 
| 77 | 
            +
                sampler = dataset.make_sampler(
         | 
| 78 | 
            +
                    batch_size,
         | 
| 79 | 
            +
                    shuffle=True,
         | 
| 80 | 
            +
                    world_size=world_size,
         | 
| 81 | 
            +
                    rank=rank,
         | 
| 82 | 
            +
                    drop_last=drop_last,
         | 
| 83 | 
            +
                    use_dynamic_sampler=False,
         | 
| 84 | 
            +
                )
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                # Init the data laoder
         | 
| 87 | 
            +
                data_loader = torch.utils.data.DataLoader(
         | 
| 88 | 
            +
                    dataset,
         | 
| 89 | 
            +
                    sampler=sampler,
         | 
| 90 | 
            +
                    batch_size=batch_size,
         | 
| 91 | 
            +
                    num_workers=num_workers,
         | 
| 92 | 
            +
                    pin_memory=pin_mem,
         | 
| 93 | 
            +
                    drop_last=drop_last,
         | 
| 94 | 
            +
                )
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                return data_loader
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            class DynamicBatchDatasetWrapper:
         | 
| 100 | 
            +
                """
         | 
| 101 | 
            +
                Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                The dynamic sampler returns batches (lists of tuples) instead of individual samples.
         | 
| 104 | 
            +
                This wrapper ensures that the underlying dataset's __getitem__ method gets called
         | 
| 105 | 
            +
                with individual tuples as expected.
         | 
| 106 | 
            +
                """
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def __init__(self, dataset):
         | 
| 109 | 
            +
                    self.dataset = dataset
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def __getitem__(self, batch_indices):
         | 
| 112 | 
            +
                    """
         | 
| 113 | 
            +
                    Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    Args:
         | 
| 116 | 
            +
                        batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    Returns:
         | 
| 119 | 
            +
                        List of samples from the underlying dataset
         | 
| 120 | 
            +
                    """
         | 
| 121 | 
            +
                    if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0:
         | 
| 122 | 
            +
                        # If it's a batch (list of tuples), process each item
         | 
| 123 | 
            +
                        if isinstance(batch_indices[0], (list, tuple)):
         | 
| 124 | 
            +
                            return [self.dataset[idx] for idx in batch_indices]
         | 
| 125 | 
            +
                        else:
         | 
| 126 | 
            +
                            # Single tuple, call dataset directly
         | 
| 127 | 
            +
                            return self.dataset[batch_indices]
         | 
| 128 | 
            +
                    else:
         | 
| 129 | 
            +
                        # Fallback for single index
         | 
| 130 | 
            +
                        return self.dataset[batch_indices]
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def __len__(self):
         | 
| 133 | 
            +
                    return len(self.dataset)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def __getattr__(self, name):
         | 
| 136 | 
            +
                    # Delegate all other attributes to the wrapped dataset
         | 
| 137 | 
            +
                    return getattr(self.dataset, name)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            def get_train_data_loader(
         | 
| 141 | 
            +
                dataset,
         | 
| 142 | 
            +
                max_num_of_imgs_per_gpu,
         | 
| 143 | 
            +
                num_workers=8,
         | 
| 144 | 
            +
                shuffle=True,
         | 
| 145 | 
            +
                drop_last=True,
         | 
| 146 | 
            +
                pin_mem=True,
         | 
| 147 | 
            +
            ):
         | 
| 148 | 
            +
                "Dynamic PyTorch dataloader corresponding to the training dataset"
         | 
| 149 | 
            +
                # PyTorch dataset
         | 
| 150 | 
            +
                if isinstance(dataset, str):
         | 
| 151 | 
            +
                    dataset = eval(dataset)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                world_size = get_world_size()
         | 
| 154 | 
            +
                rank = get_rank()
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                # Get DynamicBatchedMultiFeatureRandomSampler
         | 
| 157 | 
            +
                batch_sampler = dataset.make_sampler(
         | 
| 158 | 
            +
                    shuffle=shuffle,
         | 
| 159 | 
            +
                    world_size=world_size,
         | 
| 160 | 
            +
                    rank=rank,
         | 
| 161 | 
            +
                    drop_last=drop_last,
         | 
| 162 | 
            +
                    max_num_of_images_per_gpu=max_num_of_imgs_per_gpu,
         | 
| 163 | 
            +
                    use_dynamic_sampler=True,
         | 
| 164 | 
            +
                )
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                # Wrap the dataset to handle batch format from dynamic sampler
         | 
| 167 | 
            +
                wrapped_dataset = DynamicBatchDatasetWrapper(dataset)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                # Init the dynamic data loader
         | 
| 170 | 
            +
                data_loader = torch.utils.data.DataLoader(
         | 
| 171 | 
            +
                    wrapped_dataset,
         | 
| 172 | 
            +
                    batch_sampler=batch_sampler,
         | 
| 173 | 
            +
                    num_workers=num_workers,
         | 
| 174 | 
            +
                    pin_memory=pin_mem,
         | 
| 175 | 
            +
                )
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                return data_loader
         | 
    	
        mapanything/datasets/base/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        mapanything/datasets/base/base_dataset.py
    ADDED
    
    | @@ -0,0 +1,697 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            Base class for MapAnything datasets.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from typing import List, Tuple, Union
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import PIL
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torchvision.transforms as tvf
         | 
| 16 | 
            +
            from scipy.spatial.transform import Rotation
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from mapanything.datasets.base.easy_dataset import EasyDataset
         | 
| 19 | 
            +
            from mapanything.utils.cropping import (
         | 
| 20 | 
            +
                bbox_from_intrinsics_in_out,
         | 
| 21 | 
            +
                camera_matrix_of_crop,
         | 
| 22 | 
            +
                crop_image_and_other_optional_info,
         | 
| 23 | 
            +
                rescale_image_and_other_optional_info,
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
            from mapanything.utils.geometry import (
         | 
| 26 | 
            +
                depthmap_to_camera_coordinates,
         | 
| 27 | 
            +
                get_absolute_pointmaps_and_rays_info,
         | 
| 28 | 
            +
            )
         | 
| 29 | 
            +
            from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class BaseDataset(EasyDataset):
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                Define all basic options.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                Usage:
         | 
| 37 | 
            +
                    class MyDataset(BaseDataset):
         | 
| 38 | 
            +
                        def _get_views(self, idx):
         | 
| 39 | 
            +
                            views = []
         | 
| 40 | 
            +
                            views.append(dict(img=, ...))
         | 
| 41 | 
            +
                            return views
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def __init__(
         | 
| 45 | 
            +
                    self,
         | 
| 46 | 
            +
                    num_views: int,
         | 
| 47 | 
            +
                    variable_num_views: bool = False,
         | 
| 48 | 
            +
                    split: str = None,
         | 
| 49 | 
            +
                    covisibility_thres: float = None,
         | 
| 50 | 
            +
                    resolution: Union[int, Tuple[int, int], List[Tuple[int, int]]] = None,
         | 
| 51 | 
            +
                    principal_point_centered: bool = False,
         | 
| 52 | 
            +
                    transform: str = None,
         | 
| 53 | 
            +
                    data_norm_type: str = None,
         | 
| 54 | 
            +
                    aug_crop: int = 0,
         | 
| 55 | 
            +
                    seed: int = None,
         | 
| 56 | 
            +
                    max_num_retries: int = 5,
         | 
| 57 | 
            +
                ):
         | 
| 58 | 
            +
                    """
         | 
| 59 | 
            +
                    PyTorch dataset for multi-view images sampled from scenes, where the images form a single connected component.
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    Args:
         | 
| 62 | 
            +
                        num_views (int): Number of views.
         | 
| 63 | 
            +
                        variable_num_views (bool): If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2.
         | 
| 64 | 
            +
                                                   On by default for N-view train dataloader (hydra config).
         | 
| 65 | 
            +
                        split (str): 'train', 'val', 'test', etc.
         | 
| 66 | 
            +
                        covisibility_thres (float): Covisibility (%) threshold to determine if another image is a neighbor or not
         | 
| 67 | 
            +
                        resolution (int or tuple or list of tuples): Resolution of the images
         | 
| 68 | 
            +
                        principal_point_centered (bool): If True, the principal point is centered in the image.
         | 
| 69 | 
            +
                        transform (str): Transform to apply to the images. Options:
         | 
| 70 | 
            +
                        - 'colorjitter+grayscale+gaublur':
         | 
| 71 | 
            +
                            tvf.Compose([
         | 
| 72 | 
            +
                                tvf.RandomApply([tvf.ColorJittter(0.3, 0.4, 0.2, 0.1)], p=0.75),
         | 
| 73 | 
            +
                                tvf.RandomGrayscale(p=0.05),
         | 
| 74 | 
            +
                                tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05),
         | 
| 75 | 
            +
                            ]) after ImgNorm
         | 
| 76 | 
            +
                        - 'colorjitter': tvf.ColorJittter(0.5, 0.5, 0.5, 0.1) after ImgNorm
         | 
| 77 | 
            +
                        - 'imgnorm': ImgNorm only
         | 
| 78 | 
            +
                        data_norm_type (str): Image normalization type.
         | 
| 79 | 
            +
                                              For options, see UniCeption image normalization dict.
         | 
| 80 | 
            +
                        aug_crop (int): Augment crop. If int greater than 0, indicates the number of pixels to increase in target resolution.
         | 
| 81 | 
            +
                        seed (int): Seed for the random number generator.
         | 
| 82 | 
            +
                        max_num_retries (int): Maximum number of retries for loading a different sample from the dataset, if provided idx fails.
         | 
| 83 | 
            +
                    """
         | 
| 84 | 
            +
                    self.num_views = num_views
         | 
| 85 | 
            +
                    self.variable_num_views = variable_num_views
         | 
| 86 | 
            +
                    self.num_views_min = 2
         | 
| 87 | 
            +
                    self.split = split
         | 
| 88 | 
            +
                    self.covisibility_thres = covisibility_thres
         | 
| 89 | 
            +
                    self._set_resolutions(resolution)
         | 
| 90 | 
            +
                    self.principal_point_centered = principal_point_centered
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    # Update the number of views if necessary and make it a list if variable_num_views is True
         | 
| 93 | 
            +
                    if self.variable_num_views and self.num_views > self.num_views_min:
         | 
| 94 | 
            +
                        self.num_views = list(range(self.num_views_min, self.num_views + 1))
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    # Initialize the image normalization type
         | 
| 97 | 
            +
                    if data_norm_type in IMAGE_NORMALIZATION_DICT.keys():
         | 
| 98 | 
            +
                        self.data_norm_type = data_norm_type
         | 
| 99 | 
            +
                        image_norm = IMAGE_NORMALIZATION_DICT[data_norm_type]
         | 
| 100 | 
            +
                        ImgNorm = tvf.Compose(
         | 
| 101 | 
            +
                            [
         | 
| 102 | 
            +
                                tvf.ToTensor(),
         | 
| 103 | 
            +
                                tvf.Normalize(mean=image_norm.mean, std=image_norm.std),
         | 
| 104 | 
            +
                            ]
         | 
| 105 | 
            +
                        )
         | 
| 106 | 
            +
                    elif data_norm_type == "identity":
         | 
| 107 | 
            +
                        self.data_norm_type = data_norm_type
         | 
| 108 | 
            +
                        ImgNorm = tvf.Compose([tvf.ToTensor()])
         | 
| 109 | 
            +
                    else:
         | 
| 110 | 
            +
                        raise ValueError(
         | 
| 111 | 
            +
                            f"Unknown data_norm_type: {data_norm_type}. Available options: identity or {list(IMAGE_NORMALIZATION_DICT.keys())}"
         | 
| 112 | 
            +
                        )
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # Initialize torchvision transforms
         | 
| 115 | 
            +
                    if transform == "imgnorm":
         | 
| 116 | 
            +
                        self.transform = ImgNorm
         | 
| 117 | 
            +
                    elif transform == "colorjitter":
         | 
| 118 | 
            +
                        self.transform = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
         | 
| 119 | 
            +
                    elif transform == "colorjitter+grayscale+gaublur":
         | 
| 120 | 
            +
                        self.transform = tvf.Compose(
         | 
| 121 | 
            +
                            [
         | 
| 122 | 
            +
                                tvf.RandomApply([tvf.ColorJitter(0.3, 0.4, 0.2, 0.1)], p=0.75),
         | 
| 123 | 
            +
                                tvf.RandomGrayscale(p=0.05),
         | 
| 124 | 
            +
                                tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05),
         | 
| 125 | 
            +
                                ImgNorm,
         | 
| 126 | 
            +
                            ]
         | 
| 127 | 
            +
                        )
         | 
| 128 | 
            +
                    else:
         | 
| 129 | 
            +
                        raise ValueError(
         | 
| 130 | 
            +
                            'Unknown transform. Available options: "imgnorm", "colorjitter", "colorjitter+grayscale+gaublur"'
         | 
| 131 | 
            +
                        )
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # Initialize the augmentation parameters
         | 
| 134 | 
            +
                    self.aug_crop = aug_crop
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # Initialize the seed for the random number generator
         | 
| 137 | 
            +
                    self.seed = seed
         | 
| 138 | 
            +
                    self._seed_offset = 0
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # Initialize the maximum number of retries for loading a different sample from the dataset, if the first idx fails
         | 
| 141 | 
            +
                    self.max_num_retries = max_num_retries
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    # Initialize the dataset type flags
         | 
| 144 | 
            +
                    self.is_metric_scale = False  # by default a dataset is not metric scale, subclasses can overwrite this
         | 
| 145 | 
            +
                    self.is_synthetic = False  # by default a dataset is not synthetic, subclasses can overwrite this
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def _load_data(self):
         | 
| 148 | 
            +
                    self.scenes = []
         | 
| 149 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def __len__(self):
         | 
| 152 | 
            +
                    "Length of the dataset is determined by the number of scenes in the dataset split"
         | 
| 153 | 
            +
                    return self.num_of_scenes
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def get_stats(self):
         | 
| 156 | 
            +
                    "Get the number of scenes in the dataset split"
         | 
| 157 | 
            +
                    return f"{self.num_of_scenes} scenes"
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def __repr__(self):
         | 
| 160 | 
            +
                    resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
         | 
| 161 | 
            +
                    return (
         | 
| 162 | 
            +
                        f"""{type(self).__name__}({self.get_stats()},
         | 
| 163 | 
            +
                        {self.num_views=}
         | 
| 164 | 
            +
                        {self.split=},
         | 
| 165 | 
            +
                        {self.seed=},
         | 
| 166 | 
            +
                        resolutions={resolutions_str},
         | 
| 167 | 
            +
                        {self.transform=})""".replace("self.", "")
         | 
| 168 | 
            +
                        .replace("\n", "")
         | 
| 169 | 
            +
                        .replace("   ", "")
         | 
| 170 | 
            +
                    )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                def _get_views(self, idx, num_views_to_sample, resolution):
         | 
| 173 | 
            +
                    raise NotImplementedError()
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def _set_seed_offset(self, idx):
         | 
| 176 | 
            +
                    """
         | 
| 177 | 
            +
                    Set the seed offset. This is directly added to self.seed when setting the random seed.
         | 
| 178 | 
            +
                    """
         | 
| 179 | 
            +
                    self._seed_offset = idx
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def _set_resolutions(self, resolutions):
         | 
| 182 | 
            +
                    assert resolutions is not None, "undefined resolution"
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if isinstance(resolutions, int):
         | 
| 185 | 
            +
                        resolutions = [resolutions]
         | 
| 186 | 
            +
                    elif isinstance(resolutions, tuple):
         | 
| 187 | 
            +
                        resolutions = [resolutions]
         | 
| 188 | 
            +
                    elif isinstance(resolutions, list):
         | 
| 189 | 
            +
                        assert all(isinstance(res, tuple) for res in resolutions), (
         | 
| 190 | 
            +
                            f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints"
         | 
| 191 | 
            +
                        )
         | 
| 192 | 
            +
                    else:
         | 
| 193 | 
            +
                        raise ValueError(
         | 
| 194 | 
            +
                            f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints"
         | 
| 195 | 
            +
                        )
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    self._resolutions = []
         | 
| 198 | 
            +
                    for resolution in resolutions:
         | 
| 199 | 
            +
                        if isinstance(resolution, int):
         | 
| 200 | 
            +
                            width = height = resolution
         | 
| 201 | 
            +
                        else:
         | 
| 202 | 
            +
                            width, height = resolution
         | 
| 203 | 
            +
                        assert isinstance(width, int), (
         | 
| 204 | 
            +
                            f"Bad type for {width=} {type(width)=}, should be int"
         | 
| 205 | 
            +
                        )
         | 
| 206 | 
            +
                        assert isinstance(height, int), (
         | 
| 207 | 
            +
                            f"Bad type for {height=} {type(height)=}, should be int"
         | 
| 208 | 
            +
                        )
         | 
| 209 | 
            +
                        self._resolutions.append((width, height))
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def _crop_resize_if_necessary(
         | 
| 212 | 
            +
                    self,
         | 
| 213 | 
            +
                    image,
         | 
| 214 | 
            +
                    resolution,
         | 
| 215 | 
            +
                    depthmap,
         | 
| 216 | 
            +
                    intrinsics,
         | 
| 217 | 
            +
                    additional_quantities=None,
         | 
| 218 | 
            +
                ):
         | 
| 219 | 
            +
                    """
         | 
| 220 | 
            +
                    Process an image by downsampling and cropping as needed to match the target resolution.
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    This method performs the following operations:
         | 
| 223 | 
            +
                    1. Converts the image to PIL.Image if necessary
         | 
| 224 | 
            +
                    2. Crops the image centered on the principal point if requested
         | 
| 225 | 
            +
                    3. Downsamples the image using high-quality Lanczos filtering
         | 
| 226 | 
            +
                    4. Performs final cropping to match the target resolution
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    Args:
         | 
| 229 | 
            +
                        image (numpy.ndarray or PIL.Image.Image): Input image to be processed
         | 
| 230 | 
            +
                        resolution (tuple): Target resolution as (width, height)
         | 
| 231 | 
            +
                        depthmap (numpy.ndarray): Depth map corresponding to the image
         | 
| 232 | 
            +
                        intrinsics (numpy.ndarray): Camera intrinsics matrix (3x3)
         | 
| 233 | 
            +
                        additional_quantities (dict, optional): Additional image-related data to be processed
         | 
| 234 | 
            +
                                                               alongside the main image with nearest interpolation. Defaults to None.
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    Returns:
         | 
| 237 | 
            +
                        tuple: Processed image, depthmap, and updated intrinsics matrix.
         | 
| 238 | 
            +
                              If additional_quantities is provided, it returns those as well.
         | 
| 239 | 
            +
                    """
         | 
| 240 | 
            +
                    if not isinstance(image, PIL.Image.Image):
         | 
| 241 | 
            +
                        image = PIL.Image.fromarray(image)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    # Cropping centered on the principal point if necessary
         | 
| 244 | 
            +
                    if self.principal_point_centered:
         | 
| 245 | 
            +
                        W, H = image.size
         | 
| 246 | 
            +
                        cx, cy = intrinsics[:2, 2].round().astype(int)
         | 
| 247 | 
            +
                        if cx < 0 or cx >= W or cy < 0 or cy >= H:
         | 
| 248 | 
            +
                            # Skip centered cropping if principal point is outside image bounds
         | 
| 249 | 
            +
                            pass
         | 
| 250 | 
            +
                        else:
         | 
| 251 | 
            +
                            min_margin_x = min(cx, W - cx)
         | 
| 252 | 
            +
                            min_margin_y = min(cy, H - cy)
         | 
| 253 | 
            +
                            left, top = cx - min_margin_x, cy - min_margin_y
         | 
| 254 | 
            +
                            right, bottom = cx + min_margin_x, cy + min_margin_y
         | 
| 255 | 
            +
                            crop_bbox = (left, top, right, bottom)
         | 
| 256 | 
            +
                            # Only perform the centered crop if the crop_bbox is larger than the target resolution
         | 
| 257 | 
            +
                            crop_width = right - left
         | 
| 258 | 
            +
                            crop_height = bottom - top
         | 
| 259 | 
            +
                            if crop_width > resolution[0] and crop_height > resolution[1]:
         | 
| 260 | 
            +
                                image, depthmap, intrinsics, additional_quantities = (
         | 
| 261 | 
            +
                                    crop_image_and_other_optional_info(
         | 
| 262 | 
            +
                                        image=image,
         | 
| 263 | 
            +
                                        crop_bbox=crop_bbox,
         | 
| 264 | 
            +
                                        depthmap=depthmap,
         | 
| 265 | 
            +
                                        camera_intrinsics=intrinsics,
         | 
| 266 | 
            +
                                        additional_quantities=additional_quantities,
         | 
| 267 | 
            +
                                    )
         | 
| 268 | 
            +
                                )
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    # Get the target resolution for re-scaling
         | 
| 271 | 
            +
                    target_rescale_resolution = np.array(resolution)
         | 
| 272 | 
            +
                    if self.aug_crop > 1:
         | 
| 273 | 
            +
                        target_rescale_resolution += self._rng.integers(0, self.aug_crop)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    # High-quality Lanczos down-scaling if necessary
         | 
| 276 | 
            +
                    image, depthmap, intrinsics, additional_quantities = (
         | 
| 277 | 
            +
                        rescale_image_and_other_optional_info(
         | 
| 278 | 
            +
                            image=image,
         | 
| 279 | 
            +
                            output_resolution=target_rescale_resolution,
         | 
| 280 | 
            +
                            depthmap=depthmap,
         | 
| 281 | 
            +
                            camera_intrinsics=intrinsics,
         | 
| 282 | 
            +
                            additional_quantities_to_be_resized_with_nearest=additional_quantities,
         | 
| 283 | 
            +
                        )
         | 
| 284 | 
            +
                    )
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # Actual cropping (if necessary)
         | 
| 287 | 
            +
                    new_intrinsics = camera_matrix_of_crop(
         | 
| 288 | 
            +
                        input_camera_matrix=intrinsics,
         | 
| 289 | 
            +
                        input_resolution=image.size,
         | 
| 290 | 
            +
                        output_resolution=resolution,
         | 
| 291 | 
            +
                        offset_factor=0.5,
         | 
| 292 | 
            +
                    )
         | 
| 293 | 
            +
                    crop_bbox = bbox_from_intrinsics_in_out(
         | 
| 294 | 
            +
                        input_camera_matrix=intrinsics,
         | 
| 295 | 
            +
                        output_camera_matrix=new_intrinsics,
         | 
| 296 | 
            +
                        output_resolution=resolution,
         | 
| 297 | 
            +
                    )
         | 
| 298 | 
            +
                    image, depthmap, new_intrinsics, additional_quantities = (
         | 
| 299 | 
            +
                        crop_image_and_other_optional_info(
         | 
| 300 | 
            +
                            image=image,
         | 
| 301 | 
            +
                            crop_bbox=crop_bbox,
         | 
| 302 | 
            +
                            depthmap=depthmap,
         | 
| 303 | 
            +
                            camera_intrinsics=intrinsics,
         | 
| 304 | 
            +
                            additional_quantities=additional_quantities,
         | 
| 305 | 
            +
                        )
         | 
| 306 | 
            +
                    )
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    # Return the output
         | 
| 309 | 
            +
                    if additional_quantities is not None:
         | 
| 310 | 
            +
                        return image, depthmap, new_intrinsics, additional_quantities
         | 
| 311 | 
            +
                    else:
         | 
| 312 | 
            +
                        return image, depthmap, new_intrinsics
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                def _random_walk_sampling(
         | 
| 315 | 
            +
                    self,
         | 
| 316 | 
            +
                    scene_pairwise_covisibility,
         | 
| 317 | 
            +
                    num_of_samples,
         | 
| 318 | 
            +
                    max_retries=4,
         | 
| 319 | 
            +
                    use_bidirectional_covis=True,
         | 
| 320 | 
            +
                ):
         | 
| 321 | 
            +
                    """
         | 
| 322 | 
            +
                    Randomly samples S indices from an N x N covisibility matrix by forming adjacency edges such that the resulting subgraph (given by the indices) is connected.
         | 
| 323 | 
            +
                    If the current node has no new unvisited neighbors, backtracking occurs.
         | 
| 324 | 
            +
                    Retries with different starting indices if the desired number of samples is not reached, excluding previously visited components.
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    Args:
         | 
| 327 | 
            +
                        scene_pairwise_covisibility : np.ndarray (mmap)
         | 
| 328 | 
            +
                            N x N covisibility matrix for the scene, where N is the number of views in the scene.
         | 
| 329 | 
            +
                        num_of_samples : int
         | 
| 330 | 
            +
                            The desired number of nodes to sample (num_of_samples < N).
         | 
| 331 | 
            +
                        max_retries : int
         | 
| 332 | 
            +
                            The maximum number of retries with different starting indices.
         | 
| 333 | 
            +
                        use_bidirectional_covis : bool
         | 
| 334 | 
            +
                            Whether to compute bidirectional covisibility by averaging row and column values.
         | 
| 335 | 
            +
                            If False, uses only row access (faster for large memory-mapped arrays).
         | 
| 336 | 
            +
                            Defaults to True.
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    Returns:
         | 
| 339 | 
            +
                        np.ndarray
         | 
| 340 | 
            +
                            An array of sampled indices forming a connected subgraph.
         | 
| 341 | 
            +
                    """
         | 
| 342 | 
            +
                    excluded_nodes = set()
         | 
| 343 | 
            +
                    best_walk = []  # To keep track of the best walk found
         | 
| 344 | 
            +
                    for _ in range(max_retries):
         | 
| 345 | 
            +
                        visited = set()
         | 
| 346 | 
            +
                        walk = []  # List to store the random walk sampling order
         | 
| 347 | 
            +
                        stack = []  # Stack for backtracking
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                        # Choose a random starting index that is not in the excluded set
         | 
| 350 | 
            +
                        all_nodes = set(range(len(scene_pairwise_covisibility)))
         | 
| 351 | 
            +
                        available_nodes = list(all_nodes - excluded_nodes)
         | 
| 352 | 
            +
                        if not available_nodes:
         | 
| 353 | 
            +
                            break  # No more nodes to try
         | 
| 354 | 
            +
                        start = self._rng.choice(available_nodes)
         | 
| 355 | 
            +
                        walk.append(start)
         | 
| 356 | 
            +
                        visited.add(start)
         | 
| 357 | 
            +
                        stack.append(start)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                        # Continue until we have sampled S indices or all expandable nodes are exhausted
         | 
| 360 | 
            +
                        while len(walk) < num_of_samples and stack:
         | 
| 361 | 
            +
                            current = stack[-1]
         | 
| 362 | 
            +
                            # Get the pairwise covisibility for the current node
         | 
| 363 | 
            +
                            if use_bidirectional_covis:
         | 
| 364 | 
            +
                                # Use bidirectional covisibility (slower for large memory-mapped arrays)
         | 
| 365 | 
            +
                                pairwise_covisibility = (
         | 
| 366 | 
            +
                                    scene_pairwise_covisibility[current, :]
         | 
| 367 | 
            +
                                    + scene_pairwise_covisibility[:, current].T
         | 
| 368 | 
            +
                                ) / 2
         | 
| 369 | 
            +
                            else:
         | 
| 370 | 
            +
                                # Use only row access (faster for large memory-mapped arrays)
         | 
| 371 | 
            +
                                pairwise_covisibility = scene_pairwise_covisibility[current, :]
         | 
| 372 | 
            +
                            # Normalize the covisibility using self covisibility
         | 
| 373 | 
            +
                            pairwise_covisibility = pairwise_covisibility / (
         | 
| 374 | 
            +
                                pairwise_covisibility[current] + 1e-8
         | 
| 375 | 
            +
                            )
         | 
| 376 | 
            +
                            # Assign overlap score of zero to self-pairs
         | 
| 377 | 
            +
                            pairwise_covisibility[current] = 0
         | 
| 378 | 
            +
                            # Threshold the covisibility to get adjacency list for the current node
         | 
| 379 | 
            +
                            adjacency_list_for_current = (
         | 
| 380 | 
            +
                                pairwise_covisibility > self.covisibility_thres
         | 
| 381 | 
            +
                            ).astype(int)
         | 
| 382 | 
            +
                            adjacency_list_for_current = np.flatnonzero(adjacency_list_for_current)
         | 
| 383 | 
            +
                            # Get all unvisited neighbors
         | 
| 384 | 
            +
                            candidates = [
         | 
| 385 | 
            +
                                idx for idx in adjacency_list_for_current if idx not in visited
         | 
| 386 | 
            +
                            ]  # Remove visited nodes
         | 
| 387 | 
            +
                            if candidates:
         | 
| 388 | 
            +
                                # Randomly select one of the unvisited overlapping neighbors
         | 
| 389 | 
            +
                                next_node = self._rng.choice(candidates)
         | 
| 390 | 
            +
                                walk.append(next_node)
         | 
| 391 | 
            +
                                visited.add(next_node)
         | 
| 392 | 
            +
                                stack.append(next_node)
         | 
| 393 | 
            +
                            else:
         | 
| 394 | 
            +
                                # If no unvisited neighbor is available, backtrack
         | 
| 395 | 
            +
                                stack.pop()
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                        # Update the best walk if the current walk is larger
         | 
| 398 | 
            +
                        if len(walk) > len(best_walk):
         | 
| 399 | 
            +
                            best_walk = walk
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                        # If we have enough samples, return the result
         | 
| 402 | 
            +
                        if len(walk) >= num_of_samples:
         | 
| 403 | 
            +
                            return np.array(walk)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                        # Add all visited nodes to the excluded set
         | 
| 406 | 
            +
                        excluded_nodes.update(visited)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    # If all retries are exhausted and we still don't have enough samples, return the best walk found
         | 
| 409 | 
            +
                    return np.array(best_walk)
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                def _sample_view_indices(
         | 
| 412 | 
            +
                    self,
         | 
| 413 | 
            +
                    num_views_to_sample,
         | 
| 414 | 
            +
                    num_views_in_scene,
         | 
| 415 | 
            +
                    scene_pairwise_covisibility,
         | 
| 416 | 
            +
                    use_bidirectional_covis=True,
         | 
| 417 | 
            +
                ):
         | 
| 418 | 
            +
                    """
         | 
| 419 | 
            +
                    Sample view indices from a scene based on the adjacency list and the number of views to sample.
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    Args:
         | 
| 422 | 
            +
                        num_views_to_sample (int): Number of views to sample.
         | 
| 423 | 
            +
                        num_views_in_scene (int): Total number of views available in the scene.
         | 
| 424 | 
            +
                        scene_pairwise_covisibility (np.ndarray): N x N covisibility matrix for the scene, where N is the number of views in the scene.
         | 
| 425 | 
            +
                        use_bidirectional_covis (bool): Whether to compute bidirectional covisibility by averaging row and column values.
         | 
| 426 | 
            +
                            If False, uses only row access (faster for large memory-mapped arrays).
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    Returns:
         | 
| 429 | 
            +
                        numpy.ndarray: Array of sampled view indices.
         | 
| 430 | 
            +
                    """
         | 
| 431 | 
            +
                    if num_views_to_sample == num_views_in_scene:
         | 
| 432 | 
            +
                        # Select all views in the scene
         | 
| 433 | 
            +
                        view_indices = self._rng.permutation(num_views_in_scene)
         | 
| 434 | 
            +
                    elif num_views_to_sample > num_views_in_scene:
         | 
| 435 | 
            +
                        # Select all views in the scene and repeat them to get the desired number of views
         | 
| 436 | 
            +
                        view_indices = self._rng.choice(
         | 
| 437 | 
            +
                            num_views_in_scene, size=num_views_to_sample, replace=True
         | 
| 438 | 
            +
                        )
         | 
| 439 | 
            +
                    else:
         | 
| 440 | 
            +
                        # Select a subset of single component connected views in the scene using random walk sampling
         | 
| 441 | 
            +
                        view_indices = self._random_walk_sampling(
         | 
| 442 | 
            +
                            scene_pairwise_covisibility,
         | 
| 443 | 
            +
                            num_views_to_sample,
         | 
| 444 | 
            +
                            use_bidirectional_covis=use_bidirectional_covis,
         | 
| 445 | 
            +
                        )
         | 
| 446 | 
            +
                        # If the required num of views can't be obtained even with 4 retries, repeat existing indices to get the desired number of views
         | 
| 447 | 
            +
                        if len(view_indices) < num_views_to_sample:
         | 
| 448 | 
            +
                            view_indices = self._rng.choice(
         | 
| 449 | 
            +
                                view_indices, size=num_views_to_sample, replace=True
         | 
| 450 | 
            +
                            )
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                    return view_indices
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                def _getitem_fn(self, idx):
         | 
| 455 | 
            +
                    if isinstance(idx, tuple):
         | 
| 456 | 
            +
                        # The idx is a tuple if specifying the aspect-ratio or/and the number of views
         | 
| 457 | 
            +
                        if isinstance(self.num_views, int):
         | 
| 458 | 
            +
                            idx, ar_idx = idx
         | 
| 459 | 
            +
                        else:
         | 
| 460 | 
            +
                            idx, ar_idx, num_views_to_sample_idx = idx
         | 
| 461 | 
            +
                    else:
         | 
| 462 | 
            +
                        assert len(self._resolutions) == 1
         | 
| 463 | 
            +
                        assert isinstance(self.num_views, int)
         | 
| 464 | 
            +
                        ar_idx = 0
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    # Setup the rng
         | 
| 467 | 
            +
                    if self.seed:  # reseed for each _getitem_fn
         | 
| 468 | 
            +
                        # Leads to deterministic sampling where repeating self.seed and self._seed_offset yields the same multi-view set again
         | 
| 469 | 
            +
                        # Scenes will be repeated if size of dataset is artificially increased using "N @" or "N *"
         | 
| 470 | 
            +
                        # When scenes are repeated, self._seed_offset is increased to ensure new multi-view sets
         | 
| 471 | 
            +
                        # This is useful for evaluation if the number of dataset scenes is < N, yet we want unique multi-view sets each iter
         | 
| 472 | 
            +
                        self._rng = np.random.default_rng(seed=self.seed + self._seed_offset + idx)
         | 
| 473 | 
            +
                    elif not hasattr(self, "_rng"):
         | 
| 474 | 
            +
                        seed = torch.initial_seed()  # this is different for each dataloader process
         | 
| 475 | 
            +
                        self._rng = np.random.default_rng(seed=seed)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    # Get the views for the given index and check that the number of views is correct
         | 
| 478 | 
            +
                    resolution = self._resolutions[ar_idx]
         | 
| 479 | 
            +
                    if isinstance(self.num_views, int):
         | 
| 480 | 
            +
                        num_views_to_sample = self.num_views
         | 
| 481 | 
            +
                    else:
         | 
| 482 | 
            +
                        num_views_to_sample = self.num_views[num_views_to_sample_idx]
         | 
| 483 | 
            +
                    views = self._get_views(idx, num_views_to_sample, resolution)
         | 
| 484 | 
            +
                    if isinstance(self.num_views, int):
         | 
| 485 | 
            +
                        assert len(views) == self.num_views
         | 
| 486 | 
            +
                    else:
         | 
| 487 | 
            +
                        assert len(views) in self.num_views
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    for v, view in enumerate(views):
         | 
| 490 | 
            +
                        # Store the index and other metadata
         | 
| 491 | 
            +
                        view["idx"] = (idx, ar_idx, v)
         | 
| 492 | 
            +
                        view["is_metric_scale"] = self.is_metric_scale
         | 
| 493 | 
            +
                        view["is_synthetic"] = self.is_synthetic
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                        # Check the depth, intrinsics, and pose data (also other data if present)
         | 
| 496 | 
            +
                        assert "camera_intrinsics" in view
         | 
| 497 | 
            +
                        assert "camera_pose" in view
         | 
| 498 | 
            +
                        assert np.isfinite(view["camera_pose"]).all(), (
         | 
| 499 | 
            +
                            f"NaN or infinite values in camera pose for view {view_name(view)}"
         | 
| 500 | 
            +
                        )
         | 
| 501 | 
            +
                        assert np.isfinite(view["depthmap"]).all(), (
         | 
| 502 | 
            +
                            f"NaN or infinite values in depthmap for view {view_name(view)}"
         | 
| 503 | 
            +
                        )
         | 
| 504 | 
            +
                        assert "valid_mask" not in view
         | 
| 505 | 
            +
                        assert "pts3d" not in view, (
         | 
| 506 | 
            +
                            f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
         | 
| 507 | 
            +
                        )
         | 
| 508 | 
            +
                        if "prior_depth_z" in view:
         | 
| 509 | 
            +
                            assert np.isfinite(view["prior_depth_z"]).all(), (
         | 
| 510 | 
            +
                                f"NaN or infinite values in prior_depth_z for view {view_name(view)}"
         | 
| 511 | 
            +
                            )
         | 
| 512 | 
            +
                        if "non_ambiguous_mask" in view:
         | 
| 513 | 
            +
                            assert np.isfinite(view["non_ambiguous_mask"]).all(), (
         | 
| 514 | 
            +
                                f"NaN or infinite values in non_ambiguous_mask for view {view_name(view)}"
         | 
| 515 | 
            +
                            )
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                        # Encode the image
         | 
| 518 | 
            +
                        width, height = view["img"].size
         | 
| 519 | 
            +
                        view["true_shape"] = np.int32((height, width))
         | 
| 520 | 
            +
                        view["img"] = self.transform(view["img"])
         | 
| 521 | 
            +
                        view["data_norm_type"] = self.data_norm_type
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                        # Compute the pointmaps, raymap and depth along ray
         | 
| 524 | 
            +
                        (
         | 
| 525 | 
            +
                            pts3d,
         | 
| 526 | 
            +
                            valid_mask,
         | 
| 527 | 
            +
                            ray_origins_world,
         | 
| 528 | 
            +
                            ray_directions_world,
         | 
| 529 | 
            +
                            depth_along_ray,
         | 
| 530 | 
            +
                            ray_directions_cam,
         | 
| 531 | 
            +
                            pts3d_cam,
         | 
| 532 | 
            +
                        ) = get_absolute_pointmaps_and_rays_info(**view)
         | 
| 533 | 
            +
                        view["pts3d"] = pts3d
         | 
| 534 | 
            +
                        view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
         | 
| 535 | 
            +
                        view["depth_along_ray"] = depth_along_ray
         | 
| 536 | 
            +
                        view["ray_directions_cam"] = ray_directions_cam
         | 
| 537 | 
            +
                        view["pts3d_cam"] = pts3d_cam
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                        # Compute the prior depth along ray if present
         | 
| 540 | 
            +
                        if "prior_depth_z" in view:
         | 
| 541 | 
            +
                            prior_pts3d, _ = depthmap_to_camera_coordinates(
         | 
| 542 | 
            +
                                view["prior_depth_z"], view["camera_intrinsics"]
         | 
| 543 | 
            +
                            )
         | 
| 544 | 
            +
                            view["prior_depth_along_ray"] = np.linalg.norm(prior_pts3d, axis=-1)
         | 
| 545 | 
            +
                            view["prior_depth_along_ray"] = view["prior_depth_along_ray"][..., None]
         | 
| 546 | 
            +
                            del view["prior_depth_z"]
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                        # Convert ambiguous mask dtype to match valid mask dtype
         | 
| 549 | 
            +
                        if "non_ambiguous_mask" in view:
         | 
| 550 | 
            +
                            view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype(
         | 
| 551 | 
            +
                                view["valid_mask"].dtype
         | 
| 552 | 
            +
                            )
         | 
| 553 | 
            +
                        else:
         | 
| 554 | 
            +
                            ambiguous_mask = view["depthmap"] < 0
         | 
| 555 | 
            +
                            view["non_ambiguous_mask"] = ~ambiguous_mask
         | 
| 556 | 
            +
                            view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype(
         | 
| 557 | 
            +
                                view["valid_mask"].dtype
         | 
| 558 | 
            +
                            )
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                        # Check all datatypes
         | 
| 561 | 
            +
                        for key, val in view.items():
         | 
| 562 | 
            +
                            res, err_msg = is_good_type(val)
         | 
| 563 | 
            +
                            assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                        # Check shapes
         | 
| 566 | 
            +
                        assert view["depthmap"].shape == view["img"].shape[1:]
         | 
| 567 | 
            +
                        assert view["depthmap"].shape == view["pts3d"].shape[:2]
         | 
| 568 | 
            +
                        assert view["depthmap"].shape == view["valid_mask"].shape
         | 
| 569 | 
            +
                        assert view["depthmap"].shape == view["depth_along_ray"].shape[:2]
         | 
| 570 | 
            +
                        assert view["depthmap"].shape == view["ray_directions_cam"].shape[:2]
         | 
| 571 | 
            +
                        assert view["depthmap"].shape == view["pts3d_cam"].shape[:2]
         | 
| 572 | 
            +
                        if "prior_depth_along_ray" in view:
         | 
| 573 | 
            +
                            assert view["depthmap"].shape == view["prior_depth_along_ray"].shape[:2]
         | 
| 574 | 
            +
                        if "non_ambiguous_mask" in view:
         | 
| 575 | 
            +
                            assert view["depthmap"].shape == view["non_ambiguous_mask"].shape
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                        # Expand the last dimension of the depthmap
         | 
| 578 | 
            +
                        view["depthmap"] = view["depthmap"][..., None]
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                        # Append RNG state to the views, this allows to check whether the RNG is in the same state each time
         | 
| 581 | 
            +
                        view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                        # Compute and store the quaternions and translation for the camera poses
         | 
| 584 | 
            +
                        # Notation is (x, y, z, w) for quaternions
         | 
| 585 | 
            +
                        # This also ensures that the camera poses have a positive determinant (right-handed coordinate system)
         | 
| 586 | 
            +
                        view["camera_pose_quats"] = (
         | 
| 587 | 
            +
                            Rotation.from_matrix(view["camera_pose"][:3, :3])
         | 
| 588 | 
            +
                            .as_quat()
         | 
| 589 | 
            +
                            .astype(view["camera_pose"].dtype)
         | 
| 590 | 
            +
                        )
         | 
| 591 | 
            +
                        view["camera_pose_trans"] = view["camera_pose"][:3, 3].astype(
         | 
| 592 | 
            +
                            view["camera_pose"].dtype
         | 
| 593 | 
            +
                        )
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                        # Check the pointmaps, rays, depth along ray, and camera pose quaternions and translation to ensure they are finite
         | 
| 596 | 
            +
                        assert np.isfinite(view["pts3d"]).all(), (
         | 
| 597 | 
            +
                            f"NaN in pts3d for view {view_name(view)}"
         | 
| 598 | 
            +
                        )
         | 
| 599 | 
            +
                        assert np.isfinite(view["valid_mask"]).all(), (
         | 
| 600 | 
            +
                            f"NaN in valid_mask for view {view_name(view)}"
         | 
| 601 | 
            +
                        )
         | 
| 602 | 
            +
                        assert np.isfinite(view["depth_along_ray"]).all(), (
         | 
| 603 | 
            +
                            f"NaN in depth_along_ray for view {view_name(view)}"
         | 
| 604 | 
            +
                        )
         | 
| 605 | 
            +
                        assert np.isfinite(view["ray_directions_cam"]).all(), (
         | 
| 606 | 
            +
                            f"NaN in ray_directions_cam for view {view_name(view)}"
         | 
| 607 | 
            +
                        )
         | 
| 608 | 
            +
                        assert np.isfinite(view["pts3d_cam"]).all(), (
         | 
| 609 | 
            +
                            f"NaN in pts3d_cam for view {view_name(view)}"
         | 
| 610 | 
            +
                        )
         | 
| 611 | 
            +
                        assert np.isfinite(view["camera_pose_quats"]).all(), (
         | 
| 612 | 
            +
                            f"NaN in camera_pose_quats for view {view_name(view)}"
         | 
| 613 | 
            +
                        )
         | 
| 614 | 
            +
                        assert np.isfinite(view["camera_pose_trans"]).all(), (
         | 
| 615 | 
            +
                            f"NaN in camera_pose_trans for view {view_name(view)}"
         | 
| 616 | 
            +
                        )
         | 
| 617 | 
            +
                        if "prior_depth_along_ray" in view:
         | 
| 618 | 
            +
                            assert np.isfinite(view["prior_depth_along_ray"]).all(), (
         | 
| 619 | 
            +
                                f"NaN in prior_depth_along_ray for view {view_name(view)}"
         | 
| 620 | 
            +
                            )
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                    return views
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                def __getitem__(self, idx):
         | 
| 625 | 
            +
                    if self.max_num_retries == 0:
         | 
| 626 | 
            +
                        return self._getitem_fn(idx)
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                    num_retries = 0
         | 
| 629 | 
            +
                    while num_retries <= self.max_num_retries:
         | 
| 630 | 
            +
                        try:
         | 
| 631 | 
            +
                            return self._getitem_fn(idx)
         | 
| 632 | 
            +
                        except Exception as e:
         | 
| 633 | 
            +
                            scene_idx = idx[0] if isinstance(idx, tuple) else idx
         | 
| 634 | 
            +
                            print(
         | 
| 635 | 
            +
                                f"Error in {type(self).__name__}.__getitem__ for scene_idx={scene_idx}: {e}"
         | 
| 636 | 
            +
                            )
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                            if num_retries >= self.max_num_retries:
         | 
| 639 | 
            +
                                print(
         | 
| 640 | 
            +
                                    f"Max retries ({self.max_num_retries}) reached, raising the exception"
         | 
| 641 | 
            +
                                )
         | 
| 642 | 
            +
                                raise e
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                            # Retry with a different scene index
         | 
| 645 | 
            +
                            num_retries += 1
         | 
| 646 | 
            +
                            if isinstance(idx, tuple):
         | 
| 647 | 
            +
                                # The scene index is the first element of the tuple
         | 
| 648 | 
            +
                                idx_list = list(idx)
         | 
| 649 | 
            +
                                idx_list[0] = np.random.randint(0, len(self))
         | 
| 650 | 
            +
                                idx = tuple(idx_list)
         | 
| 651 | 
            +
                            else:
         | 
| 652 | 
            +
                                # The scene index is idx
         | 
| 653 | 
            +
                                idx = np.random.randint(0, len(self))
         | 
| 654 | 
            +
                            scene_idx = idx[0] if isinstance(idx, tuple) else idx
         | 
| 655 | 
            +
                            print(
         | 
| 656 | 
            +
                                f"Retrying with scene_idx={scene_idx} ({num_retries} of {self.max_num_retries})"
         | 
| 657 | 
            +
                            )
         | 
| 658 | 
            +
             | 
| 659 | 
            +
             | 
| 660 | 
            +
            def is_good_type(v):
         | 
| 661 | 
            +
                """
         | 
| 662 | 
            +
                Check if a value has an acceptable data type for processing in the dataset.
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                Args:
         | 
| 665 | 
            +
                    v: The value to check.
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                Returns:
         | 
| 668 | 
            +
                    tuple: A tuple containing:
         | 
| 669 | 
            +
                        - bool: True if the type is acceptable, False otherwise.
         | 
| 670 | 
            +
                        - str or None: Error message if the type is not acceptable, None otherwise.
         | 
| 671 | 
            +
                """
         | 
| 672 | 
            +
                if isinstance(v, (str, int, tuple)):
         | 
| 673 | 
            +
                    return True, None
         | 
| 674 | 
            +
                if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
         | 
| 675 | 
            +
                    return False, f"bad {v.dtype=}"
         | 
| 676 | 
            +
                return True, None
         | 
| 677 | 
            +
             | 
| 678 | 
            +
             | 
| 679 | 
            +
            def view_name(view, batch_index=None):
         | 
| 680 | 
            +
                """
         | 
| 681 | 
            +
                Generate a string identifier for a view based on its dataset, label, and instance.
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                Args:
         | 
| 684 | 
            +
                    view (dict): Dictionary containing view information with 'dataset', 'label', and 'instance' keys.
         | 
| 685 | 
            +
                    batch_index (int, optional): Index to select from batched data. Defaults to None.
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                Returns:
         | 
| 688 | 
            +
                    str: A formatted string in the form "dataset/label/instance".
         | 
| 689 | 
            +
                """
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                def sel(x):
         | 
| 692 | 
            +
                    return x[batch_index] if batch_index not in (None, slice(None)) else x
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                db = sel(view["dataset"])
         | 
| 695 | 
            +
                label = sel(view["label"])
         | 
| 696 | 
            +
                instance = sel(view["instance"])
         | 
| 697 | 
            +
                return f"{db}/{label}/{instance}"
         | 
    	
        mapanything/datasets/base/batched_sampler.py
    ADDED
    
    | @@ -0,0 +1,431 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            Utilities for random sampling under a single or multiple constraints
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            References: DUSt3R
         | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def round_by(total, multiple, up=False):
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                Round a number to the nearest multiple of another number.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                Args:
         | 
| 21 | 
            +
                    total (int): The number to round
         | 
| 22 | 
            +
                    multiple (int): The multiple to round to
         | 
| 23 | 
            +
                    up (bool, optional): Whether to round up. Defaults to False.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                Returns:
         | 
| 26 | 
            +
                    int: The rounded number
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                if up:
         | 
| 29 | 
            +
                    total = total + multiple - 1
         | 
| 30 | 
            +
                return (total // multiple) * multiple
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class BatchedRandomSampler:
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                Random sampling under a constraint: each sample in the batch has the same feature,
         | 
| 36 | 
            +
                which is chosen randomly from a known pool of 'features' for each batch.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                For instance, the 'feature' could be the image aspect-ratio.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                The index returned is a tuple (sample_idx, feat_idx).
         | 
| 41 | 
            +
                This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def __init__(
         | 
| 45 | 
            +
                    self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True
         | 
| 46 | 
            +
                ):
         | 
| 47 | 
            +
                    """
         | 
| 48 | 
            +
                    Args:
         | 
| 49 | 
            +
                        dataset: Dataset to sample from
         | 
| 50 | 
            +
                        batch_size: Number of samples per batch
         | 
| 51 | 
            +
                        pool_size: Integer representing the size of feature pool
         | 
| 52 | 
            +
                        world_size: Number of distributed processes
         | 
| 53 | 
            +
                        rank: Rank of the current process
         | 
| 54 | 
            +
                        drop_last: Whether to drop the last incomplete batch
         | 
| 55 | 
            +
                    """
         | 
| 56 | 
            +
                    self.batch_size = batch_size
         | 
| 57 | 
            +
                    self.pool_size = pool_size
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    self.len_dataset = N = len(dataset)
         | 
| 60 | 
            +
                    self.total_size = round_by(N, batch_size * world_size) if drop_last else N
         | 
| 61 | 
            +
                    assert world_size == 1 or drop_last, (
         | 
| 62 | 
            +
                        "must drop the last batch in distributed mode"
         | 
| 63 | 
            +
                    )
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # Distributed sampler
         | 
| 66 | 
            +
                    self.world_size = world_size
         | 
| 67 | 
            +
                    self.rank = rank
         | 
| 68 | 
            +
                    self.epoch = None
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def __len__(self):
         | 
| 71 | 
            +
                    """
         | 
| 72 | 
            +
                    Get the length of the sampler.
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    Returns:
         | 
| 75 | 
            +
                        int: The number of samples in the sampler for the current process
         | 
| 76 | 
            +
                    """
         | 
| 77 | 
            +
                    return self.total_size // self.world_size
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def set_epoch(self, epoch):
         | 
| 80 | 
            +
                    """
         | 
| 81 | 
            +
                    Set the epoch for this sampler.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    This should be called before each epoch to ensure proper shuffling of the data.
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    Args:
         | 
| 86 | 
            +
                        epoch (int): The current epoch number
         | 
| 87 | 
            +
                    """
         | 
| 88 | 
            +
                    self.epoch = epoch
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def __iter__(self):
         | 
| 91 | 
            +
                    """
         | 
| 92 | 
            +
                    Iterator over the indices.
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    This method generates random indices for each batch, ensuring that all samples
         | 
| 95 | 
            +
                    within a batch have the same feature index for the given feature pool.
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    Yields:
         | 
| 98 | 
            +
                        tuple: A tuple containing (sample_idx, feat_idx)
         | 
| 99 | 
            +
                    """
         | 
| 100 | 
            +
                    # Prepare RNG
         | 
| 101 | 
            +
                    if self.epoch is None:
         | 
| 102 | 
            +
                        assert self.world_size == 1 and self.rank == 0, (
         | 
| 103 | 
            +
                            "use set_epoch() if distributed mode is used"
         | 
| 104 | 
            +
                        )
         | 
| 105 | 
            +
                        seed = int(torch.empty((), dtype=torch.int64).random_().item())
         | 
| 106 | 
            +
                    else:
         | 
| 107 | 
            +
                        seed = self.epoch + 777
         | 
| 108 | 
            +
                    rng = np.random.default_rng(seed=seed)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # Random indices (will restart from 0 if not drop_last)
         | 
| 111 | 
            +
                    sample_idxs = np.arange(self.total_size)
         | 
| 112 | 
            +
                    rng.shuffle(sample_idxs)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # Random feat_idxs (same across each batch)
         | 
| 115 | 
            +
                    n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
         | 
| 116 | 
            +
                    feat_idxs = rng.integers(self.pool_size, size=n_batches)
         | 
| 117 | 
            +
                    feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
         | 
| 118 | 
            +
                    feat_idxs = feat_idxs.ravel()[: self.total_size]
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    # Put them together
         | 
| 121 | 
            +
                    idxs = np.c_[sample_idxs, feat_idxs]  # shape = (total_size, 2)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # Distributed sampler: we select a subset of batches
         | 
| 124 | 
            +
                    # Make sure the slice for each node is aligned with batch_size
         | 
| 125 | 
            +
                    size_per_proc = self.batch_size * (
         | 
| 126 | 
            +
                        (self.total_size + self.world_size * self.batch_size - 1)
         | 
| 127 | 
            +
                        // (self.world_size * self.batch_size)
         | 
| 128 | 
            +
                    )
         | 
| 129 | 
            +
                    idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    yield from (tuple(idx) for idx in idxs)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            class BatchedMultiFeatureRandomSampler:
         | 
| 135 | 
            +
                """
         | 
| 136 | 
            +
                Random sampling under multiple constraints: each sample in the batch has the same features,
         | 
| 137 | 
            +
                which are chosen randomly from known pools of 'features' for each batch.
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                For instance, the 'features' could be the image aspect-ratio and scene type.
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                The index returned is a tuple (sample_idx, feat_idx_1, feat_idx_2, ...).
         | 
| 142 | 
            +
                This sampler ensures that each series of `batch_size` indices has the same feature indices.
         | 
| 143 | 
            +
                """
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def __init__(
         | 
| 146 | 
            +
                    self, dataset, batch_size, pool_sizes, world_size=1, rank=0, drop_last=True
         | 
| 147 | 
            +
                ):
         | 
| 148 | 
            +
                    """
         | 
| 149 | 
            +
                    Args:
         | 
| 150 | 
            +
                        dataset: Dataset to sample from
         | 
| 151 | 
            +
                        batch_size: Number of samples per batch
         | 
| 152 | 
            +
                        pool_sizes: List of integers representing the size of each feature pool
         | 
| 153 | 
            +
                        world_size: Number of distributed processes
         | 
| 154 | 
            +
                        rank: Rank of the current process
         | 
| 155 | 
            +
                        drop_last: Whether to drop the last incomplete batch
         | 
| 156 | 
            +
                    """
         | 
| 157 | 
            +
                    self.batch_size = batch_size
         | 
| 158 | 
            +
                    self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes]
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    self.len_dataset = N = len(dataset)
         | 
| 161 | 
            +
                    self.total_size = round_by(N, batch_size * world_size) if drop_last else N
         | 
| 162 | 
            +
                    assert world_size == 1 or drop_last, (
         | 
| 163 | 
            +
                        "must drop the last batch in distributed mode"
         | 
| 164 | 
            +
                    )
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    # Distributed sampler
         | 
| 167 | 
            +
                    self.world_size = world_size
         | 
| 168 | 
            +
                    self.rank = rank
         | 
| 169 | 
            +
                    self.epoch = None
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def __len__(self):
         | 
| 172 | 
            +
                    """
         | 
| 173 | 
            +
                    Get the length of the sampler.
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    Returns:
         | 
| 176 | 
            +
                        int: The number of samples in the sampler for the current process
         | 
| 177 | 
            +
                    """
         | 
| 178 | 
            +
                    return self.total_size // self.world_size
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def set_epoch(self, epoch):
         | 
| 181 | 
            +
                    """
         | 
| 182 | 
            +
                    Set the epoch for this sampler.
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    This should be called before each epoch to ensure proper shuffling of the data.
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    Args:
         | 
| 187 | 
            +
                        epoch (int): The current epoch number
         | 
| 188 | 
            +
                    """
         | 
| 189 | 
            +
                    self.epoch = epoch
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def __iter__(self):
         | 
| 192 | 
            +
                    """
         | 
| 193 | 
            +
                    Iterator over the indices.
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    This method generates random indices for each batch, ensuring that all samples
         | 
| 196 | 
            +
                    within a batch have the same feature indices for multiple features.
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    Yields:
         | 
| 199 | 
            +
                        tuple: A tuple containing (sample_idx, feat_idx_1, feat_idx_2, ...)
         | 
| 200 | 
            +
                    """
         | 
| 201 | 
            +
                    # Prepare RNG
         | 
| 202 | 
            +
                    if self.epoch is None:
         | 
| 203 | 
            +
                        assert self.world_size == 1 and self.rank == 0, (
         | 
| 204 | 
            +
                            "use set_epoch() if distributed mode is used"
         | 
| 205 | 
            +
                        )
         | 
| 206 | 
            +
                        seed = int(torch.empty((), dtype=torch.int64).random_().item())
         | 
| 207 | 
            +
                    else:
         | 
| 208 | 
            +
                        seed = self.epoch + 777
         | 
| 209 | 
            +
                    rng = np.random.default_rng(seed=seed)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    # Random indices (will restart from 0 if not drop_last)
         | 
| 212 | 
            +
                    sample_idxs = np.arange(self.total_size)
         | 
| 213 | 
            +
                    rng.shuffle(sample_idxs)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    # Random feat_idxs (same across each batch)
         | 
| 216 | 
            +
                    n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    # Generate feature indices for each feature pool
         | 
| 219 | 
            +
                    all_feat_idxs = []
         | 
| 220 | 
            +
                    for pool_size in self.pool_sizes:
         | 
| 221 | 
            +
                        feat_idxs = rng.integers(pool_size, size=n_batches)
         | 
| 222 | 
            +
                        feat_idxs = np.broadcast_to(
         | 
| 223 | 
            +
                            feat_idxs[:, None], (n_batches, self.batch_size)
         | 
| 224 | 
            +
                        )
         | 
| 225 | 
            +
                        feat_idxs = feat_idxs.ravel()[: self.total_size]
         | 
| 226 | 
            +
                        all_feat_idxs.append(feat_idxs)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # Put them together
         | 
| 229 | 
            +
                    idxs = np.column_stack(
         | 
| 230 | 
            +
                        [sample_idxs] + all_feat_idxs
         | 
| 231 | 
            +
                    )  # shape = (total_size, 1 + len(pool_sizes))
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # Distributed sampler: we select a subset of batches
         | 
| 234 | 
            +
                    # Make sure the slice for each node is aligned with batch_size
         | 
| 235 | 
            +
                    size_per_proc = self.batch_size * (
         | 
| 236 | 
            +
                        (self.total_size + self.world_size * self.batch_size - 1)
         | 
| 237 | 
            +
                        // (self.world_size * self.batch_size)
         | 
| 238 | 
            +
                    )
         | 
| 239 | 
            +
                    idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    yield from (tuple(idx) for idx in idxs)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
             | 
| 244 | 
            +
            class DynamicBatchedMultiFeatureRandomSampler:
         | 
| 245 | 
            +
                """
         | 
| 246 | 
            +
                Random sampling under multiple constraints with dynamic batch size:
         | 
| 247 | 
            +
                each sample in the batch has the same features, which are chosen randomly
         | 
| 248 | 
            +
                from known pools of 'features' for each batch.
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                The batch size is dynamically determined based on a specified feature index,
         | 
| 251 | 
            +
                using a direct mapping from feature values to batch sizes.
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                For instance, if one of the features is the number of images in a multi-view set,
         | 
| 254 | 
            +
                you can specify different batch sizes for different numbers of images to optimize
         | 
| 255 | 
            +
                GPU memory usage. This is achieved by using the feature_to_batch_size_map parameter
         | 
| 256 | 
            +
                to directly specify what batch size to use for each feature value.
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                The returned index is a list of tuples [(sample_idx, feat_idx_1, feat_idx_2, ...), ...].
         | 
| 259 | 
            +
                """
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def __init__(
         | 
| 262 | 
            +
                    self,
         | 
| 263 | 
            +
                    dataset,
         | 
| 264 | 
            +
                    pool_sizes,
         | 
| 265 | 
            +
                    scaling_feature_idx=0,
         | 
| 266 | 
            +
                    feature_to_batch_size_map=None,
         | 
| 267 | 
            +
                    world_size=1,
         | 
| 268 | 
            +
                    rank=0,
         | 
| 269 | 
            +
                    drop_last=True,
         | 
| 270 | 
            +
                ):
         | 
| 271 | 
            +
                    """
         | 
| 272 | 
            +
                    Args:
         | 
| 273 | 
            +
                        dataset: Dataset to sample from
         | 
| 274 | 
            +
                        pool_sizes: List of integers representing the size of each feature pool
         | 
| 275 | 
            +
                        scaling_feature_idx: Index of the feature to use for determining batch size (0-based index into pool_sizes)
         | 
| 276 | 
            +
                        feature_to_batch_size_map: Optional function or dict that maps feature values directly to batch sizes.
         | 
| 277 | 
            +
                                             For example, if the feature represents number of views, this maps number of views
         | 
| 278 | 
            +
                                             to appropriate batch size that can fit in GPU memory.
         | 
| 279 | 
            +
                                             If None, uses a default batch size of 1 for all feature values.
         | 
| 280 | 
            +
                        world_size: Number of distributed processes
         | 
| 281 | 
            +
                        rank: Rank of the current process
         | 
| 282 | 
            +
                        drop_last: Whether to drop the last incomplete batch
         | 
| 283 | 
            +
                    """
         | 
| 284 | 
            +
                    self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes]
         | 
| 285 | 
            +
                    self.scaling_feature_idx = scaling_feature_idx
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    # Ensure scaling_feature_idx is valid
         | 
| 288 | 
            +
                    if scaling_feature_idx < 0 or scaling_feature_idx >= len(self.pool_sizes):
         | 
| 289 | 
            +
                        raise ValueError(
         | 
| 290 | 
            +
                            f"scaling_feature_idx must be between 0 and {len(self.pool_sizes) - 1}"
         | 
| 291 | 
            +
                        )
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    # Set up mapping from feature values to batch sizes
         | 
| 294 | 
            +
                    self.feature_to_batch_size_map = feature_to_batch_size_map
         | 
| 295 | 
            +
                    if self.feature_to_batch_size_map is None:
         | 
| 296 | 
            +
                        # Default: batch size of 1 for all feature values
         | 
| 297 | 
            +
                        self.feature_to_batch_size_map = {
         | 
| 298 | 
            +
                            i: 1 for i in range(self.pool_sizes[scaling_feature_idx])
         | 
| 299 | 
            +
                        }
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    self.len_dataset = N = len(dataset)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    # We don't know the exact batch size yet, so we use a large number for total_size
         | 
| 304 | 
            +
                    # This will be adjusted during iteration
         | 
| 305 | 
            +
                    self.total_size = N
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    # Distributed sampler
         | 
| 308 | 
            +
                    self.world_size = world_size
         | 
| 309 | 
            +
                    self.rank = rank
         | 
| 310 | 
            +
                    self.epoch = None
         | 
| 311 | 
            +
                    self.drop_last = drop_last
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                def __len__(self):
         | 
| 314 | 
            +
                    """
         | 
| 315 | 
            +
                    Get the approximate length of the sampler.
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    Since batch size varies, this is an estimate based on the largest batch size
         | 
| 318 | 
            +
                    in the mapping, which provides a lower bound on the number of batches.
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    Returns:
         | 
| 321 | 
            +
                        int: The estimated minimum number of samples in the sampler for the current process
         | 
| 322 | 
            +
                    """
         | 
| 323 | 
            +
                    # Find the largest batch size in the mapping
         | 
| 324 | 
            +
                    if callable(self.feature_to_batch_size_map):
         | 
| 325 | 
            +
                        # If it's a function, sample some values to find the maximum
         | 
| 326 | 
            +
                        batch_sizes = [
         | 
| 327 | 
            +
                            self.feature_to_batch_size_map(i)
         | 
| 328 | 
            +
                            for i in range(self.pool_sizes[self.scaling_feature_idx])
         | 
| 329 | 
            +
                        ]
         | 
| 330 | 
            +
                        max_batch_size = max(batch_sizes)
         | 
| 331 | 
            +
                    else:
         | 
| 332 | 
            +
                        # If it's a dict or similar, find the maximum directly
         | 
| 333 | 
            +
                        max_batch_size = max(self.feature_to_batch_size_map.values())
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    # Ensure minimum batch size of 1
         | 
| 336 | 
            +
                    max_batch_size = max(1, max_batch_size)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    # Estimate total batches using the largest batch size
         | 
| 339 | 
            +
                    # This gives a lower bound on the number of batches
         | 
| 340 | 
            +
                    total_batches = self.total_size // max_batch_size
         | 
| 341 | 
            +
                    if not self.drop_last and self.total_size % max_batch_size > 0:
         | 
| 342 | 
            +
                        total_batches += 1
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    # Distribute among processes
         | 
| 345 | 
            +
                    return total_batches // self.world_size
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                def set_epoch(self, epoch):
         | 
| 348 | 
            +
                    """
         | 
| 349 | 
            +
                    Set the epoch for this sampler.
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    This should be called before each epoch to ensure proper shuffling of the data.
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    Args:
         | 
| 354 | 
            +
                        epoch (int): The current epoch number
         | 
| 355 | 
            +
                    """
         | 
| 356 | 
            +
                    self.epoch = epoch
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                def __iter__(self):
         | 
| 359 | 
            +
                    """
         | 
| 360 | 
            +
                    Iterator over the indices with dynamic batch sizes.
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    This method generates random indices for each batch, ensuring that all samples
         | 
| 363 | 
            +
                    within a batch have the same feature indices for multiple features.
         | 
| 364 | 
            +
                    The batch size is determined directly from the feature_to_batch_size_map.
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    The iterator enforces the length returned by __len__() by stopping after
         | 
| 367 | 
            +
                    exactly that many batches have been yielded for this process.
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    Yields:
         | 
| 370 | 
            +
                        list of tuples: A batch of tuples, each containing (sample_idx, feat_idx_1, feat_idx_2, ...)
         | 
| 371 | 
            +
                    """
         | 
| 372 | 
            +
                    # Prepare RNG
         | 
| 373 | 
            +
                    if self.epoch is None:
         | 
| 374 | 
            +
                        assert self.world_size == 1 and self.rank == 0, (
         | 
| 375 | 
            +
                            "use set_epoch() if distributed mode is used"
         | 
| 376 | 
            +
                        )
         | 
| 377 | 
            +
                        seed = int(torch.empty((), dtype=torch.int64).random_().item())
         | 
| 378 | 
            +
                    else:
         | 
| 379 | 
            +
                        seed = self.epoch + 777
         | 
| 380 | 
            +
                    rng = np.random.default_rng(seed=seed)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    # Random indices for the entire dataset
         | 
| 383 | 
            +
                    sample_idxs = np.arange(self.total_size)
         | 
| 384 | 
            +
                    rng.shuffle(sample_idxs)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    # Get the target number of batches for this process (enforce strict length)
         | 
| 387 | 
            +
                    target_batches_for_process = len(self)
         | 
| 388 | 
            +
                    batches_yielded_for_process = 0
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    # Process indices in batches with dynamic sizing
         | 
| 391 | 
            +
                    idx = 0
         | 
| 392 | 
            +
                    batch_idx = 0  # Track batch index for even distribution
         | 
| 393 | 
            +
                    while idx < len(sample_idxs) and (
         | 
| 394 | 
            +
                        batches_yielded_for_process < target_batches_for_process
         | 
| 395 | 
            +
                    ):
         | 
| 396 | 
            +
                        # Randomly select feature indices for this batch
         | 
| 397 | 
            +
                        feat_idxs = [rng.integers(pool_size) for pool_size in self.pool_sizes]
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                        # Get the scaling feature value
         | 
| 400 | 
            +
                        scaling_feat = feat_idxs[self.scaling_feature_idx]
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                        # Get the batch size directly from the mapping
         | 
| 403 | 
            +
                        if callable(self.feature_to_batch_size_map):
         | 
| 404 | 
            +
                            batch_size = self.feature_to_batch_size_map(scaling_feat)
         | 
| 405 | 
            +
                        else:
         | 
| 406 | 
            +
                            batch_size = self.feature_to_batch_size_map.get(scaling_feat, 1)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                        # Ensure minimum batch size of 1
         | 
| 409 | 
            +
                        batch_size = max(1, batch_size)
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                        # Ensure we don't go beyond available samples
         | 
| 412 | 
            +
                        remaining = len(sample_idxs) - idx
         | 
| 413 | 
            +
                        if remaining < batch_size:
         | 
| 414 | 
            +
                            if self.drop_last:
         | 
| 415 | 
            +
                                break
         | 
| 416 | 
            +
                            batch_size = remaining
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                        # Create batch with consistent feature indices
         | 
| 419 | 
            +
                        batch = []
         | 
| 420 | 
            +
                        for i in range(batch_size):
         | 
| 421 | 
            +
                            if idx + i < len(sample_idxs):
         | 
| 422 | 
            +
                                sample_idx = sample_idxs[idx + i]
         | 
| 423 | 
            +
                                batch.append(tuple([sample_idx] + feat_idxs))
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                        # Distribute batches among processes in round-robin fashion
         | 
| 426 | 
            +
                        if len(batch) > 0 and (batch_idx % self.world_size == self.rank):
         | 
| 427 | 
            +
                            yield batch
         | 
| 428 | 
            +
                            batches_yielded_for_process += 1
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                        batch_idx += 1  # Increment batch index
         | 
| 431 | 
            +
                        idx += batch_size
         | 
    	
        mapanything/datasets/base/easy_dataset.py
    ADDED
    
    | @@ -0,0 +1,478 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            Base dataset class that enables easy resizing and combining
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            References: DUSt3R
         | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from mapanything.datasets.base.batched_sampler import (
         | 
| 15 | 
            +
                BatchedMultiFeatureRandomSampler,
         | 
| 16 | 
            +
                DynamicBatchedMultiFeatureRandomSampler,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class EasyDataset:
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                Dataset that can be easily resized and combined.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                Examples:
         | 
| 25 | 
            +
                ---------
         | 
| 26 | 
            +
                    2 * dataset ==> Duplicate each element 2x
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    10 @ dataset ==> Set the size to 10 (random sampling, duplicates if necessary)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    Dataset1 + Dataset2 ==> Concatenate datasets
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def __add__(self, other):
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    Concatenate this dataset with another dataset.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    Args:
         | 
| 38 | 
            +
                        other (EasyDataset): Another dataset to concatenate with this one
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    Returns:
         | 
| 41 | 
            +
                        CatDataset: A new dataset that is the concatenation of this dataset and the other
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
                    return CatDataset([self, other])
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def __rmul__(self, factor):
         | 
| 46 | 
            +
                    """
         | 
| 47 | 
            +
                    Multiply the dataset by a factor, duplicating each element.
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    Args:
         | 
| 50 | 
            +
                        factor (int): Number of times to duplicate each element
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    Returns:
         | 
| 53 | 
            +
                        MulDataset: A new dataset with each element duplicated 'factor' times
         | 
| 54 | 
            +
                    """
         | 
| 55 | 
            +
                    return MulDataset(factor, self)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __rmatmul__(self, factor):
         | 
| 58 | 
            +
                    """
         | 
| 59 | 
            +
                    Resize the dataset to a specific size using random sampling.
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    Args:
         | 
| 62 | 
            +
                        factor (int): The new size of the dataset
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    Returns:
         | 
| 65 | 
            +
                        ResizedDataset: A new dataset with the specified size
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    return ResizedDataset(factor, self)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def set_epoch(self, epoch):
         | 
| 70 | 
            +
                    """
         | 
| 71 | 
            +
                    Set the current epoch for all constituent datasets.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    Args:
         | 
| 74 | 
            +
                        epoch (int): The current epoch number
         | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    pass  # nothing to do by default
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def make_sampler(
         | 
| 79 | 
            +
                    self,
         | 
| 80 | 
            +
                    batch_size=None,
         | 
| 81 | 
            +
                    shuffle=True,
         | 
| 82 | 
            +
                    world_size=1,
         | 
| 83 | 
            +
                    rank=0,
         | 
| 84 | 
            +
                    drop_last=True,
         | 
| 85 | 
            +
                    max_num_of_images_per_gpu=None,
         | 
| 86 | 
            +
                    use_dynamic_sampler=True,
         | 
| 87 | 
            +
                ):
         | 
| 88 | 
            +
                    """
         | 
| 89 | 
            +
                    Create a sampler for this dataset.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    Args:
         | 
| 92 | 
            +
                        batch_size (int, optional): Number of samples per batch (used for non-dynamic sampler). Defaults to None.
         | 
| 93 | 
            +
                        shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
         | 
| 94 | 
            +
                        world_size (int, optional): Number of distributed processes. Defaults to 1.
         | 
| 95 | 
            +
                        rank (int, optional): Rank of the current process. Defaults to 0.
         | 
| 96 | 
            +
                        drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
         | 
| 97 | 
            +
                        max_num_of_images_per_gpu (int, optional): Maximum number of images per GPU for dynamic batching. Defaults to None.
         | 
| 98 | 
            +
                        use_dynamic_sampler (bool, optional): Whether to use the dynamic sampler. Defaults to True.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    Returns:
         | 
| 101 | 
            +
                        DynamicBatchedMultiFeatureRandomSampler or BatchedMultiFeatureRandomSampler: A sampler for this dataset
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    Raises:
         | 
| 104 | 
            +
                        NotImplementedError: If shuffle is False
         | 
| 105 | 
            +
                        ValueError: If num_views has an invalid type or required parameters are missing
         | 
| 106 | 
            +
                    """
         | 
| 107 | 
            +
                    if not (shuffle):
         | 
| 108 | 
            +
                        raise NotImplementedError()  # cannot deal yet
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if isinstance(self.num_views, int):
         | 
| 111 | 
            +
                        num_of_aspect_ratios = len(self._resolutions)
         | 
| 112 | 
            +
                        feature_pool_sizes = [num_of_aspect_ratios]
         | 
| 113 | 
            +
                        scaling_feature_idx = 0  # Use aspect ratio as scaling feature
         | 
| 114 | 
            +
                    elif isinstance(self.num_views, list):
         | 
| 115 | 
            +
                        num_of_aspect_ratios = len(self._resolutions)
         | 
| 116 | 
            +
                        num_of_num_views = len(self.num_views)
         | 
| 117 | 
            +
                        feature_pool_sizes = [num_of_aspect_ratios, num_of_num_views]
         | 
| 118 | 
            +
                        scaling_feature_idx = 1  # Use num_views as scaling feature
         | 
| 119 | 
            +
                    else:
         | 
| 120 | 
            +
                        raise ValueError(
         | 
| 121 | 
            +
                            f"Bad type for {self.num_views=}, should be int or list of ints"
         | 
| 122 | 
            +
                        )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    if use_dynamic_sampler:
         | 
| 125 | 
            +
                        if max_num_of_images_per_gpu is None:
         | 
| 126 | 
            +
                            raise ValueError(
         | 
| 127 | 
            +
                                "max_num_of_images_per_gpu must be provided when using dynamic sampler"
         | 
| 128 | 
            +
                            )
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        # Create feature-to-batch-size mapping
         | 
| 131 | 
            +
                        if isinstance(self.num_views, list):
         | 
| 132 | 
            +
                            # Map num_views_idx to batch size: max(1, max_num_of_images_per_gpu // (num_views_idx + dataset.num_views_min))
         | 
| 133 | 
            +
                            feature_to_batch_size_map = {}
         | 
| 134 | 
            +
                            for num_views_idx, num_views in enumerate(self.num_views):
         | 
| 135 | 
            +
                                batch_size_for_multi_view_sets = max(
         | 
| 136 | 
            +
                                    1, max_num_of_images_per_gpu // num_views
         | 
| 137 | 
            +
                                )
         | 
| 138 | 
            +
                                feature_to_batch_size_map[num_views_idx] = (
         | 
| 139 | 
            +
                                    batch_size_for_multi_view_sets
         | 
| 140 | 
            +
                                )
         | 
| 141 | 
            +
                        else:
         | 
| 142 | 
            +
                            # For fixed num_views, use a simple mapping
         | 
| 143 | 
            +
                            feature_to_batch_size_map = {
         | 
| 144 | 
            +
                                0: max(1, max_num_of_images_per_gpu // self.num_views)
         | 
| 145 | 
            +
                            }
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        return DynamicBatchedMultiFeatureRandomSampler(
         | 
| 148 | 
            +
                            self,
         | 
| 149 | 
            +
                            pool_sizes=feature_pool_sizes,
         | 
| 150 | 
            +
                            scaling_feature_idx=scaling_feature_idx,
         | 
| 151 | 
            +
                            feature_to_batch_size_map=feature_to_batch_size_map,
         | 
| 152 | 
            +
                            world_size=world_size,
         | 
| 153 | 
            +
                            rank=rank,
         | 
| 154 | 
            +
                            drop_last=drop_last,
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
                    else:
         | 
| 157 | 
            +
                        if batch_size is None:
         | 
| 158 | 
            +
                            raise ValueError(
         | 
| 159 | 
            +
                                "batch_size must be provided when not using dynamic sampler"
         | 
| 160 | 
            +
                            )
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                        return BatchedMultiFeatureRandomSampler(
         | 
| 163 | 
            +
                            self,
         | 
| 164 | 
            +
                            batch_size,
         | 
| 165 | 
            +
                            feature_pool_sizes,
         | 
| 166 | 
            +
                            world_size=world_size,
         | 
| 167 | 
            +
                            rank=rank,
         | 
| 168 | 
            +
                            drop_last=drop_last,
         | 
| 169 | 
            +
                        )
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            class MulDataset(EasyDataset):
         | 
| 173 | 
            +
                """Artificially augmenting the size of a dataset."""
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                multiplicator: int
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def __init__(self, multiplicator, dataset):
         | 
| 178 | 
            +
                    """
         | 
| 179 | 
            +
                    Initialize a dataset that artificially augments the size of another dataset.
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    Args:
         | 
| 182 | 
            +
                        multiplicator (int): Factor by which to multiply the dataset size
         | 
| 183 | 
            +
                        dataset (EasyDataset): The dataset to augment
         | 
| 184 | 
            +
                    """
         | 
| 185 | 
            +
                    assert isinstance(multiplicator, int) and multiplicator > 0
         | 
| 186 | 
            +
                    self.multiplicator = multiplicator
         | 
| 187 | 
            +
                    self.dataset = dataset
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def __len__(self):
         | 
| 190 | 
            +
                    """
         | 
| 191 | 
            +
                    Get the length of the dataset.
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    Returns:
         | 
| 194 | 
            +
                        int: The number of samples in the dataset
         | 
| 195 | 
            +
                    """
         | 
| 196 | 
            +
                    return self.multiplicator * len(self.dataset)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def __repr__(self):
         | 
| 199 | 
            +
                    """
         | 
| 200 | 
            +
                    Get a string representation of the dataset.
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    Returns:
         | 
| 203 | 
            +
                        str: String representation showing the multiplication factor and the original dataset
         | 
| 204 | 
            +
                    """
         | 
| 205 | 
            +
                    return f"{self.multiplicator}*{repr(self.dataset)}"
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                def __getitem__(self, idx):
         | 
| 208 | 
            +
                    """
         | 
| 209 | 
            +
                    Get an item from the dataset.
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    Args:
         | 
| 212 | 
            +
                        idx: Index or tuple of indices to retrieve
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    Returns:
         | 
| 215 | 
            +
                        The item at the specified index from the original dataset
         | 
| 216 | 
            +
                    """
         | 
| 217 | 
            +
                    if isinstance(idx, tuple):
         | 
| 218 | 
            +
                        other = idx[1:]
         | 
| 219 | 
            +
                        idx = idx[0]
         | 
| 220 | 
            +
                        new_idx = (idx // self.multiplicator, *other)
         | 
| 221 | 
            +
                        return self.dataset[new_idx]
         | 
| 222 | 
            +
                    else:
         | 
| 223 | 
            +
                        return self.dataset[idx // self.multiplicator]
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                @property
         | 
| 226 | 
            +
                def _resolutions(self):
         | 
| 227 | 
            +
                    """
         | 
| 228 | 
            +
                    Get the resolutions of the dataset.
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    Returns:
         | 
| 231 | 
            +
                        The resolutions from the original dataset
         | 
| 232 | 
            +
                    """
         | 
| 233 | 
            +
                    return self.dataset._resolutions
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                @property
         | 
| 236 | 
            +
                def num_views(self):
         | 
| 237 | 
            +
                    """
         | 
| 238 | 
            +
                    Get the number of views used for the dataset.
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    Returns:
         | 
| 241 | 
            +
                        int or list: The number of views parameter from the original dataset
         | 
| 242 | 
            +
                    """
         | 
| 243 | 
            +
                    return self.dataset.num_views
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
            class ResizedDataset(EasyDataset):
         | 
| 247 | 
            +
                """Artificially changing the size of a dataset."""
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                new_size: int
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                def __init__(self, new_size, dataset):
         | 
| 252 | 
            +
                    """
         | 
| 253 | 
            +
                    Initialize a dataset with an artificially changed size.
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    Args:
         | 
| 256 | 
            +
                        new_size (int): The new size of the dataset
         | 
| 257 | 
            +
                        dataset (EasyDataset): The original dataset
         | 
| 258 | 
            +
                    """
         | 
| 259 | 
            +
                    assert isinstance(new_size, int) and new_size > 0
         | 
| 260 | 
            +
                    self.new_size = new_size
         | 
| 261 | 
            +
                    self.dataset = dataset
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                def __len__(self):
         | 
| 264 | 
            +
                    """
         | 
| 265 | 
            +
                    Get the length of the dataset.
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    Returns:
         | 
| 268 | 
            +
                        int: The new size of the dataset
         | 
| 269 | 
            +
                    """
         | 
| 270 | 
            +
                    return self.new_size
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                def __repr__(self):
         | 
| 273 | 
            +
                    """
         | 
| 274 | 
            +
                    Get a string representation of the dataset.
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    Returns:
         | 
| 277 | 
            +
                        str: String representation showing the new size and the original dataset
         | 
| 278 | 
            +
                    """
         | 
| 279 | 
            +
                    size_str = str(self.new_size)
         | 
| 280 | 
            +
                    for i in range((len(size_str) - 1) // 3):
         | 
| 281 | 
            +
                        sep = -4 * i - 3
         | 
| 282 | 
            +
                        size_str = size_str[:sep] + "_" + size_str[sep:]
         | 
| 283 | 
            +
                    return f"{size_str} @ {repr(self.dataset)}"
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                def set_epoch(self, epoch):
         | 
| 286 | 
            +
                    """
         | 
| 287 | 
            +
                    Set the current epoch and generate a new random mapping of indices.
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    This method must be called before using __getitem__.
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    Args:
         | 
| 292 | 
            +
                        epoch (int): The current epoch number
         | 
| 293 | 
            +
                    """
         | 
| 294 | 
            +
                    # This random shuffle only depends on the epoch
         | 
| 295 | 
            +
                    rng = np.random.default_rng(seed=epoch + 777)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    # Shuffle all indices
         | 
| 298 | 
            +
                    perm = rng.permutation(len(self.dataset))
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    # Calculate how many repetitions we need
         | 
| 301 | 
            +
                    num_repetitions = 1 + (len(self) - 1) // len(self.dataset)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    # Rotary extension until target size is met
         | 
| 304 | 
            +
                    shuffled_idxs = np.concatenate([perm] * num_repetitions)
         | 
| 305 | 
            +
                    self._idxs_mapping = shuffled_idxs[: self.new_size]
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    # Generate the seed offset for each repetition
         | 
| 308 | 
            +
                    # This is needed to ensure we see unique samples when we repeat a scene
         | 
| 309 | 
            +
                    seed_offset_per_repetition = [
         | 
| 310 | 
            +
                        np.full(len(self.dataset), i) for i in range(num_repetitions)
         | 
| 311 | 
            +
                    ]
         | 
| 312 | 
            +
                    seed_offset_idxs = np.concatenate(seed_offset_per_repetition)
         | 
| 313 | 
            +
                    self._idxs_seed_offset = seed_offset_idxs[: self.new_size]
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    assert len(self._idxs_mapping) == self.new_size
         | 
| 316 | 
            +
                    assert len(self._idxs_seed_offset) == self.new_size
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                def __getitem__(self, idx):
         | 
| 319 | 
            +
                    """
         | 
| 320 | 
            +
                    Get an item from the dataset.
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    Args:
         | 
| 323 | 
            +
                        idx: Index or tuple of indices to retrieve
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    Returns:
         | 
| 326 | 
            +
                        The item at the mapped index from the original dataset
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    Raises:
         | 
| 329 | 
            +
                        AssertionError: If set_epoch has not been called
         | 
| 330 | 
            +
                    """
         | 
| 331 | 
            +
                    assert hasattr(self, "_idxs_mapping"), (
         | 
| 332 | 
            +
                        "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()"
         | 
| 333 | 
            +
                    )
         | 
| 334 | 
            +
                    if isinstance(idx, tuple):
         | 
| 335 | 
            +
                        other = idx[1:]
         | 
| 336 | 
            +
                        idx = idx[0]
         | 
| 337 | 
            +
                        self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
         | 
| 338 | 
            +
                        new_idx = (self._idxs_mapping[idx], *other)
         | 
| 339 | 
            +
                        return self.dataset[new_idx]
         | 
| 340 | 
            +
                    else:
         | 
| 341 | 
            +
                        self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
         | 
| 342 | 
            +
                        return self.dataset[self._idxs_mapping[idx]]
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                @property
         | 
| 345 | 
            +
                def _resolutions(self):
         | 
| 346 | 
            +
                    """
         | 
| 347 | 
            +
                    Get the resolutions of the dataset.
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    Returns:
         | 
| 350 | 
            +
                        The resolutions from the original dataset
         | 
| 351 | 
            +
                    """
         | 
| 352 | 
            +
                    return self.dataset._resolutions
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                @property
         | 
| 355 | 
            +
                def num_views(self):
         | 
| 356 | 
            +
                    """
         | 
| 357 | 
            +
                    Get the number of views used for the dataset.
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    Returns:
         | 
| 360 | 
            +
                        int or list: The number of views parameter from the original dataset
         | 
| 361 | 
            +
                    """
         | 
| 362 | 
            +
                    return self.dataset.num_views
         | 
| 363 | 
            +
             | 
| 364 | 
            +
             | 
| 365 | 
            +
            class CatDataset(EasyDataset):
         | 
| 366 | 
            +
                """Concatenation of several datasets"""
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                def __init__(self, datasets):
         | 
| 369 | 
            +
                    """
         | 
| 370 | 
            +
                    Initialize a dataset that is a concatenation of several datasets.
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    Args:
         | 
| 373 | 
            +
                        datasets (list): List of EasyDataset instances to concatenate
         | 
| 374 | 
            +
                    """
         | 
| 375 | 
            +
                    for dataset in datasets:
         | 
| 376 | 
            +
                        assert isinstance(dataset, EasyDataset)
         | 
| 377 | 
            +
                    self.datasets = datasets
         | 
| 378 | 
            +
                    self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                def __len__(self):
         | 
| 381 | 
            +
                    """
         | 
| 382 | 
            +
                    Get the length of the concatenated dataset.
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    Returns:
         | 
| 385 | 
            +
                        int: Total number of samples across all datasets
         | 
| 386 | 
            +
                    """
         | 
| 387 | 
            +
                    return self._cum_sizes[-1]
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                def __repr__(self):
         | 
| 390 | 
            +
                    """
         | 
| 391 | 
            +
                    Get a string representation of the concatenated dataset.
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    Returns:
         | 
| 394 | 
            +
                        str: String representation showing all concatenated datasets joined by '+'
         | 
| 395 | 
            +
                    """
         | 
| 396 | 
            +
                    # Remove uselessly long transform
         | 
| 397 | 
            +
                    return " + ".join(
         | 
| 398 | 
            +
                        repr(dataset).replace(
         | 
| 399 | 
            +
                            ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))",
         | 
| 400 | 
            +
                            "",
         | 
| 401 | 
            +
                        )
         | 
| 402 | 
            +
                        for dataset in self.datasets
         | 
| 403 | 
            +
                    )
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                def set_epoch(self, epoch):
         | 
| 406 | 
            +
                    """
         | 
| 407 | 
            +
                    Set the current epoch for all constituent datasets.
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    Args:
         | 
| 410 | 
            +
                        epoch (int): The current epoch number
         | 
| 411 | 
            +
                    """
         | 
| 412 | 
            +
                    for dataset in self.datasets:
         | 
| 413 | 
            +
                        dataset.set_epoch(epoch)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                def __getitem__(self, idx):
         | 
| 416 | 
            +
                    """
         | 
| 417 | 
            +
                    Get an item from the concatenated dataset.
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    Args:
         | 
| 420 | 
            +
                        idx: Index or tuple of indices to retrieve
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    Returns:
         | 
| 423 | 
            +
                        The item at the specified index from the appropriate constituent dataset
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    Raises:
         | 
| 426 | 
            +
                        IndexError: If the index is out of range
         | 
| 427 | 
            +
                    """
         | 
| 428 | 
            +
                    other = None
         | 
| 429 | 
            +
                    if isinstance(idx, tuple):
         | 
| 430 | 
            +
                        other = idx[1:]
         | 
| 431 | 
            +
                        idx = idx[0]
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    if not (0 <= idx < len(self)):
         | 
| 434 | 
            +
                        raise IndexError()
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    db_idx = np.searchsorted(self._cum_sizes, idx, "right")
         | 
| 437 | 
            +
                    dataset = self.datasets[db_idx]
         | 
| 438 | 
            +
                    new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    if other is not None:
         | 
| 441 | 
            +
                        new_idx = (new_idx, *other)
         | 
| 442 | 
            +
                    return dataset[new_idx]
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                @property
         | 
| 445 | 
            +
                def _resolutions(self):
         | 
| 446 | 
            +
                    """
         | 
| 447 | 
            +
                    Get the resolutions of the dataset.
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    Returns:
         | 
| 450 | 
            +
                        The resolutions from the first dataset (all datasets must have the same resolutions)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                    Raises:
         | 
| 453 | 
            +
                        AssertionError: If datasets have different resolutions
         | 
| 454 | 
            +
                    """
         | 
| 455 | 
            +
                    resolutions = self.datasets[0]._resolutions
         | 
| 456 | 
            +
                    for dataset in self.datasets[1:]:
         | 
| 457 | 
            +
                        assert tuple(dataset._resolutions) == tuple(resolutions), (
         | 
| 458 | 
            +
                            "All datasets must have the same resolutions"
         | 
| 459 | 
            +
                        )
         | 
| 460 | 
            +
                    return resolutions
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                @property
         | 
| 463 | 
            +
                def num_views(self):
         | 
| 464 | 
            +
                    """
         | 
| 465 | 
            +
                    Get the number of views used for the dataset.
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    Returns:
         | 
| 468 | 
            +
                        int or list: The number of views parameter from the first dataset
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    Raises:
         | 
| 471 | 
            +
                        AssertionError: If datasets have different num_views
         | 
| 472 | 
            +
                    """
         | 
| 473 | 
            +
                    num_views = self.datasets[0].num_views
         | 
| 474 | 
            +
                    for dataset in self.datasets[1:]:
         | 
| 475 | 
            +
                        assert dataset.num_views == num_views, (
         | 
| 476 | 
            +
                            "All datasets must have the same num_views and variable_num_views parameters"
         | 
| 477 | 
            +
                        )
         | 
| 478 | 
            +
                    return num_views
         | 
    	
        mapanything/datasets/utils/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        mapanything/datasets/utils/data_splits.py
    ADDED
    
    | @@ -0,0 +1,1734 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            Modules containing dataset split information
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class BlendedMVSSplits:
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                This class contains the information about the BlendedMVS dataset splits.
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self):
         | 
| 17 | 
            +
                    """
         | 
| 18 | 
            +
                    The splits are generated using the following logic:
         | 
| 19 | 
            +
                        # Get all seqls and seqhs using self.blendedmvs_info.all_sequences
         | 
| 20 | 
            +
                        all_sequences = self.blendedmvs_info.all_sequences
         | 
| 21 | 
            +
                        all_seqls = [int(seq[8:], 16) for seq in all_sequences]
         | 
| 22 | 
            +
                        all_seqhs = [int(seq[:8], 16) for seq in all_sequences]
         | 
| 23 | 
            +
                        # Split the seqls (& corresponding seqhs) using the DUSt3R train/val split logic
         | 
| 24 | 
            +
                        if split is None:
         | 
| 25 | 
            +
                            selection = slice(None)
         | 
| 26 | 
            +
                        elif split in ["train", "overfit"]:
         | 
| 27 | 
            +
                            # select 90% of all scenes
         | 
| 28 | 
            +
                            selection = [(seql % 10) > 0 for seql in all_seqls]
         | 
| 29 | 
            +
                        elif split == "val":
         | 
| 30 | 
            +
                            # select 10% of all scenes
         | 
| 31 | 
            +
                            selection = [(seql % 10) == 0 for seql in all_seqls]
         | 
| 32 | 
            +
                        else:
         | 
| 33 | 
            +
                            raise ValueError(f"Unknown split {split}, must be None, train, val or overfit")
         | 
| 34 | 
            +
                        # Filter sequences based on the selection
         | 
| 35 | 
            +
                        selected_seqls = [seql for seql, sel in zip(all_seqls, selection) if sel]
         | 
| 36 | 
            +
                        selected_seqhs = [seqh for seqh, sel in zip(all_seqhs, selection) if sel]
         | 
| 37 | 
            +
                        # Put them back into sequence names f"{seqh:08x}{seql:016x}"
         | 
| 38 | 
            +
                        sequence_names = [f"{seqh:08x}{seql:016x}" for seqh, seql in zip(selected_seqhs, selected_seqls)]
         | 
| 39 | 
            +
                        # Remove invalid sequence names which don't exist in self.blendedmvs_info.sequences
         | 
| 40 | 
            +
                        valid_sequences = set(self.blendedmvs_info.sequences)
         | 
| 41 | 
            +
                        valid_sequence_names = [name for name in sequence_names if name in valid_sequences]
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
                    # All the 502 sequences in the dataset (totals to 115k images)
         | 
| 44 | 
            +
                    self.all_scenes = [
         | 
| 45 | 
            +
                        "000000000000000000000000",
         | 
| 46 | 
            +
                        "00000000000000000000000a",
         | 
| 47 | 
            +
                        "00000000000000000000000b",
         | 
| 48 | 
            +
                        "00000000000000000000000c",
         | 
| 49 | 
            +
                        "00000000000000000000000d",
         | 
| 50 | 
            +
                        "00000000000000000000000e",
         | 
| 51 | 
            +
                        "00000000000000000000000f",
         | 
| 52 | 
            +
                        "000000000000000000000001",
         | 
| 53 | 
            +
                        "00000000000000000000001a",
         | 
| 54 | 
            +
                        "00000000000000000000001b",
         | 
| 55 | 
            +
                        "00000000000000000000001d",
         | 
| 56 | 
            +
                        "000000000000000000000002",
         | 
| 57 | 
            +
                        "000000000000000000000003",
         | 
| 58 | 
            +
                        "000000000000000000000004",
         | 
| 59 | 
            +
                        "000000000000000000000005",
         | 
| 60 | 
            +
                        "5a2a95f032a1c655cfe3de62",
         | 
| 61 | 
            +
                        "5a2af22b32a1c655cfe46013",
         | 
| 62 | 
            +
                        "5a2ba6de32a1c655cfe51b79",
         | 
| 63 | 
            +
                        "5a3b9731e24cd76dad1a5f1b",
         | 
| 64 | 
            +
                        "5a3ca9cb270f0e3f14d0eddb",
         | 
| 65 | 
            +
                        "5a3cb4e4270f0e3f14d12f43",
         | 
| 66 | 
            +
                        "5a03e732454a8a7ec672776c",
         | 
| 67 | 
            +
                        "5a3f4aba5889373fbbc5d3b5",
         | 
| 68 | 
            +
                        "5a4a38dad38c8a075495b5d2",
         | 
| 69 | 
            +
                        "5a5a1e48d62c7a12d5d00e47",
         | 
| 70 | 
            +
                        "5a6b1c418d100c2f8fdc4411",
         | 
| 71 | 
            +
                        "5a6feeb54a7fbc3f874f9db7",
         | 
| 72 | 
            +
                        "5a7cb1d6fe5c0d6fb53e64fb",
         | 
| 73 | 
            +
                        "5a7d3db14989e929563eb153",
         | 
| 74 | 
            +
                        "5a8aa0fab18050187cbe060e",
         | 
| 75 | 
            +
                        "5a9e5df65baeef72b4a021cd",
         | 
| 76 | 
            +
                        "5a48ba95c7dab83a7d7b44ed",
         | 
| 77 | 
            +
                        "5a48c4e9c7dab83a7d7b5cc7",
         | 
| 78 | 
            +
                        "5a48d4b2c7dab83a7d7b9851",
         | 
| 79 | 
            +
                        "5a69c47d0d5d0a7f3b2e9752",
         | 
| 80 | 
            +
                        "5a77b46b318efe6c6736e68a",
         | 
| 81 | 
            +
                        "5a355c271b63f53d5970f362",
         | 
| 82 | 
            +
                        "5a489fb1c7dab83a7d7b1070",
         | 
| 83 | 
            +
                        "5a533e8034d7582116e34209",
         | 
| 84 | 
            +
                        "5a562fc7425d0f5186314725",
         | 
| 85 | 
            +
                        "5a572fd9fc597b0478a81d14",
         | 
| 86 | 
            +
                        "5a588a8193ac3d233f77fbca",
         | 
| 87 | 
            +
                        "5a618c72784780334bc1972d",
         | 
| 88 | 
            +
                        "5a752d42acc41e2423f17674",
         | 
| 89 | 
            +
                        "5a969eea91dfc339a9a3ad2c",
         | 
| 90 | 
            +
                        "5a8315f624b8e938486e0bd8",
         | 
| 91 | 
            +
                        "5a57542f333d180827dfc132",
         | 
| 92 | 
            +
                        "5a0271884e62597cdee0d0eb",
         | 
| 93 | 
            +
                        "5a6400933d809f1d8200af15",
         | 
| 94 | 
            +
                        "5a6464143d809f1d8208c43c",
         | 
| 95 | 
            +
                        "5a563183425d0f5186314855",
         | 
| 96 | 
            +
                        "5aa0f9d7a9efce63548c69a1",
         | 
| 97 | 
            +
                        "5aa0f478a9efce63548c1cb4",
         | 
| 98 | 
            +
                        "5aa7db90bfdd572271e95246",
         | 
| 99 | 
            +
                        "5aa235f64a17b335eeaf9609",
         | 
| 100 | 
            +
                        "5aa515e613d42d091d29d300",
         | 
| 101 | 
            +
                        "5aa1196ea9efce63548ed649",
         | 
| 102 | 
            +
                        "5aaadd4cbc13235570d178a7",
         | 
| 103 | 
            +
                        "5ab6af12ac4291329b1072ab",
         | 
| 104 | 
            +
                        "5ab7e00aac4291329b15864d",
         | 
| 105 | 
            +
                        "5ab8b8e029f5351f7f2ccf59",
         | 
| 106 | 
            +
                        "5ab74bf2ac4291329b11e879",
         | 
| 107 | 
            +
                        "5ab85f1dac4291329b17cb50",
         | 
| 108 | 
            +
                        "5ab8713ba3799a1d138bd69a",
         | 
| 109 | 
            +
                        "5abc2506b53b042ead637d86",
         | 
| 110 | 
            +
                        "5acc7459a7853c4b5ebbef59",
         | 
| 111 | 
            +
                        "5acf8ca0f3d8a750097e4b15",
         | 
| 112 | 
            +
                        "5adc6bd52430a05ecb2ffb85",
         | 
| 113 | 
            +
                        "5ae2e9c5fe405c5076abc6b2",
         | 
| 114 | 
            +
                        "5af02e904c8216544b4ab5a2",
         | 
| 115 | 
            +
                        "5af28cea59bc705737003253",
         | 
| 116 | 
            +
                        "5af545d0559359053d25dcf5",
         | 
| 117 | 
            +
                        "5afacb69ab00705d0cefdd5b",
         | 
| 118 | 
            +
                        "5b2c67b5e0878c381608b8d8",
         | 
| 119 | 
            +
                        "5b3b2b9e8d46a939f933fdc0",
         | 
| 120 | 
            +
                        "5b3b353d8d46a939f93524b9",
         | 
| 121 | 
            +
                        "5b6e716d67b396324c2d77cb",
         | 
| 122 | 
            +
                        "5b6eff8b67b396324c5b2672",
         | 
| 123 | 
            +
                        "5b7a3890fc8fcf6781e2593a",
         | 
| 124 | 
            +
                        "5b21e18c58e2823a67a10dd8",
         | 
| 125 | 
            +
                        "5b60fa0c764f146feef84df0",
         | 
| 126 | 
            +
                        "5b69cc0cb44b61786eb959bf",
         | 
| 127 | 
            +
                        "5b78e57afc8fcf6781d0c3ba",
         | 
| 128 | 
            +
                        "5b192eb2170cf166458ff886",
         | 
| 129 | 
            +
                        "5b558a928bbfb62204e77ba2",
         | 
| 130 | 
            +
                        "5b864d850d072a699b32f4ae",
         | 
| 131 | 
            +
                        "5b908d3dc6ab78485f3d24a9",
         | 
| 132 | 
            +
                        "5b950c71608de421b1e7318f",
         | 
| 133 | 
            +
                        "5b4933abf2b5f44e95de482a",
         | 
| 134 | 
            +
                        "5b08286b2775267d5b0634ba",
         | 
| 135 | 
            +
                        "5b37189a35304b6f75e7583e",
         | 
| 136 | 
            +
                        "5b271079e0878c3816dacca4",
         | 
| 137 | 
            +
                        "5b22269758e2823a67a3bd03",
         | 
| 138 | 
            +
                        "5b62647143840965efc0dbde",
         | 
| 139 | 
            +
                        "5ba19a8a360c7c30c1c169df",
         | 
| 140 | 
            +
                        "5ba75d79d76ffa2c86cf2f05",
         | 
| 141 | 
            +
                        "5bb7a08aea1cfa39f1a947ab",
         | 
| 142 | 
            +
                        "5bb8a49aea1cfa39f1aa7f75",
         | 
| 143 | 
            +
                        "5bbb6eb2ea1cfa39f1af7e0c",
         | 
| 144 | 
            +
                        "5bc5f0e896b66a2cd8f9bd36",
         | 
| 145 | 
            +
                        "5bccd6beca24970bce448134",
         | 
| 146 | 
            +
                        "5bce7ac9ca24970bce4934b6",
         | 
| 147 | 
            +
                        "5bcf979a6d5f586b95c258cd",
         | 
| 148 | 
            +
                        "5bd43b4ba6b28b1ee86b92dd",
         | 
| 149 | 
            +
                        "5be3a5fb8cfdd56947f6b67c",
         | 
| 150 | 
            +
                        "5be3ae47f44e235bdbbc9771",
         | 
| 151 | 
            +
                        "5be4ab93870d330ff2dce134",
         | 
| 152 | 
            +
                        "5be47bf9b18881428d8fbc1d",
         | 
| 153 | 
            +
                        "5be883a4f98cee15019d5b83",
         | 
| 154 | 
            +
                        "5bea87f4abd34c35e1860ab5",
         | 
| 155 | 
            +
                        "5beb6e66abd34c35e18e66b9",
         | 
| 156 | 
            +
                        "5bf3a82cd439231948877aed",
         | 
| 157 | 
            +
                        "5bf7d63575c26f32dbf7413b",
         | 
| 158 | 
            +
                        "5bf17c0fd439231948355385",
         | 
| 159 | 
            +
                        "5bf26cbbd43923194854b270",
         | 
| 160 | 
            +
                        "5bf03590d4392319481971dc",
         | 
| 161 | 
            +
                        "5bf18642c50e6f7f8bdbd492",
         | 
| 162 | 
            +
                        "5bf21799d43923194842c001",
         | 
| 163 | 
            +
                        "5bfc9d5aec61ca1dd69132a2",
         | 
| 164 | 
            +
                        "5bfd0f32ec61ca1dd69dc77b",
         | 
| 165 | 
            +
                        "5bfe5ae0fe0ea555e6a969ca",
         | 
| 166 | 
            +
                        "5bff3c5cfe0ea555e6bcbf3a",
         | 
| 167 | 
            +
                        "5c0d13b795da9479e12e2ee9",
         | 
| 168 | 
            +
                        "5c1af2e2bee9a723c963d019",
         | 
| 169 | 
            +
                        "5c1b1500bee9a723c96c3e78",
         | 
| 170 | 
            +
                        "5c1dbf200843bc542d8ef8c4",
         | 
| 171 | 
            +
                        "5c1f33f1d33e1f2e4aa6dda4",
         | 
| 172 | 
            +
                        "5c2b3ed5e611832e8aed46bf",
         | 
| 173 | 
            +
                        "5c20ca3a0843bc542d94e3e2",
         | 
| 174 | 
            +
                        "5c062d84a96e33018ff6f0a6",
         | 
| 175 | 
            +
                        "5c189f2326173c3a09ed7ef3",
         | 
| 176 | 
            +
                        "5c1892f726173c3a09ea9aeb",
         | 
| 177 | 
            +
                        "5c34300a73a8df509add216d",
         | 
| 178 | 
            +
                        "5c34529873a8df509ae57b58",
         | 
| 179 | 
            +
                        "000000000000000000000006",
         | 
| 180 | 
            +
                        "000000000000000000000007",
         | 
| 181 | 
            +
                        "000000000000000000000008",
         | 
| 182 | 
            +
                        "000000000000000000000009",
         | 
| 183 | 
            +
                        "000000000000000000000010",
         | 
| 184 | 
            +
                        "000000000000000000000011",
         | 
| 185 | 
            +
                        "000000000000000000000012",
         | 
| 186 | 
            +
                        "000000000000000000000015",
         | 
| 187 | 
            +
                        "000000000000000000000016",
         | 
| 188 | 
            +
                        "000000000000000000000017",
         | 
| 189 | 
            +
                        "000000000000000000000018",
         | 
| 190 | 
            +
                        "000000000000000000000019",
         | 
| 191 | 
            +
                        "56d73ba74bd29b8c35abade2",
         | 
| 192 | 
            +
                        "56f34064e296120e10484dc4",
         | 
| 193 | 
            +
                        "57a4a7bb6b9272286e26dc18",
         | 
| 194 | 
            +
                        "57f8d9bbe73f6760f10e916a",
         | 
| 195 | 
            +
                        "58a0a2f33d0b4542479a11b1",
         | 
| 196 | 
            +
                        "58a0dd1a3d0b4542479a28f3",
         | 
| 197 | 
            +
                        "58a1a7914a4d262a170b1101",
         | 
| 198 | 
            +
                        "58a1bc804a4d262a170b2f01",
         | 
| 199 | 
            +
                        "58a1d9d14a4d262a170b58fe",
         | 
| 200 | 
            +
                        "58a01dea38486e3c98475871",
         | 
| 201 | 
            +
                        "58a1f5d74a4d262a170b65fc",
         | 
| 202 | 
            +
                        "58a2a09e156b87103d3d668c",
         | 
| 203 | 
            +
                        "58a2d9c3156b87103d3da90f",
         | 
| 204 | 
            +
                        "58a3ccb0156b87103d3e4332",
         | 
| 205 | 
            +
                        "58a3f2f8156b87103d3e5838",
         | 
| 206 | 
            +
                        "58a3f6c0156b87103d3e5971",
         | 
| 207 | 
            +
                        "58a3fc95156b87103d3e5d9b",
         | 
| 208 | 
            +
                        "58a07ce53d0b45424799fdde",
         | 
| 209 | 
            +
                        "58a07f233d0b45424799ffe7",
         | 
| 210 | 
            +
                        "58a44df2156b87103d3ee239",
         | 
| 211 | 
            +
                        "58a164f73d0b4542479a7a8e",
         | 
| 212 | 
            +
                        "58a0365e38486e3c984783eb",
         | 
| 213 | 
            +
                        "58a439cf156b87103d3ec885",
         | 
| 214 | 
            +
                        "58a464aa156b87103d3eec04",
         | 
| 215 | 
            +
                        "58a4452f156b87103d3ed55b",
         | 
| 216 | 
            +
                        "58a160983d0b4542479a7347",
         | 
| 217 | 
            +
                        "58a186444a4d262a170ae3ae",
         | 
| 218 | 
            +
                        "58a285424a4d262a170baf3e",
         | 
| 219 | 
            +
                        "58a41819156b87103d3e92a5",
         | 
| 220 | 
            +
                        "58a44463156b87103d3ed45e",
         | 
| 221 | 
            +
                        "58a47552156b87103d3f00a4",
         | 
| 222 | 
            +
                        "58c4bb4f4a69c55606122be4",
         | 
| 223 | 
            +
                        "58c6451e4a69c556061894f1",
         | 
| 224 | 
            +
                        "58ca7014affdfd07c70a95ce",
         | 
| 225 | 
            +
                        "58cf4771d0f5fb221defe6da",
         | 
| 226 | 
            +
                        "58d36897f387231e6c929903",
         | 
| 227 | 
            +
                        "58eaf1513353456af3a1682a",
         | 
| 228 | 
            +
                        "58f7f7299f5b5647873cb110",
         | 
| 229 | 
            +
                        "58f73e7c9f5b56478738929f",
         | 
| 230 | 
            +
                        "59a8f851597729752c31e7e0",
         | 
| 231 | 
            +
                        "59a452bf9b460239aa5d1c72",
         | 
| 232 | 
            +
                        "59a9619a825418241fb88191",
         | 
| 233 | 
            +
                        "59acd2f4b891807f439c8992",
         | 
| 234 | 
            +
                        "59bf97fe7e7b31545da34439",
         | 
| 235 | 
            +
                        "59c1c3e2fd6e3d4ead9f1013",
         | 
| 236 | 
            +
                        "59d2657f82ca7774b1ec081d",
         | 
| 237 | 
            +
                        "59da1fb88a126011d0394ae9",
         | 
| 238 | 
            +
                        "59e75a2ca9e91f2c5526005d",
         | 
| 239 | 
            +
                        "59e864b2a9e91f2c5529325f",
         | 
| 240 | 
            +
                        "59ecfd02e225f6492d20fcc9",
         | 
| 241 | 
            +
                        "59f37f74b45be2233001ba18",
         | 
| 242 | 
            +
                        "59f70ab1e5c5d366af29bf3e",
         | 
| 243 | 
            +
                        "59f87d0bfa6280566fb38c9a",
         | 
| 244 | 
            +
                        "59f363a8b45be22330016cad",
         | 
| 245 | 
            +
                        "564a27b26d07883f460d8ab0",
         | 
| 246 | 
            +
                        "565fb1dead14d4154dae2b94",
         | 
| 247 | 
            +
                        "567a0fb0a825d2fb79ac9a20",
         | 
| 248 | 
            +
                        "569b92eb826bcba945ca002b",
         | 
| 249 | 
            +
                        "576fefa017ce5a16397e87fd",
         | 
| 250 | 
            +
                        "584a7333fe3cb463906c9fe6",
         | 
| 251 | 
            +
                        "584aa8e9fe3cb463906cc7d0",
         | 
| 252 | 
            +
                        "584ad76bfe3cb463906ce6dc",
         | 
| 253 | 
            +
                        "584af003fe3cb463906d0e9b",
         | 
| 254 | 
            +
                        "584b9a747072670e72bfc49d",
         | 
| 255 | 
            +
                        "584b671f7072670e72bfaaf8",
         | 
| 256 | 
            +
                        "584b81747072670e72bfbbfd",
         | 
| 257 | 
            +
                        "584ba35f7072670e72bfca4d",
         | 
| 258 | 
            +
                        "584ba5977072670e72bfcc2d",
         | 
| 259 | 
            +
                        "584bc53c7072670e72bfe85f",
         | 
| 260 | 
            +
                        "584bc3997072670e72bfe58d",
         | 
| 261 | 
            +
                        "584bc4407072670e72bfe665",
         | 
| 262 | 
            +
                        "584bd5587072670e72bffe39",
         | 
| 263 | 
            +
                        "584bdadf7072670e72c0005c",
         | 
| 264 | 
            +
                        "584be5ed7072670e72c007b3",
         | 
| 265 | 
            +
                        "584c9ad27072670e72c060c5",
         | 
| 266 | 
            +
                        "584c9cc67072670e72c063a1",
         | 
| 267 | 
            +
                        "584c58b77072670e72c03990",
         | 
| 268 | 
            +
                        "584cea557072670e72c07fb4",
         | 
| 269 | 
            +
                        "584d19d47072670e72c0c6c0",
         | 
| 270 | 
            +
                        "584dfe467072670e72c1665a",
         | 
| 271 | 
            +
                        "584e875c7072670e72c1ec94",
         | 
| 272 | 
            +
                        "584e05667072670e72c17167",
         | 
| 273 | 
            +
                        "584f94e87072670e72c2d3f7",
         | 
| 274 | 
            +
                        "584fdffd7072670e72c32dc7",
         | 
| 275 | 
            +
                        "584fe07f7072670e72c32e59",
         | 
| 276 | 
            +
                        "585a2a71b338a62ad50138dc",
         | 
| 277 | 
            +
                        "585a206ab338a62ad501298f",
         | 
| 278 | 
            +
                        "585a217cb338a62ad5012b38",
         | 
| 279 | 
            +
                        "585b34afb338a62ad501e836",
         | 
| 280 | 
            +
                        "585bb25fc49c8507c3ce7812",
         | 
| 281 | 
            +
                        "585bbe55c49c8507c3ce81cd",
         | 
| 282 | 
            +
                        "585d6c8a2a57cc11d4920a1e",
         | 
| 283 | 
            +
                        "585e54c72a57cc11d492f71a",
         | 
| 284 | 
            +
                        "585e34302a57cc11d492be30",
         | 
| 285 | 
            +
                        "585ee0632a57cc11d4933608",
         | 
| 286 | 
            +
                        "585f9661712e2761468dabca",
         | 
| 287 | 
            +
                        "585ffe9a712e2761468df643",
         | 
| 288 | 
            +
                        "586a37ec9d1b5e34c28184fc",
         | 
| 289 | 
            +
                        "586a515a9d1b5e34c281b431",
         | 
| 290 | 
            +
                        "586a94939d1b5e34c2823b5d",
         | 
| 291 | 
            +
                        "586abc689d1b5e34c2826360",
         | 
| 292 | 
            +
                        "586b0e219d1b5e34c2828862",
         | 
| 293 | 
            +
                        "586b3db89d1b5e34c282cd52",
         | 
| 294 | 
            +
                        "586b4c459d1b5e34c282e66d",
         | 
| 295 | 
            +
                        "586b7d7d9d1b5e34c283359e",
         | 
| 296 | 
            +
                        "586b8f149d1b5e34c283497c",
         | 
| 297 | 
            +
                        "586b8f629d1b5e34c28349d6",
         | 
| 298 | 
            +
                        "586c4c4d9d1b5e34c28391a1",
         | 
| 299 | 
            +
                        "586c5b5b9d1b5e34c2839a5b",
         | 
| 300 | 
            +
                        "586c9fdf9d1b5e34c283b657",
         | 
| 301 | 
            +
                        "586c48329d1b5e34c2838e80",
         | 
| 302 | 
            +
                        "586caab99d1b5e34c283c213",
         | 
| 303 | 
            +
                        "586cd0779d1b5e34c28403a7",
         | 
| 304 | 
            +
                        "586d6d249d1b5e34c284b80e",
         | 
| 305 | 
            +
                        "586d8a029d1b5e34c284c948",
         | 
| 306 | 
            +
                        "586d55af9d1b5e34c284a999",
         | 
| 307 | 
            +
                        "586d07869d1b5e34c2842e5b",
         | 
| 308 | 
            +
                        "586d27489d1b5e34c28453af",
         | 
| 309 | 
            +
                        "586df9849d1b5e34c28506de",
         | 
| 310 | 
            +
                        "586e279c9d1b5e34c2852180",
         | 
| 311 | 
            +
                        "587bc5ec2366dd5d06e262c1",
         | 
| 312 | 
            +
                        "587c1abf2366dd5d06e28901",
         | 
| 313 | 
            +
                        "587c03f12366dd5d06e27722",
         | 
| 314 | 
            +
                        "587c19da2366dd5d06e2877b",
         | 
| 315 | 
            +
                        "587c31b92366dd5d06e2a9dc",
         | 
| 316 | 
            +
                        "587c87d02366dd5d06e2f989",
         | 
| 317 | 
            +
                        "587c97a52366dd5d06e30a96",
         | 
| 318 | 
            +
                        "587c45192366dd5d06e2c0eb",
         | 
| 319 | 
            +
                        "587cec702366dd5d06e37862",
         | 
| 320 | 
            +
                        "587cef0a2366dd5d06e379e3",
         | 
| 321 | 
            +
                        "587db5872366dd5d06e3e0af",
         | 
| 322 | 
            +
                        "587e2b1d2366dd5d06e41af0",
         | 
| 323 | 
            +
                        "587e2ea62366dd5d06e41f2e",
         | 
| 324 | 
            +
                        "587e5cb52366dd5d06e4486e",
         | 
| 325 | 
            +
                        "587eb1822366dd5d06e45f29",
         | 
| 326 | 
            +
                        "587f365d2366dd5d06e4906e",
         | 
| 327 | 
            +
                        "588a9c5fec4d5a1c088ec350",
         | 
| 328 | 
            +
                        "588a34cfec4d5a1c088ea8d1",
         | 
| 329 | 
            +
                        "588ab5bdec4d5a1c088ed60f",
         | 
| 330 | 
            +
                        "588aff9d90414422fbe7885a",
         | 
| 331 | 
            +
                        "588b20d290414422fbe79f40",
         | 
| 332 | 
            +
                        "588c08d590414422fbe8200b",
         | 
| 333 | 
            +
                        "588c203d90414422fbe8319e",
         | 
| 334 | 
            +
                        "588c989a90414422fbe86d96",
         | 
| 335 | 
            +
                        "588ca09d90414422fbe871a1",
         | 
| 336 | 
            +
                        "588cce2190414422fbe88520",
         | 
| 337 | 
            +
                        "588cd5ef90414422fbe8875c",
         | 
| 338 | 
            +
                        "588cf0ad90414422fbe8a20f",
         | 
| 339 | 
            +
                        "588e0d8c90414422fbe8f8b2",
         | 
| 340 | 
            +
                        "588e01c490414422fbe8ee2a",
         | 
| 341 | 
            +
                        "588e35e690414422fbe90a53",
         | 
| 342 | 
            +
                        "588f017e90414422fbe9b74b",
         | 
| 343 | 
            +
                        "588f095190414422fbe9c1ee",
         | 
| 344 | 
            +
                        "589aca717dc3d323d55671c4",
         | 
| 345 | 
            +
                        "589af2c97dc3d323d55691e8",
         | 
| 346 | 
            +
                        "589b49ea7dc3d323d556d9b4",
         | 
| 347 | 
            +
                        "589b04287dc3d323d556a185",
         | 
| 348 | 
            +
                        "589bf6a57dc3d323d55743ab",
         | 
| 349 | 
            +
                        "589c3c497dc3d323d5578468",
         | 
| 350 | 
            +
                        "589c3c577dc3d323d5578480",
         | 
| 351 | 
            +
                        "589c300f7dc3d323d5577926",
         | 
| 352 | 
            +
                        "589c24527dc3d323d5577126",
         | 
| 353 | 
            +
                        "589c35457dc3d323d5577d8d",
         | 
| 354 | 
            +
                        "589ca6a6b896147a1b73aff7",
         | 
| 355 | 
            +
                        "589d1e1fb896147a1b73ee5b",
         | 
| 356 | 
            +
                        "589d5c58b896147a1b742256",
         | 
| 357 | 
            +
                        "589d95538fa2cf375df3317b",
         | 
| 358 | 
            +
                        "589df0ffb504a864ad63521a",
         | 
| 359 | 
            +
                        "589ea316b504a864ad639a2b",
         | 
| 360 | 
            +
                        "589ec97cb504a864ad63adc3",
         | 
| 361 | 
            +
                        "589f214338486e3c9846f123",
         | 
| 362 | 
            +
                        "589fdfe738486e3c984736cf",
         | 
| 363 | 
            +
                        "590c2d70336bb52a190be886",
         | 
| 364 | 
            +
                        "590f91851225725be9e25d4e",
         | 
| 365 | 
            +
                        "591a467a6109e14d4f09b776",
         | 
| 366 | 
            +
                        "591cf3033162411cf9047f37",
         | 
| 367 | 
            +
                        "591ea44850991c70dc99a207",
         | 
| 368 | 
            +
                        "599aa591d5b41f366fed0d58",
         | 
| 369 | 
            +
                        "5643df56138263b51db1b5f3",
         | 
| 370 | 
            +
                        "5644bdac138263b51db9f669",
         | 
| 371 | 
            +
                        "5692a4c2adafac1f14201821",
         | 
| 372 | 
            +
                        "5850d4f97072670e72c425d6",
         | 
| 373 | 
            +
                        "5854c405804be105852330fe",
         | 
| 374 | 
            +
                        "5855a4fc804be1058523bd75",
         | 
| 375 | 
            +
                        "5856ac15804be105852419d8",
         | 
| 376 | 
            +
                        "5856ae8b804be10585241bae",
         | 
| 377 | 
            +
                        "5856b460804be10585242059",
         | 
| 378 | 
            +
                        "5857aa5ab338a62ad5ff4dbe",
         | 
| 379 | 
            +
                        "5857acf8b338a62ad5ff5107",
         | 
| 380 | 
            +
                        "5858db6cb338a62ad500103b",
         | 
| 381 | 
            +
                        "5858dbcab338a62ad5001081",
         | 
| 382 | 
            +
                        "5859d84fb338a62ad500e5cf",
         | 
| 383 | 
            +
                        "5861d8ea712e2761468f3cb3",
         | 
| 384 | 
            +
                        "5863edf8712e27614690cce0",
         | 
| 385 | 
            +
                        "5864a935712e2761469111b4",
         | 
| 386 | 
            +
                        "5864b076712e27614691197e",
         | 
| 387 | 
            +
                        "5864da88712e276146913d8b",
         | 
| 388 | 
            +
                        "5865f4a8712e27614691e39b",
         | 
| 389 | 
            +
                        "5867a434833dfe3f7b88edaf",
         | 
| 390 | 
            +
                        "5868cd15833dfe3f7b89bfa3",
         | 
| 391 | 
            +
                        "5880b3692366dd5d06e5d534",
         | 
| 392 | 
            +
                        "5880e3422366dd5d06e5ff8e",
         | 
| 393 | 
            +
                        "5880f0ef2366dd5d06e6166e",
         | 
| 394 | 
            +
                        "5881d2bfb6844814c136a119",
         | 
| 395 | 
            +
                        "5881f11d8ce2c2754d0714c3",
         | 
| 396 | 
            +
                        "5881fee18ce2c2754d0723f8",
         | 
| 397 | 
            +
                        "5882cda2b116682b4adebd25",
         | 
| 398 | 
            +
                        "5882d58fb116682b4adec7db",
         | 
| 399 | 
            +
                        "5884c256932ba84fbed70bf5",
         | 
| 400 | 
            +
                        "5884cc13932ba84fbed71ec4",
         | 
| 401 | 
            +
                        "5885bc5296fa095e0671a7f0",
         | 
| 402 | 
            +
                        "5886d14cb791366d617a362c",
         | 
| 403 | 
            +
                        "5888becfc02346100f4b0b21",
         | 
| 404 | 
            +
                        "5888e408c02346100f4b1a29",
         | 
| 405 | 
            +
                        "5889da66ec4d5a1c088e5187",
         | 
| 406 | 
            +
                        "5889e344ec4d5a1c088e59be",
         | 
| 407 | 
            +
                        "5889e754ec4d5a1c088e60ba",
         | 
| 408 | 
            +
                        "5890c16b90414422fbeb0262",
         | 
| 409 | 
            +
                        "5891d8ae9a8c0314c5cd30ab",
         | 
| 410 | 
            +
                        "5891d0479a8c0314c5cd2abd",
         | 
| 411 | 
            +
                        "5891ecf19a8c0314c5cd490a",
         | 
| 412 | 
            +
                        "5892c0cd9a8c0314c5cdc977",
         | 
| 413 | 
            +
                        "5894ab309a8c0314c5cee57d",
         | 
| 414 | 
            +
                        "5895a6a89a8c0314c5cfca7c",
         | 
| 415 | 
            +
                        "5895b8c29a8c0314c5cfd051",
         | 
| 416 | 
            +
                        "5895d38f9a8c0314c5cfe50c",
         | 
| 417 | 
            +
                        "5895f2329a8c0314c5d00117",
         | 
| 418 | 
            +
                        "5896bb989a8c0314c5d086b6",
         | 
| 419 | 
            +
                        "5896ebf39a8c0314c5d0a8c4",
         | 
| 420 | 
            +
                        "5898b1bac9dccc22987b7f74",
         | 
| 421 | 
            +
                        "5898b6ffc9dccc22987b8a03",
         | 
| 422 | 
            +
                        "5898b31cc9dccc22987b82ec",
         | 
| 423 | 
            +
                        "5898bbaac9dccc22987b8eba",
         | 
| 424 | 
            +
                        "5899cfa6b76d7a3780a4cb64",
         | 
| 425 | 
            +
                        "5899e5dcb76d7a3780a4ecc1",
         | 
| 426 | 
            +
                        "5947b62af1b45630bd0c2a02",
         | 
| 427 | 
            +
                        "57102be2877e1421026358af",
         | 
| 428 | 
            +
                        "57153d4031bb9900425bde85",
         | 
| 429 | 
            +
                        "57177cd7fb8d93461afc4527",
         | 
| 430 | 
            +
                        "58497cdf97b73e0b090c4273",
         | 
| 431 | 
            +
                        "58500b007072670e72c35588",
         | 
| 432 | 
            +
                        "58510bf97072670e72c46ddf",
         | 
| 433 | 
            +
                        "58522bd56789802282f2ecb3",
         | 
| 434 | 
            +
                        "58524a2e0e7012308944bcf3",
         | 
| 435 | 
            +
                        "58524a080e7012308944bcbf",
         | 
| 436 | 
            +
                        "58524c1d0e7012308944bfda",
         | 
| 437 | 
            +
                        "58524f170e7012308944c200",
         | 
| 438 | 
            +
                        "58529a4e0e70123089454c6f",
         | 
| 439 | 
            +
                        "58551bdf804be1058523556d",
         | 
| 440 | 
            +
                        "58568c9a804be10585240b03",
         | 
| 441 | 
            +
                        "58574b35804be105852455fd",
         | 
| 442 | 
            +
                        "58577c60b338a62ad5ff1564",
         | 
| 443 | 
            +
                        "58592d69b338a62ad5007a74",
         | 
| 444 | 
            +
                        "58598db2b338a62ad500bc38",
         | 
| 445 | 
            +
                        "58625f42712e2761468fb44c",
         | 
| 446 | 
            +
                        "58651bcc712e2761469166dc",
         | 
| 447 | 
            +
                        "58660e79712e27614691fe3d",
         | 
| 448 | 
            +
                        "58669aad712e27614692834c",
         | 
| 449 | 
            +
                        "58669c02712e27614692851a",
         | 
| 450 | 
            +
                        "58676c36833dfe3f7b88b7f2",
         | 
| 451 | 
            +
                        "58678b2d833dfe3f7b88e244",
         | 
| 452 | 
            +
                        "58790c82ce911104a3467c88",
         | 
| 453 | 
            +
                        "58800b0b2366dd5d06e5312d",
         | 
| 454 | 
            +
                        "58805eac2366dd5d06e56460",
         | 
| 455 | 
            +
                        "58806e422366dd5d06e57bb6",
         | 
| 456 | 
            +
                        "58831d060db9bf59bf8ab98b",
         | 
| 457 | 
            +
                        "58851ebb932ba84fbed7abad",
         | 
| 458 | 
            +
                        "58871dc3b791366d617a55ff",
         | 
| 459 | 
            +
                        "58873cabb791366d617a65a7",
         | 
| 460 | 
            +
                        "58873d44b791366d617a65dd",
         | 
| 461 | 
            +
                        "58888b3dc02346100f4af665",
         | 
| 462 | 
            +
                        "58897f62c02346100f4b8ee6",
         | 
| 463 | 
            +
                        "58933bac9a8c0314c5ce3508",
         | 
| 464 | 
            +
                        "58938e6d9a8c0314c5ce726f",
         | 
| 465 | 
            +
                        "58951cb49a8c0314c5cf4d5e",
         | 
| 466 | 
            +
                        "58970fd09a8c0314c5d0e383",
         | 
| 467 | 
            +
                        "58977ef09a8c0314c5d17b26",
         | 
| 468 | 
            +
                        "59056e6760bb961de55f3501",
         | 
| 469 | 
            +
                        "59071f2e5a6dbd3af4130f98",
         | 
| 470 | 
            +
                        "59102c811225725be9e64149",
         | 
| 471 | 
            +
                        "59338e76772c3e6384afbb15",
         | 
| 472 | 
            +
                        "59350ca084b7f26bf5ce6eb8",
         | 
| 473 | 
            +
                        "59397e493a87372f2c9e882b",
         | 
| 474 | 
            +
                        "59521e0b9096412211c2aa9d",
         | 
| 475 | 
            +
                        "59817e4a1bd4b175e7038d19",
         | 
| 476 | 
            +
                        "567884f58d2828b95e3c8eba",
         | 
| 477 | 
            +
                        "585559d9804be10585238ddf",
         | 
| 478 | 
            +
                        "585834cdb338a62ad5ffab4d",
         | 
| 479 | 
            +
                        "586082d8712e2761468e2877",
         | 
| 480 | 
            +
                        "586133c2712e2761468ecfe3",
         | 
| 481 | 
            +
                        "586281d2712e2761468fcaa2",
         | 
| 482 | 
            +
                        "586316e5712e276146903c4d",
         | 
| 483 | 
            +
                        "586326ad712e276146904571",
         | 
| 484 | 
            +
                        "586375c9712e276146907429",
         | 
| 485 | 
            +
                        "586389c9712e276146908da6",
         | 
| 486 | 
            +
                        "586496fa712e2761469108e7",
         | 
| 487 | 
            +
                        "586669c6712e27614692597a",
         | 
| 488 | 
            +
                        "586913a49d1b5e34c2808b02",
         | 
| 489 | 
            +
                        "586922da9d1b5e34c2809ff3",
         | 
| 490 | 
            +
                        "588185d8dfb7a15588a114a3",
         | 
| 491 | 
            +
                        "588305ed0db9bf59bf8a8c80",
         | 
| 492 | 
            +
                        "588315c60db9bf59bf8aa928",
         | 
| 493 | 
            +
                        "588332ee0db9bf59bf8ae9c3",
         | 
| 494 | 
            +
                        "588457b8932ba84fbed69942",
         | 
| 495 | 
            +
                        "588519d5932ba84fbed7a04a",
         | 
| 496 | 
            +
                        "588824d1b791366d617adeef",
         | 
| 497 | 
            +
                        "588857f6c02346100f4ac09f",
         | 
| 498 | 
            +
                        "589145ef90414422fbeb2e08",
         | 
| 499 | 
            +
                        "589433fa9a8c0314c5ce9656",
         | 
| 500 | 
            +
                        "589765d39a8c0314c5d16b12",
         | 
| 501 | 
            +
                        "5851165f7072670e72c4860d",
         | 
| 502 | 
            +
                        "5859341ab338a62ad500848d",
         | 
| 503 | 
            +
                        "5862388b712e2761468f84aa",
         | 
| 504 | 
            +
                        "5863915b712e276146909135",
         | 
| 505 | 
            +
                        "5866445b712e27614692383e",
         | 
| 506 | 
            +
                        "5866500d712e2761469240fd",
         | 
| 507 | 
            +
                        "5867785a833dfe3f7b88c764",
         | 
| 508 | 
            +
                        "5867969c833dfe3f7b88e8bc",
         | 
| 509 | 
            +
                        "5868040c833dfe3f7b8934f7",
         | 
| 510 | 
            +
                        "5880675a2366dd5d06e570ca",
         | 
| 511 | 
            +
                        "5882372c8ce2c2754d076af0",
         | 
| 512 | 
            +
                        "5883535e932ba84fbed5ad07",
         | 
| 513 | 
            +
                        "5888358cb791366d617af69d",
         | 
| 514 | 
            +
                        "5890330d90414422fbeaa0cb",
         | 
| 515 | 
            +
                        "5897076e9a8c0314c5d0d31b",
         | 
| 516 | 
            +
                        "5940564ec2d9527ab869f7e2",
         | 
| 517 | 
            +
                        "5947719bf1b45630bd096665",
         | 
| 518 | 
            +
                        "5948194ff1b45630bd0f47e3",
         | 
| 519 | 
            +
                        "5950206a41b158666ac50506",
         | 
| 520 | 
            +
                        "5983012d1bd4b175e70c985a",
         | 
| 521 | 
            +
                        "58586810b338a62ad5ffc20c",
         | 
| 522 | 
            +
                        "58592046b338a62ad5006b33",
         | 
| 523 | 
            +
                        "58592854b338a62ad500750a",
         | 
| 524 | 
            +
                        "58596531b338a62ad500aace",
         | 
| 525 | 
            +
                        "58818685dfb7a15588a11626",
         | 
| 526 | 
            +
                        "58829563f42b1d3ee3ec835f",
         | 
| 527 | 
            +
                        "58894345c02346100f4b51ca",
         | 
| 528 | 
            +
                        "585289980e7012308945276a",
         | 
| 529 | 
            +
                        "585369770e7012308945c709",
         | 
| 530 | 
            +
                        "585373640e7012308945cab9",
         | 
| 531 | 
            +
                        "588230658ce2c2754d076728",
         | 
| 532 | 
            +
                        "589388059a8c0314c5ce718b",
         | 
| 533 | 
            +
                        "595979485ec6a95e86a58c8d",
         | 
| 534 | 
            +
                        "5841206219d291325678ca90",
         | 
| 535 | 
            +
                        "58563650804be1058523da55",
         | 
| 536 | 
            +
                        "58564084804be1058523e116",
         | 
| 537 | 
            +
                        "58636467712e27614690661f",
         | 
| 538 | 
            +
                        "58647495712e27614690f36d",
         | 
| 539 | 
            +
                        "58654563712e276146918643",
         | 
| 540 | 
            +
                        "58664251712e276146923738",
         | 
| 541 | 
            +
                        "588084032366dd5d06e59e82",
         | 
| 542 | 
            +
                        "588159582366dd5d06e66877",
         | 
| 543 | 
            +
                        "5890279190414422fbea9734",
         | 
| 544 | 
            +
                        "5890523090414422fbeab3f0",
         | 
| 545 | 
            +
                        "5890641690414422fbeabbe7",
         | 
| 546 | 
            +
                        "585203546789802282f2aaf5",
         | 
| 547 | 
            +
                    ]
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    # Final sequences to be used after filtering (some of the sequences have incorrect/low quality depth)
         | 
| 550 | 
            +
                    # Generally water bodies like lakes have incorrect depth
         | 
| 551 | 
            +
                    # Filtered out sequences:
         | 
| 552 | 
            +
                    # "5692a4c2adafac1f14201821" # Incorrect Depth
         | 
| 553 | 
            +
                    # "5864a935712e2761469111b4" # Noisy Depth and artifacts near horizon
         | 
| 554 | 
            +
                    # "59f87d0bfa6280566fb38c9a" # Object-centric, noise with background and sometimes in front of object
         | 
| 555 | 
            +
                    # "58a44463156b87103d3ed45e" # Very noisy depth in background
         | 
| 556 | 
            +
                    # "5c2b3ed5e611832e8aed46bf" # Depth occluded by artifacts
         | 
| 557 | 
            +
                    # "5bf03590d4392319481971dc" # Depth occluded by artifacts
         | 
| 558 | 
            +
                    # "00000000000000000000001a" # Largely incomplete depth
         | 
| 559 | 
            +
                    # "00000000000000000000000c" # Imprecise depth for buildings
         | 
| 560 | 
            +
                    # "000000000000000000000000" # Incorrect depth for planar terrain
         | 
| 561 | 
            +
                    self.scenes = [
         | 
| 562 | 
            +
                        "00000000000000000000000a",
         | 
| 563 | 
            +
                        "00000000000000000000000b",
         | 
| 564 | 
            +
                        "00000000000000000000000d",
         | 
| 565 | 
            +
                        "00000000000000000000000e",
         | 
| 566 | 
            +
                        "00000000000000000000000f",
         | 
| 567 | 
            +
                        "000000000000000000000001",
         | 
| 568 | 
            +
                        "00000000000000000000001b",
         | 
| 569 | 
            +
                        "00000000000000000000001d",
         | 
| 570 | 
            +
                        "000000000000000000000002",
         | 
| 571 | 
            +
                        "000000000000000000000003",
         | 
| 572 | 
            +
                        "000000000000000000000004",
         | 
| 573 | 
            +
                        "000000000000000000000005",
         | 
| 574 | 
            +
                        "5a2a95f032a1c655cfe3de62",
         | 
| 575 | 
            +
                        "5a2af22b32a1c655cfe46013",
         | 
| 576 | 
            +
                        "5a2ba6de32a1c655cfe51b79",
         | 
| 577 | 
            +
                        "5a3b9731e24cd76dad1a5f1b",
         | 
| 578 | 
            +
                        "5a3ca9cb270f0e3f14d0eddb",
         | 
| 579 | 
            +
                        "5a3cb4e4270f0e3f14d12f43",
         | 
| 580 | 
            +
                        "5a03e732454a8a7ec672776c",
         | 
| 581 | 
            +
                        "5a3f4aba5889373fbbc5d3b5",
         | 
| 582 | 
            +
                        "5a4a38dad38c8a075495b5d2",
         | 
| 583 | 
            +
                        "5a5a1e48d62c7a12d5d00e47",
         | 
| 584 | 
            +
                        "5a6b1c418d100c2f8fdc4411",
         | 
| 585 | 
            +
                        "5a6feeb54a7fbc3f874f9db7",
         | 
| 586 | 
            +
                        "5a7cb1d6fe5c0d6fb53e64fb",
         | 
| 587 | 
            +
                        "5a7d3db14989e929563eb153",
         | 
| 588 | 
            +
                        "5a8aa0fab18050187cbe060e",
         | 
| 589 | 
            +
                        "5a9e5df65baeef72b4a021cd",
         | 
| 590 | 
            +
                        "5a48ba95c7dab83a7d7b44ed",
         | 
| 591 | 
            +
                        "5a48c4e9c7dab83a7d7b5cc7",
         | 
| 592 | 
            +
                        "5a48d4b2c7dab83a7d7b9851",
         | 
| 593 | 
            +
                        "5a69c47d0d5d0a7f3b2e9752",
         | 
| 594 | 
            +
                        "5a77b46b318efe6c6736e68a",
         | 
| 595 | 
            +
                        "5a355c271b63f53d5970f362",
         | 
| 596 | 
            +
                        "5a489fb1c7dab83a7d7b1070",
         | 
| 597 | 
            +
                        "5a533e8034d7582116e34209",
         | 
| 598 | 
            +
                        "5a562fc7425d0f5186314725",
         | 
| 599 | 
            +
                        "5a572fd9fc597b0478a81d14",
         | 
| 600 | 
            +
                        "5a588a8193ac3d233f77fbca",
         | 
| 601 | 
            +
                        "5a618c72784780334bc1972d",
         | 
| 602 | 
            +
                        "5a752d42acc41e2423f17674",
         | 
| 603 | 
            +
                        "5a969eea91dfc339a9a3ad2c",
         | 
| 604 | 
            +
                        "5a8315f624b8e938486e0bd8",
         | 
| 605 | 
            +
                        "5a57542f333d180827dfc132",
         | 
| 606 | 
            +
                        "5a0271884e62597cdee0d0eb",
         | 
| 607 | 
            +
                        "5a6400933d809f1d8200af15",
         | 
| 608 | 
            +
                        "5a6464143d809f1d8208c43c",
         | 
| 609 | 
            +
                        "5a563183425d0f5186314855",
         | 
| 610 | 
            +
                        "5aa0f9d7a9efce63548c69a1",
         | 
| 611 | 
            +
                        "5aa0f478a9efce63548c1cb4",
         | 
| 612 | 
            +
                        "5aa7db90bfdd572271e95246",
         | 
| 613 | 
            +
                        "5aa235f64a17b335eeaf9609",
         | 
| 614 | 
            +
                        "5aa515e613d42d091d29d300",
         | 
| 615 | 
            +
                        "5aa1196ea9efce63548ed649",
         | 
| 616 | 
            +
                        "5aaadd4cbc13235570d178a7",
         | 
| 617 | 
            +
                        "5ab6af12ac4291329b1072ab",
         | 
| 618 | 
            +
                        "5ab7e00aac4291329b15864d",
         | 
| 619 | 
            +
                        "5ab8b8e029f5351f7f2ccf59",
         | 
| 620 | 
            +
                        "5ab74bf2ac4291329b11e879",
         | 
| 621 | 
            +
                        "5ab85f1dac4291329b17cb50",
         | 
| 622 | 
            +
                        "5ab8713ba3799a1d138bd69a",
         | 
| 623 | 
            +
                        "5abc2506b53b042ead637d86",
         | 
| 624 | 
            +
                        "5acc7459a7853c4b5ebbef59",
         | 
| 625 | 
            +
                        "5acf8ca0f3d8a750097e4b15",
         | 
| 626 | 
            +
                        "5adc6bd52430a05ecb2ffb85",
         | 
| 627 | 
            +
                        "5ae2e9c5fe405c5076abc6b2",
         | 
| 628 | 
            +
                        "5af02e904c8216544b4ab5a2",
         | 
| 629 | 
            +
                        "5af28cea59bc705737003253",
         | 
| 630 | 
            +
                        "5af545d0559359053d25dcf5",
         | 
| 631 | 
            +
                        "5afacb69ab00705d0cefdd5b",
         | 
| 632 | 
            +
                        "5b2c67b5e0878c381608b8d8",
         | 
| 633 | 
            +
                        "5b3b2b9e8d46a939f933fdc0",
         | 
| 634 | 
            +
                        "5b3b353d8d46a939f93524b9",
         | 
| 635 | 
            +
                        "5b6e716d67b396324c2d77cb",
         | 
| 636 | 
            +
                        "5b6eff8b67b396324c5b2672",
         | 
| 637 | 
            +
                        "5b7a3890fc8fcf6781e2593a",
         | 
| 638 | 
            +
                        "5b21e18c58e2823a67a10dd8",
         | 
| 639 | 
            +
                        "5b60fa0c764f146feef84df0",
         | 
| 640 | 
            +
                        "5b69cc0cb44b61786eb959bf",
         | 
| 641 | 
            +
                        "5b78e57afc8fcf6781d0c3ba",
         | 
| 642 | 
            +
                        "5b192eb2170cf166458ff886",
         | 
| 643 | 
            +
                        "5b558a928bbfb62204e77ba2",
         | 
| 644 | 
            +
                        "5b864d850d072a699b32f4ae",
         | 
| 645 | 
            +
                        "5b908d3dc6ab78485f3d24a9",
         | 
| 646 | 
            +
                        "5b950c71608de421b1e7318f",
         | 
| 647 | 
            +
                        "5b4933abf2b5f44e95de482a",
         | 
| 648 | 
            +
                        "5b08286b2775267d5b0634ba",
         | 
| 649 | 
            +
                        "5b37189a35304b6f75e7583e",
         | 
| 650 | 
            +
                        "5b271079e0878c3816dacca4",
         | 
| 651 | 
            +
                        "5b22269758e2823a67a3bd03",
         | 
| 652 | 
            +
                        "5b62647143840965efc0dbde",
         | 
| 653 | 
            +
                        "5ba19a8a360c7c30c1c169df",
         | 
| 654 | 
            +
                        "5ba75d79d76ffa2c86cf2f05",
         | 
| 655 | 
            +
                        "5bb7a08aea1cfa39f1a947ab",
         | 
| 656 | 
            +
                        "5bb8a49aea1cfa39f1aa7f75",
         | 
| 657 | 
            +
                        "5bbb6eb2ea1cfa39f1af7e0c",
         | 
| 658 | 
            +
                        "5bc5f0e896b66a2cd8f9bd36",
         | 
| 659 | 
            +
                        "5bccd6beca24970bce448134",
         | 
| 660 | 
            +
                        "5bce7ac9ca24970bce4934b6",
         | 
| 661 | 
            +
                        "5bcf979a6d5f586b95c258cd",
         | 
| 662 | 
            +
                        "5bd43b4ba6b28b1ee86b92dd",
         | 
| 663 | 
            +
                        "5be3a5fb8cfdd56947f6b67c",
         | 
| 664 | 
            +
                        "5be3ae47f44e235bdbbc9771",
         | 
| 665 | 
            +
                        "5be4ab93870d330ff2dce134",
         | 
| 666 | 
            +
                        "5be47bf9b18881428d8fbc1d",
         | 
| 667 | 
            +
                        "5be883a4f98cee15019d5b83",
         | 
| 668 | 
            +
                        "5bea87f4abd34c35e1860ab5",
         | 
| 669 | 
            +
                        "5beb6e66abd34c35e18e66b9",
         | 
| 670 | 
            +
                        "5bf3a82cd439231948877aed",
         | 
| 671 | 
            +
                        "5bf7d63575c26f32dbf7413b",
         | 
| 672 | 
            +
                        "5bf17c0fd439231948355385",
         | 
| 673 | 
            +
                        "5bf26cbbd43923194854b270",
         | 
| 674 | 
            +
                        "5bf18642c50e6f7f8bdbd492",
         | 
| 675 | 
            +
                        "5bf21799d43923194842c001",
         | 
| 676 | 
            +
                        "5bfc9d5aec61ca1dd69132a2",
         | 
| 677 | 
            +
                        "5bfd0f32ec61ca1dd69dc77b",
         | 
| 678 | 
            +
                        "5bfe5ae0fe0ea555e6a969ca",
         | 
| 679 | 
            +
                        "5bff3c5cfe0ea555e6bcbf3a",
         | 
| 680 | 
            +
                        "5c0d13b795da9479e12e2ee9",
         | 
| 681 | 
            +
                        "5c1af2e2bee9a723c963d019",
         | 
| 682 | 
            +
                        "5c1b1500bee9a723c96c3e78",
         | 
| 683 | 
            +
                        "5c1dbf200843bc542d8ef8c4",
         | 
| 684 | 
            +
                        "5c1f33f1d33e1f2e4aa6dda4",
         | 
| 685 | 
            +
                        "5c20ca3a0843bc542d94e3e2",
         | 
| 686 | 
            +
                        "5c062d84a96e33018ff6f0a6",
         | 
| 687 | 
            +
                        "5c189f2326173c3a09ed7ef3",
         | 
| 688 | 
            +
                        "5c1892f726173c3a09ea9aeb",
         | 
| 689 | 
            +
                        "5c34300a73a8df509add216d",
         | 
| 690 | 
            +
                        "5c34529873a8df509ae57b58",
         | 
| 691 | 
            +
                        "000000000000000000000006",
         | 
| 692 | 
            +
                        "000000000000000000000007",
         | 
| 693 | 
            +
                        "000000000000000000000008",
         | 
| 694 | 
            +
                        "000000000000000000000009",
         | 
| 695 | 
            +
                        "000000000000000000000010",
         | 
| 696 | 
            +
                        "000000000000000000000011",
         | 
| 697 | 
            +
                        "000000000000000000000012",
         | 
| 698 | 
            +
                        "000000000000000000000015",
         | 
| 699 | 
            +
                        "000000000000000000000016",
         | 
| 700 | 
            +
                        "000000000000000000000017",
         | 
| 701 | 
            +
                        "000000000000000000000018",
         | 
| 702 | 
            +
                        "000000000000000000000019",
         | 
| 703 | 
            +
                        "56d73ba74bd29b8c35abade2",
         | 
| 704 | 
            +
                        "56f34064e296120e10484dc4",
         | 
| 705 | 
            +
                        "57a4a7bb6b9272286e26dc18",
         | 
| 706 | 
            +
                        "57f8d9bbe73f6760f10e916a",
         | 
| 707 | 
            +
                        "58a0a2f33d0b4542479a11b1",
         | 
| 708 | 
            +
                        "58a0dd1a3d0b4542479a28f3",
         | 
| 709 | 
            +
                        "58a1a7914a4d262a170b1101",
         | 
| 710 | 
            +
                        "58a1bc804a4d262a170b2f01",
         | 
| 711 | 
            +
                        "58a1d9d14a4d262a170b58fe",
         | 
| 712 | 
            +
                        "58a01dea38486e3c98475871",
         | 
| 713 | 
            +
                        "58a1f5d74a4d262a170b65fc",
         | 
| 714 | 
            +
                        "58a2a09e156b87103d3d668c",
         | 
| 715 | 
            +
                        "58a2d9c3156b87103d3da90f",
         | 
| 716 | 
            +
                        "58a3ccb0156b87103d3e4332",
         | 
| 717 | 
            +
                        "58a3f2f8156b87103d3e5838",
         | 
| 718 | 
            +
                        "58a3f6c0156b87103d3e5971",
         | 
| 719 | 
            +
                        "58a3fc95156b87103d3e5d9b",
         | 
| 720 | 
            +
                        "58a07ce53d0b45424799fdde",
         | 
| 721 | 
            +
                        "58a07f233d0b45424799ffe7",
         | 
| 722 | 
            +
                        "58a44df2156b87103d3ee239",
         | 
| 723 | 
            +
                        "58a164f73d0b4542479a7a8e",
         | 
| 724 | 
            +
                        "58a0365e38486e3c984783eb",
         | 
| 725 | 
            +
                        "58a439cf156b87103d3ec885",
         | 
| 726 | 
            +
                        "58a464aa156b87103d3eec04",
         | 
| 727 | 
            +
                        "58a4452f156b87103d3ed55b",
         | 
| 728 | 
            +
                        "58a160983d0b4542479a7347",
         | 
| 729 | 
            +
                        "58a186444a4d262a170ae3ae",
         | 
| 730 | 
            +
                        "58a285424a4d262a170baf3e",
         | 
| 731 | 
            +
                        "58a41819156b87103d3e92a5",
         | 
| 732 | 
            +
                        "58a47552156b87103d3f00a4",
         | 
| 733 | 
            +
                        "58c4bb4f4a69c55606122be4",
         | 
| 734 | 
            +
                        "58c6451e4a69c556061894f1",
         | 
| 735 | 
            +
                        "58ca7014affdfd07c70a95ce",
         | 
| 736 | 
            +
                        "58cf4771d0f5fb221defe6da",
         | 
| 737 | 
            +
                        "58d36897f387231e6c929903",
         | 
| 738 | 
            +
                        "58eaf1513353456af3a1682a",
         | 
| 739 | 
            +
                        "58f7f7299f5b5647873cb110",
         | 
| 740 | 
            +
                        "58f73e7c9f5b56478738929f",
         | 
| 741 | 
            +
                        "59a8f851597729752c31e7e0",
         | 
| 742 | 
            +
                        "59a452bf9b460239aa5d1c72",
         | 
| 743 | 
            +
                        "59a9619a825418241fb88191",
         | 
| 744 | 
            +
                        "59acd2f4b891807f439c8992",
         | 
| 745 | 
            +
                        "59bf97fe7e7b31545da34439",
         | 
| 746 | 
            +
                        "59c1c3e2fd6e3d4ead9f1013",
         | 
| 747 | 
            +
                        "59d2657f82ca7774b1ec081d",
         | 
| 748 | 
            +
                        "59da1fb88a126011d0394ae9",
         | 
| 749 | 
            +
                        "59e75a2ca9e91f2c5526005d",
         | 
| 750 | 
            +
                        "59e864b2a9e91f2c5529325f",
         | 
| 751 | 
            +
                        "59ecfd02e225f6492d20fcc9",
         | 
| 752 | 
            +
                        "59f37f74b45be2233001ba18",
         | 
| 753 | 
            +
                        "59f70ab1e5c5d366af29bf3e",
         | 
| 754 | 
            +
                        "59f363a8b45be22330016cad",
         | 
| 755 | 
            +
                        "564a27b26d07883f460d8ab0",
         | 
| 756 | 
            +
                        "565fb1dead14d4154dae2b94",
         | 
| 757 | 
            +
                        "567a0fb0a825d2fb79ac9a20",
         | 
| 758 | 
            +
                        "569b92eb826bcba945ca002b",
         | 
| 759 | 
            +
                        "576fefa017ce5a16397e87fd",
         | 
| 760 | 
            +
                        "584a7333fe3cb463906c9fe6",
         | 
| 761 | 
            +
                        "584aa8e9fe3cb463906cc7d0",
         | 
| 762 | 
            +
                        "584ad76bfe3cb463906ce6dc",
         | 
| 763 | 
            +
                        "584af003fe3cb463906d0e9b",
         | 
| 764 | 
            +
                        "584b9a747072670e72bfc49d",
         | 
| 765 | 
            +
                        "584b671f7072670e72bfaaf8",
         | 
| 766 | 
            +
                        "584b81747072670e72bfbbfd",
         | 
| 767 | 
            +
                        "584ba35f7072670e72bfca4d",
         | 
| 768 | 
            +
                        "584ba5977072670e72bfcc2d",
         | 
| 769 | 
            +
                        "584bc53c7072670e72bfe85f",
         | 
| 770 | 
            +
                        "584bc3997072670e72bfe58d",
         | 
| 771 | 
            +
                        "584bc4407072670e72bfe665",
         | 
| 772 | 
            +
                        "584bd5587072670e72bffe39",
         | 
| 773 | 
            +
                        "584bdadf7072670e72c0005c",
         | 
| 774 | 
            +
                        "584be5ed7072670e72c007b3",
         | 
| 775 | 
            +
                        "584c9ad27072670e72c060c5",
         | 
| 776 | 
            +
                        "584c9cc67072670e72c063a1",
         | 
| 777 | 
            +
                        "584c58b77072670e72c03990",
         | 
| 778 | 
            +
                        "584cea557072670e72c07fb4",
         | 
| 779 | 
            +
                        "584d19d47072670e72c0c6c0",
         | 
| 780 | 
            +
                        "584dfe467072670e72c1665a",
         | 
| 781 | 
            +
                        "584e875c7072670e72c1ec94",
         | 
| 782 | 
            +
                        "584e05667072670e72c17167",
         | 
| 783 | 
            +
                        "584f94e87072670e72c2d3f7",
         | 
| 784 | 
            +
                        "584fdffd7072670e72c32dc7",
         | 
| 785 | 
            +
                        "584fe07f7072670e72c32e59",
         | 
| 786 | 
            +
                        "585a2a71b338a62ad50138dc",
         | 
| 787 | 
            +
                        "585a206ab338a62ad501298f",
         | 
| 788 | 
            +
                        "585a217cb338a62ad5012b38",
         | 
| 789 | 
            +
                        "585b34afb338a62ad501e836",
         | 
| 790 | 
            +
                        "585bb25fc49c8507c3ce7812",
         | 
| 791 | 
            +
                        "585bbe55c49c8507c3ce81cd",
         | 
| 792 | 
            +
                        "585d6c8a2a57cc11d4920a1e",
         | 
| 793 | 
            +
                        "585e54c72a57cc11d492f71a",
         | 
| 794 | 
            +
                        "585e34302a57cc11d492be30",
         | 
| 795 | 
            +
                        "585ee0632a57cc11d4933608",
         | 
| 796 | 
            +
                        "585f9661712e2761468dabca",
         | 
| 797 | 
            +
                        "585ffe9a712e2761468df643",
         | 
| 798 | 
            +
                        "586a37ec9d1b5e34c28184fc",
         | 
| 799 | 
            +
                        "586a515a9d1b5e34c281b431",
         | 
| 800 | 
            +
                        "586a94939d1b5e34c2823b5d",
         | 
| 801 | 
            +
                        "586abc689d1b5e34c2826360",
         | 
| 802 | 
            +
                        "586b0e219d1b5e34c2828862",
         | 
| 803 | 
            +
                        "586b3db89d1b5e34c282cd52",
         | 
| 804 | 
            +
                        "586b4c459d1b5e34c282e66d",
         | 
| 805 | 
            +
                        "586b7d7d9d1b5e34c283359e",
         | 
| 806 | 
            +
                        "586b8f149d1b5e34c283497c",
         | 
| 807 | 
            +
                        "586b8f629d1b5e34c28349d6",
         | 
| 808 | 
            +
                        "586c4c4d9d1b5e34c28391a1",
         | 
| 809 | 
            +
                        "586c5b5b9d1b5e34c2839a5b",
         | 
| 810 | 
            +
                        "586c9fdf9d1b5e34c283b657",
         | 
| 811 | 
            +
                        "586c48329d1b5e34c2838e80",
         | 
| 812 | 
            +
                        "586caab99d1b5e34c283c213",
         | 
| 813 | 
            +
                        "586cd0779d1b5e34c28403a7",
         | 
| 814 | 
            +
                        "586d6d249d1b5e34c284b80e",
         | 
| 815 | 
            +
                        "586d8a029d1b5e34c284c948",
         | 
| 816 | 
            +
                        "586d55af9d1b5e34c284a999",
         | 
| 817 | 
            +
                        "586d07869d1b5e34c2842e5b",
         | 
| 818 | 
            +
                        "586d27489d1b5e34c28453af",
         | 
| 819 | 
            +
                        "586df9849d1b5e34c28506de",
         | 
| 820 | 
            +
                        "586e279c9d1b5e34c2852180",
         | 
| 821 | 
            +
                        "587bc5ec2366dd5d06e262c1",
         | 
| 822 | 
            +
                        "587c1abf2366dd5d06e28901",
         | 
| 823 | 
            +
                        "587c03f12366dd5d06e27722",
         | 
| 824 | 
            +
                        "587c19da2366dd5d06e2877b",
         | 
| 825 | 
            +
                        "587c31b92366dd5d06e2a9dc",
         | 
| 826 | 
            +
                        "587c87d02366dd5d06e2f989",
         | 
| 827 | 
            +
                        "587c97a52366dd5d06e30a96",
         | 
| 828 | 
            +
                        "587c45192366dd5d06e2c0eb",
         | 
| 829 | 
            +
                        "587cec702366dd5d06e37862",
         | 
| 830 | 
            +
                        "587cef0a2366dd5d06e379e3",
         | 
| 831 | 
            +
                        "587db5872366dd5d06e3e0af",
         | 
| 832 | 
            +
                        "587e2b1d2366dd5d06e41af0",
         | 
| 833 | 
            +
                        "587e2ea62366dd5d06e41f2e",
         | 
| 834 | 
            +
                        "587e5cb52366dd5d06e4486e",
         | 
| 835 | 
            +
                        "587eb1822366dd5d06e45f29",
         | 
| 836 | 
            +
                        "587f365d2366dd5d06e4906e",
         | 
| 837 | 
            +
                        "588a9c5fec4d5a1c088ec350",
         | 
| 838 | 
            +
                        "588a34cfec4d5a1c088ea8d1",
         | 
| 839 | 
            +
                        "588ab5bdec4d5a1c088ed60f",
         | 
| 840 | 
            +
                        "588aff9d90414422fbe7885a",
         | 
| 841 | 
            +
                        "588b20d290414422fbe79f40",
         | 
| 842 | 
            +
                        "588c08d590414422fbe8200b",
         | 
| 843 | 
            +
                        "588c203d90414422fbe8319e",
         | 
| 844 | 
            +
                        "588c989a90414422fbe86d96",
         | 
| 845 | 
            +
                        "588ca09d90414422fbe871a1",
         | 
| 846 | 
            +
                        "588cce2190414422fbe88520",
         | 
| 847 | 
            +
                        "588cd5ef90414422fbe8875c",
         | 
| 848 | 
            +
                        "588cf0ad90414422fbe8a20f",
         | 
| 849 | 
            +
                        "588e0d8c90414422fbe8f8b2",
         | 
| 850 | 
            +
                        "588e01c490414422fbe8ee2a",
         | 
| 851 | 
            +
                        "588e35e690414422fbe90a53",
         | 
| 852 | 
            +
                        "588f017e90414422fbe9b74b",
         | 
| 853 | 
            +
                        "588f095190414422fbe9c1ee",
         | 
| 854 | 
            +
                        "589aca717dc3d323d55671c4",
         | 
| 855 | 
            +
                        "589af2c97dc3d323d55691e8",
         | 
| 856 | 
            +
                        "589b49ea7dc3d323d556d9b4",
         | 
| 857 | 
            +
                        "589b04287dc3d323d556a185",
         | 
| 858 | 
            +
                        "589bf6a57dc3d323d55743ab",
         | 
| 859 | 
            +
                        "589c3c497dc3d323d5578468",
         | 
| 860 | 
            +
                        "589c3c577dc3d323d5578480",
         | 
| 861 | 
            +
                        "589c300f7dc3d323d5577926",
         | 
| 862 | 
            +
                        "589c24527dc3d323d5577126",
         | 
| 863 | 
            +
                        "589c35457dc3d323d5577d8d",
         | 
| 864 | 
            +
                        "589ca6a6b896147a1b73aff7",
         | 
| 865 | 
            +
                        "589d1e1fb896147a1b73ee5b",
         | 
| 866 | 
            +
                        "589d5c58b896147a1b742256",
         | 
| 867 | 
            +
                        "589d95538fa2cf375df3317b",
         | 
| 868 | 
            +
                        "589df0ffb504a864ad63521a",
         | 
| 869 | 
            +
                        "589ea316b504a864ad639a2b",
         | 
| 870 | 
            +
                        "589ec97cb504a864ad63adc3",
         | 
| 871 | 
            +
                        "589f214338486e3c9846f123",
         | 
| 872 | 
            +
                        "589fdfe738486e3c984736cf",
         | 
| 873 | 
            +
                        "590c2d70336bb52a190be886",
         | 
| 874 | 
            +
                        "590f91851225725be9e25d4e",
         | 
| 875 | 
            +
                        "591a467a6109e14d4f09b776",
         | 
| 876 | 
            +
                        "591cf3033162411cf9047f37",
         | 
| 877 | 
            +
                        "591ea44850991c70dc99a207",
         | 
| 878 | 
            +
                        "599aa591d5b41f366fed0d58",
         | 
| 879 | 
            +
                        "5643df56138263b51db1b5f3",
         | 
| 880 | 
            +
                        "5644bdac138263b51db9f669",
         | 
| 881 | 
            +
                        "5850d4f97072670e72c425d6",
         | 
| 882 | 
            +
                        "5854c405804be105852330fe",
         | 
| 883 | 
            +
                        "5855a4fc804be1058523bd75",
         | 
| 884 | 
            +
                        "5856ac15804be105852419d8",
         | 
| 885 | 
            +
                        "5856ae8b804be10585241bae",
         | 
| 886 | 
            +
                        "5856b460804be10585242059",
         | 
| 887 | 
            +
                        "5857aa5ab338a62ad5ff4dbe",
         | 
| 888 | 
            +
                        "5857acf8b338a62ad5ff5107",
         | 
| 889 | 
            +
                        "5858db6cb338a62ad500103b",
         | 
| 890 | 
            +
                        "5858dbcab338a62ad5001081",
         | 
| 891 | 
            +
                        "5859d84fb338a62ad500e5cf",
         | 
| 892 | 
            +
                        "5861d8ea712e2761468f3cb3",
         | 
| 893 | 
            +
                        "5863edf8712e27614690cce0",
         | 
| 894 | 
            +
                        "5864b076712e27614691197e",
         | 
| 895 | 
            +
                        "5864da88712e276146913d8b",
         | 
| 896 | 
            +
                        "5865f4a8712e27614691e39b",
         | 
| 897 | 
            +
                        "5867a434833dfe3f7b88edaf",
         | 
| 898 | 
            +
                        "5868cd15833dfe3f7b89bfa3",
         | 
| 899 | 
            +
                        "5880b3692366dd5d06e5d534",
         | 
| 900 | 
            +
                        "5880e3422366dd5d06e5ff8e",
         | 
| 901 | 
            +
                        "5880f0ef2366dd5d06e6166e",
         | 
| 902 | 
            +
                        "5881d2bfb6844814c136a119",
         | 
| 903 | 
            +
                        "5881f11d8ce2c2754d0714c3",
         | 
| 904 | 
            +
                        "5881fee18ce2c2754d0723f8",
         | 
| 905 | 
            +
                        "5882cda2b116682b4adebd25",
         | 
| 906 | 
            +
                        "5882d58fb116682b4adec7db",
         | 
| 907 | 
            +
                        "5884c256932ba84fbed70bf5",
         | 
| 908 | 
            +
                        "5884cc13932ba84fbed71ec4",
         | 
| 909 | 
            +
                        "5885bc5296fa095e0671a7f0",
         | 
| 910 | 
            +
                        "5886d14cb791366d617a362c",
         | 
| 911 | 
            +
                        "5888becfc02346100f4b0b21",
         | 
| 912 | 
            +
                        "5888e408c02346100f4b1a29",
         | 
| 913 | 
            +
                        "5889da66ec4d5a1c088e5187",
         | 
| 914 | 
            +
                        "5889e344ec4d5a1c088e59be",
         | 
| 915 | 
            +
                        "5889e754ec4d5a1c088e60ba",
         | 
| 916 | 
            +
                        "5890c16b90414422fbeb0262",
         | 
| 917 | 
            +
                        "5891d8ae9a8c0314c5cd30ab",
         | 
| 918 | 
            +
                        "5891d0479a8c0314c5cd2abd",
         | 
| 919 | 
            +
                        "5891ecf19a8c0314c5cd490a",
         | 
| 920 | 
            +
                        "5892c0cd9a8c0314c5cdc977",
         | 
| 921 | 
            +
                        "5894ab309a8c0314c5cee57d",
         | 
| 922 | 
            +
                        "5895a6a89a8c0314c5cfca7c",
         | 
| 923 | 
            +
                        "5895b8c29a8c0314c5cfd051",
         | 
| 924 | 
            +
                        "5895d38f9a8c0314c5cfe50c",
         | 
| 925 | 
            +
                        "5895f2329a8c0314c5d00117",
         | 
| 926 | 
            +
                        "5896bb989a8c0314c5d086b6",
         | 
| 927 | 
            +
                        "5896ebf39a8c0314c5d0a8c4",
         | 
| 928 | 
            +
                        "5898b1bac9dccc22987b7f74",
         | 
| 929 | 
            +
                        "5898b6ffc9dccc22987b8a03",
         | 
| 930 | 
            +
                        "5898b31cc9dccc22987b82ec",
         | 
| 931 | 
            +
                        "5898bbaac9dccc22987b8eba",
         | 
| 932 | 
            +
                        "5899cfa6b76d7a3780a4cb64",
         | 
| 933 | 
            +
                        "5899e5dcb76d7a3780a4ecc1",
         | 
| 934 | 
            +
                        "5947b62af1b45630bd0c2a02",
         | 
| 935 | 
            +
                        "57102be2877e1421026358af",
         | 
| 936 | 
            +
                        "57153d4031bb9900425bde85",
         | 
| 937 | 
            +
                        "57177cd7fb8d93461afc4527",
         | 
| 938 | 
            +
                        "58497cdf97b73e0b090c4273",
         | 
| 939 | 
            +
                        "58500b007072670e72c35588",
         | 
| 940 | 
            +
                        "58510bf97072670e72c46ddf",
         | 
| 941 | 
            +
                        "58522bd56789802282f2ecb3",
         | 
| 942 | 
            +
                        "58524a2e0e7012308944bcf3",
         | 
| 943 | 
            +
                        "58524a080e7012308944bcbf",
         | 
| 944 | 
            +
                        "58524c1d0e7012308944bfda",
         | 
| 945 | 
            +
                        "58524f170e7012308944c200",
         | 
| 946 | 
            +
                        "58529a4e0e70123089454c6f",
         | 
| 947 | 
            +
                        "58551bdf804be1058523556d",
         | 
| 948 | 
            +
                        "58568c9a804be10585240b03",
         | 
| 949 | 
            +
                        "58574b35804be105852455fd",
         | 
| 950 | 
            +
                        "58577c60b338a62ad5ff1564",
         | 
| 951 | 
            +
                        "58592d69b338a62ad5007a74",
         | 
| 952 | 
            +
                        "58598db2b338a62ad500bc38",
         | 
| 953 | 
            +
                        "58625f42712e2761468fb44c",
         | 
| 954 | 
            +
                        "58651bcc712e2761469166dc",
         | 
| 955 | 
            +
                        "58660e79712e27614691fe3d",
         | 
| 956 | 
            +
                        "58669aad712e27614692834c",
         | 
| 957 | 
            +
                        "58669c02712e27614692851a",
         | 
| 958 | 
            +
                        "58676c36833dfe3f7b88b7f2",
         | 
| 959 | 
            +
                        "58678b2d833dfe3f7b88e244",
         | 
| 960 | 
            +
                        "58790c82ce911104a3467c88",
         | 
| 961 | 
            +
                        "58800b0b2366dd5d06e5312d",
         | 
| 962 | 
            +
                        "58805eac2366dd5d06e56460",
         | 
| 963 | 
            +
                        "58806e422366dd5d06e57bb6",
         | 
| 964 | 
            +
                        "58831d060db9bf59bf8ab98b",
         | 
| 965 | 
            +
                        "58851ebb932ba84fbed7abad",
         | 
| 966 | 
            +
                        "58871dc3b791366d617a55ff",
         | 
| 967 | 
            +
                        "58873cabb791366d617a65a7",
         | 
| 968 | 
            +
                        "58873d44b791366d617a65dd",
         | 
| 969 | 
            +
                        "58888b3dc02346100f4af665",
         | 
| 970 | 
            +
                        "58897f62c02346100f4b8ee6",
         | 
| 971 | 
            +
                        "58933bac9a8c0314c5ce3508",
         | 
| 972 | 
            +
                        "58938e6d9a8c0314c5ce726f",
         | 
| 973 | 
            +
                        "58951cb49a8c0314c5cf4d5e",
         | 
| 974 | 
            +
                        "58970fd09a8c0314c5d0e383",
         | 
| 975 | 
            +
                        "58977ef09a8c0314c5d17b26",
         | 
| 976 | 
            +
                        "59056e6760bb961de55f3501",
         | 
| 977 | 
            +
                        "59071f2e5a6dbd3af4130f98",
         | 
| 978 | 
            +
                        "59102c811225725be9e64149",
         | 
| 979 | 
            +
                        "59338e76772c3e6384afbb15",
         | 
| 980 | 
            +
                        "59350ca084b7f26bf5ce6eb8",
         | 
| 981 | 
            +
                        "59397e493a87372f2c9e882b",
         | 
| 982 | 
            +
                        "59521e0b9096412211c2aa9d",
         | 
| 983 | 
            +
                        "59817e4a1bd4b175e7038d19",
         | 
| 984 | 
            +
                        "567884f58d2828b95e3c8eba",
         | 
| 985 | 
            +
                        "585559d9804be10585238ddf",
         | 
| 986 | 
            +
                        "585834cdb338a62ad5ffab4d",
         | 
| 987 | 
            +
                        "586082d8712e2761468e2877",
         | 
| 988 | 
            +
                        "586133c2712e2761468ecfe3",
         | 
| 989 | 
            +
                        "586281d2712e2761468fcaa2",
         | 
| 990 | 
            +
                        "586316e5712e276146903c4d",
         | 
| 991 | 
            +
                        "586326ad712e276146904571",
         | 
| 992 | 
            +
                        "586375c9712e276146907429",
         | 
| 993 | 
            +
                        "586389c9712e276146908da6",
         | 
| 994 | 
            +
                        "586496fa712e2761469108e7",
         | 
| 995 | 
            +
                        "586669c6712e27614692597a",
         | 
| 996 | 
            +
                        "586913a49d1b5e34c2808b02",
         | 
| 997 | 
            +
                        "586922da9d1b5e34c2809ff3",
         | 
| 998 | 
            +
                        "588185d8dfb7a15588a114a3",
         | 
| 999 | 
            +
                        "588305ed0db9bf59bf8a8c80",
         | 
| 1000 | 
            +
                        "588315c60db9bf59bf8aa928",
         | 
| 1001 | 
            +
                        "588332ee0db9bf59bf8ae9c3",
         | 
| 1002 | 
            +
                        "588457b8932ba84fbed69942",
         | 
| 1003 | 
            +
                        "588519d5932ba84fbed7a04a",
         | 
| 1004 | 
            +
                        "588824d1b791366d617adeef",
         | 
| 1005 | 
            +
                        "588857f6c02346100f4ac09f",
         | 
| 1006 | 
            +
                        "589145ef90414422fbeb2e08",
         | 
| 1007 | 
            +
                        "589433fa9a8c0314c5ce9656",
         | 
| 1008 | 
            +
                        "589765d39a8c0314c5d16b12",
         | 
| 1009 | 
            +
                        "5851165f7072670e72c4860d",
         | 
| 1010 | 
            +
                        "5859341ab338a62ad500848d",
         | 
| 1011 | 
            +
                        "5862388b712e2761468f84aa",
         | 
| 1012 | 
            +
                        "5863915b712e276146909135",
         | 
| 1013 | 
            +
                        "5866445b712e27614692383e",
         | 
| 1014 | 
            +
                        "5866500d712e2761469240fd",
         | 
| 1015 | 
            +
                        "5867785a833dfe3f7b88c764",
         | 
| 1016 | 
            +
                        "5867969c833dfe3f7b88e8bc",
         | 
| 1017 | 
            +
                        "5868040c833dfe3f7b8934f7",
         | 
| 1018 | 
            +
                        "5880675a2366dd5d06e570ca",
         | 
| 1019 | 
            +
                        "5882372c8ce2c2754d076af0",
         | 
| 1020 | 
            +
                        "5883535e932ba84fbed5ad07",
         | 
| 1021 | 
            +
                        "5888358cb791366d617af69d",
         | 
| 1022 | 
            +
                        "5890330d90414422fbeaa0cb",
         | 
| 1023 | 
            +
                        "5897076e9a8c0314c5d0d31b",
         | 
| 1024 | 
            +
                        "5940564ec2d9527ab869f7e2",
         | 
| 1025 | 
            +
                        "5947719bf1b45630bd096665",
         | 
| 1026 | 
            +
                        "5948194ff1b45630bd0f47e3",
         | 
| 1027 | 
            +
                        "5950206a41b158666ac50506",
         | 
| 1028 | 
            +
                        "5983012d1bd4b175e70c985a",
         | 
| 1029 | 
            +
                        "58586810b338a62ad5ffc20c",
         | 
| 1030 | 
            +
                        "58592046b338a62ad5006b33",
         | 
| 1031 | 
            +
                        "58592854b338a62ad500750a",
         | 
| 1032 | 
            +
                        "58596531b338a62ad500aace",
         | 
| 1033 | 
            +
                        "58818685dfb7a15588a11626",
         | 
| 1034 | 
            +
                        "58829563f42b1d3ee3ec835f",
         | 
| 1035 | 
            +
                        "58894345c02346100f4b51ca",
         | 
| 1036 | 
            +
                        "585289980e7012308945276a",
         | 
| 1037 | 
            +
                        "585369770e7012308945c709",
         | 
| 1038 | 
            +
                        "585373640e7012308945cab9",
         | 
| 1039 | 
            +
                        "588230658ce2c2754d076728",
         | 
| 1040 | 
            +
                        "589388059a8c0314c5ce718b",
         | 
| 1041 | 
            +
                        "595979485ec6a95e86a58c8d",
         | 
| 1042 | 
            +
                        "5841206219d291325678ca90",
         | 
| 1043 | 
            +
                        "58563650804be1058523da55",
         | 
| 1044 | 
            +
                        "58564084804be1058523e116",
         | 
| 1045 | 
            +
                        "58636467712e27614690661f",
         | 
| 1046 | 
            +
                        "58647495712e27614690f36d",
         | 
| 1047 | 
            +
                        "58654563712e276146918643",
         | 
| 1048 | 
            +
                        "58664251712e276146923738",
         | 
| 1049 | 
            +
                        "588084032366dd5d06e59e82",
         | 
| 1050 | 
            +
                        "588159582366dd5d06e66877",
         | 
| 1051 | 
            +
                        "5890279190414422fbea9734",
         | 
| 1052 | 
            +
                        "5890523090414422fbeab3f0",
         | 
| 1053 | 
            +
                        "5890641690414422fbeabbe7",
         | 
| 1054 | 
            +
                        "585203546789802282f2aaf5",
         | 
| 1055 | 
            +
                    ]
         | 
| 1056 | 
            +
             | 
| 1057 | 
            +
                    # Train set sequences after filtering
         | 
| 1058 | 
            +
                    self.train_split_scenes = [
         | 
| 1059 | 
            +
                        "00000000000000000000000b",
         | 
| 1060 | 
            +
                        "00000000000000000000000d",
         | 
| 1061 | 
            +
                        "00000000000000000000000e",
         | 
| 1062 | 
            +
                        "00000000000000000000000f",
         | 
| 1063 | 
            +
                        "000000000000000000000001",
         | 
| 1064 | 
            +
                        "00000000000000000000001b",
         | 
| 1065 | 
            +
                        "00000000000000000000001d",
         | 
| 1066 | 
            +
                        "000000000000000000000002",
         | 
| 1067 | 
            +
                        "000000000000000000000003",
         | 
| 1068 | 
            +
                        "000000000000000000000004",
         | 
| 1069 | 
            +
                        "000000000000000000000005",
         | 
| 1070 | 
            +
                        "5a2a95f032a1c655cfe3de62",
         | 
| 1071 | 
            +
                        "5a2af22b32a1c655cfe46013",
         | 
| 1072 | 
            +
                        "5a2ba6de32a1c655cfe51b79",
         | 
| 1073 | 
            +
                        "5a3b9731e24cd76dad1a5f1b",
         | 
| 1074 | 
            +
                        "5a3ca9cb270f0e3f14d0eddb",
         | 
| 1075 | 
            +
                        "5a3cb4e4270f0e3f14d12f43",
         | 
| 1076 | 
            +
                        "5a03e732454a8a7ec672776c",
         | 
| 1077 | 
            +
                        "5a3f4aba5889373fbbc5d3b5",
         | 
| 1078 | 
            +
                        "5a5a1e48d62c7a12d5d00e47",
         | 
| 1079 | 
            +
                        "5a6b1c418d100c2f8fdc4411",
         | 
| 1080 | 
            +
                        "5a6feeb54a7fbc3f874f9db7",
         | 
| 1081 | 
            +
                        "5a7cb1d6fe5c0d6fb53e64fb",
         | 
| 1082 | 
            +
                        "5a7d3db14989e929563eb153",
         | 
| 1083 | 
            +
                        "5a8aa0fab18050187cbe060e",
         | 
| 1084 | 
            +
                        "5a9e5df65baeef72b4a021cd",
         | 
| 1085 | 
            +
                        "5a48ba95c7dab83a7d7b44ed",
         | 
| 1086 | 
            +
                        "5a48c4e9c7dab83a7d7b5cc7",
         | 
| 1087 | 
            +
                        "5a48d4b2c7dab83a7d7b9851",
         | 
| 1088 | 
            +
                        "5a69c47d0d5d0a7f3b2e9752",
         | 
| 1089 | 
            +
                        "5a77b46b318efe6c6736e68a",
         | 
| 1090 | 
            +
                        "5a355c271b63f53d5970f362",
         | 
| 1091 | 
            +
                        "5a533e8034d7582116e34209",
         | 
| 1092 | 
            +
                        "5a562fc7425d0f5186314725",
         | 
| 1093 | 
            +
                        "5a618c72784780334bc1972d",
         | 
| 1094 | 
            +
                        "5a752d42acc41e2423f17674",
         | 
| 1095 | 
            +
                        "5a969eea91dfc339a9a3ad2c",
         | 
| 1096 | 
            +
                        "5a8315f624b8e938486e0bd8",
         | 
| 1097 | 
            +
                        "5a57542f333d180827dfc132",
         | 
| 1098 | 
            +
                        "5a0271884e62597cdee0d0eb",
         | 
| 1099 | 
            +
                        "5a6400933d809f1d8200af15",
         | 
| 1100 | 
            +
                        "5a6464143d809f1d8208c43c",
         | 
| 1101 | 
            +
                        "5a563183425d0f5186314855",
         | 
| 1102 | 
            +
                        "5aa0f9d7a9efce63548c69a1",
         | 
| 1103 | 
            +
                        "5aa7db90bfdd572271e95246",
         | 
| 1104 | 
            +
                        "5aa235f64a17b335eeaf9609",
         | 
| 1105 | 
            +
                        "5aa515e613d42d091d29d300",
         | 
| 1106 | 
            +
                        "5aa1196ea9efce63548ed649",
         | 
| 1107 | 
            +
                        "5aaadd4cbc13235570d178a7",
         | 
| 1108 | 
            +
                        "5ab6af12ac4291329b1072ab",
         | 
| 1109 | 
            +
                        "5ab7e00aac4291329b15864d",
         | 
| 1110 | 
            +
                        "5ab8b8e029f5351f7f2ccf59",
         | 
| 1111 | 
            +
                        "5ab74bf2ac4291329b11e879",
         | 
| 1112 | 
            +
                        "5ab85f1dac4291329b17cb50",
         | 
| 1113 | 
            +
                        "5ab8713ba3799a1d138bd69a",
         | 
| 1114 | 
            +
                        "5abc2506b53b042ead637d86",
         | 
| 1115 | 
            +
                        "5acc7459a7853c4b5ebbef59",
         | 
| 1116 | 
            +
                        "5acf8ca0f3d8a750097e4b15",
         | 
| 1117 | 
            +
                        "5adc6bd52430a05ecb2ffb85",
         | 
| 1118 | 
            +
                        "5af02e904c8216544b4ab5a2",
         | 
| 1119 | 
            +
                        "5af28cea59bc705737003253",
         | 
| 1120 | 
            +
                        "5af545d0559359053d25dcf5",
         | 
| 1121 | 
            +
                        "5afacb69ab00705d0cefdd5b",
         | 
| 1122 | 
            +
                        "5b3b2b9e8d46a939f933fdc0",
         | 
| 1123 | 
            +
                        "5b3b353d8d46a939f93524b9",
         | 
| 1124 | 
            +
                        "5b6e716d67b396324c2d77cb",
         | 
| 1125 | 
            +
                        "5b6eff8b67b396324c5b2672",
         | 
| 1126 | 
            +
                        "5b7a3890fc8fcf6781e2593a",
         | 
| 1127 | 
            +
                        "5b60fa0c764f146feef84df0",
         | 
| 1128 | 
            +
                        "5b69cc0cb44b61786eb959bf",
         | 
| 1129 | 
            +
                        "5b78e57afc8fcf6781d0c3ba",
         | 
| 1130 | 
            +
                        "5b192eb2170cf166458ff886",
         | 
| 1131 | 
            +
                        "5b558a928bbfb62204e77ba2",
         | 
| 1132 | 
            +
                        "5b908d3dc6ab78485f3d24a9",
         | 
| 1133 | 
            +
                        "5b950c71608de421b1e7318f",
         | 
| 1134 | 
            +
                        "5b08286b2775267d5b0634ba",
         | 
| 1135 | 
            +
                        "5b271079e0878c3816dacca4",
         | 
| 1136 | 
            +
                        "5b22269758e2823a67a3bd03",
         | 
| 1137 | 
            +
                        "5b62647143840965efc0dbde",
         | 
| 1138 | 
            +
                        "5ba19a8a360c7c30c1c169df",
         | 
| 1139 | 
            +
                        "5ba75d79d76ffa2c86cf2f05",
         | 
| 1140 | 
            +
                        "5bb7a08aea1cfa39f1a947ab",
         | 
| 1141 | 
            +
                        "5bb8a49aea1cfa39f1aa7f75",
         | 
| 1142 | 
            +
                        "5bbb6eb2ea1cfa39f1af7e0c",
         | 
| 1143 | 
            +
                        "5bce7ac9ca24970bce4934b6",
         | 
| 1144 | 
            +
                        "5bcf979a6d5f586b95c258cd",
         | 
| 1145 | 
            +
                        "5bd43b4ba6b28b1ee86b92dd",
         | 
| 1146 | 
            +
                        "5be3a5fb8cfdd56947f6b67c",
         | 
| 1147 | 
            +
                        "5be3ae47f44e235bdbbc9771",
         | 
| 1148 | 
            +
                        "5be4ab93870d330ff2dce134",
         | 
| 1149 | 
            +
                        "5be47bf9b18881428d8fbc1d",
         | 
| 1150 | 
            +
                        "5be883a4f98cee15019d5b83",
         | 
| 1151 | 
            +
                        "5bea87f4abd34c35e1860ab5",
         | 
| 1152 | 
            +
                        "5beb6e66abd34c35e18e66b9",
         | 
| 1153 | 
            +
                        "5bf3a82cd439231948877aed",
         | 
| 1154 | 
            +
                        "5bf7d63575c26f32dbf7413b",
         | 
| 1155 | 
            +
                        "5bf17c0fd439231948355385",
         | 
| 1156 | 
            +
                        "5bf21799d43923194842c001",
         | 
| 1157 | 
            +
                        "5bfd0f32ec61ca1dd69dc77b",
         | 
| 1158 | 
            +
                        "5bfe5ae0fe0ea555e6a969ca",
         | 
| 1159 | 
            +
                        "5c0d13b795da9479e12e2ee9",
         | 
| 1160 | 
            +
                        "5c1af2e2bee9a723c963d019",
         | 
| 1161 | 
            +
                        "5c1b1500bee9a723c96c3e78",
         | 
| 1162 | 
            +
                        "5c1dbf200843bc542d8ef8c4",
         | 
| 1163 | 
            +
                        "5c20ca3a0843bc542d94e3e2",
         | 
| 1164 | 
            +
                        "5c062d84a96e33018ff6f0a6",
         | 
| 1165 | 
            +
                        "5c189f2326173c3a09ed7ef3",
         | 
| 1166 | 
            +
                        "5c1892f726173c3a09ea9aeb",
         | 
| 1167 | 
            +
                        "5c34300a73a8df509add216d",
         | 
| 1168 | 
            +
                        "000000000000000000000006",
         | 
| 1169 | 
            +
                        "000000000000000000000007",
         | 
| 1170 | 
            +
                        "000000000000000000000008",
         | 
| 1171 | 
            +
                        "000000000000000000000009",
         | 
| 1172 | 
            +
                        "000000000000000000000010",
         | 
| 1173 | 
            +
                        "000000000000000000000011",
         | 
| 1174 | 
            +
                        "000000000000000000000012",
         | 
| 1175 | 
            +
                        "000000000000000000000015",
         | 
| 1176 | 
            +
                        "000000000000000000000016",
         | 
| 1177 | 
            +
                        "000000000000000000000017",
         | 
| 1178 | 
            +
                        "000000000000000000000018",
         | 
| 1179 | 
            +
                        "000000000000000000000019",
         | 
| 1180 | 
            +
                        "56d73ba74bd29b8c35abade2",
         | 
| 1181 | 
            +
                        "56f34064e296120e10484dc4",
         | 
| 1182 | 
            +
                        "57a4a7bb6b9272286e26dc18",
         | 
| 1183 | 
            +
                        "57f8d9bbe73f6760f10e916a",
         | 
| 1184 | 
            +
                        "58a0a2f33d0b4542479a11b1",
         | 
| 1185 | 
            +
                        "58a0dd1a3d0b4542479a28f3",
         | 
| 1186 | 
            +
                        "58a1a7914a4d262a170b1101",
         | 
| 1187 | 
            +
                        "58a1bc804a4d262a170b2f01",
         | 
| 1188 | 
            +
                        "58a1d9d14a4d262a170b58fe",
         | 
| 1189 | 
            +
                        "58a01dea38486e3c98475871",
         | 
| 1190 | 
            +
                        "58a1f5d74a4d262a170b65fc",
         | 
| 1191 | 
            +
                        "58a2a09e156b87103d3d668c",
         | 
| 1192 | 
            +
                        "58a2d9c3156b87103d3da90f",
         | 
| 1193 | 
            +
                        "58a3ccb0156b87103d3e4332",
         | 
| 1194 | 
            +
                        "58a3f2f8156b87103d3e5838",
         | 
| 1195 | 
            +
                        "58a3f6c0156b87103d3e5971",
         | 
| 1196 | 
            +
                        "58a3fc95156b87103d3e5d9b",
         | 
| 1197 | 
            +
                        "58a07ce53d0b45424799fdde",
         | 
| 1198 | 
            +
                        "58a07f233d0b45424799ffe7",
         | 
| 1199 | 
            +
                        "58a44df2156b87103d3ee239",
         | 
| 1200 | 
            +
                        "58a164f73d0b4542479a7a8e",
         | 
| 1201 | 
            +
                        "58a0365e38486e3c984783eb",
         | 
| 1202 | 
            +
                        "58a439cf156b87103d3ec885",
         | 
| 1203 | 
            +
                        "58a464aa156b87103d3eec04",
         | 
| 1204 | 
            +
                        "58a4452f156b87103d3ed55b",
         | 
| 1205 | 
            +
                        "58a160983d0b4542479a7347",
         | 
| 1206 | 
            +
                        "58a285424a4d262a170baf3e",
         | 
| 1207 | 
            +
                        "58a41819156b87103d3e92a5",
         | 
| 1208 | 
            +
                        "58a47552156b87103d3f00a4",
         | 
| 1209 | 
            +
                        "58c4bb4f4a69c55606122be4",
         | 
| 1210 | 
            +
                        "58c6451e4a69c556061894f1",
         | 
| 1211 | 
            +
                        "58ca7014affdfd07c70a95ce",
         | 
| 1212 | 
            +
                        "58cf4771d0f5fb221defe6da",
         | 
| 1213 | 
            +
                        "58d36897f387231e6c929903",
         | 
| 1214 | 
            +
                        "58eaf1513353456af3a1682a",
         | 
| 1215 | 
            +
                        "58f73e7c9f5b56478738929f",
         | 
| 1216 | 
            +
                        "59a8f851597729752c31e7e0",
         | 
| 1217 | 
            +
                        "59a452bf9b460239aa5d1c72",
         | 
| 1218 | 
            +
                        "59a9619a825418241fb88191",
         | 
| 1219 | 
            +
                        "59bf97fe7e7b31545da34439",
         | 
| 1220 | 
            +
                        "59c1c3e2fd6e3d4ead9f1013",
         | 
| 1221 | 
            +
                        "59d2657f82ca7774b1ec081d",
         | 
| 1222 | 
            +
                        "59da1fb88a126011d0394ae9",
         | 
| 1223 | 
            +
                        "59e75a2ca9e91f2c5526005d",
         | 
| 1224 | 
            +
                        "59e864b2a9e91f2c5529325f",
         | 
| 1225 | 
            +
                        "59ecfd02e225f6492d20fcc9",
         | 
| 1226 | 
            +
                        "59f37f74b45be2233001ba18",
         | 
| 1227 | 
            +
                        "59f70ab1e5c5d366af29bf3e",
         | 
| 1228 | 
            +
                        "59f363a8b45be22330016cad",
         | 
| 1229 | 
            +
                        "564a27b26d07883f460d8ab0",
         | 
| 1230 | 
            +
                        "565fb1dead14d4154dae2b94",
         | 
| 1231 | 
            +
                        "569b92eb826bcba945ca002b",
         | 
| 1232 | 
            +
                        "576fefa017ce5a16397e87fd",
         | 
| 1233 | 
            +
                        "584a7333fe3cb463906c9fe6",
         | 
| 1234 | 
            +
                        "584aa8e9fe3cb463906cc7d0",
         | 
| 1235 | 
            +
                        "584af003fe3cb463906d0e9b",
         | 
| 1236 | 
            +
                        "584b9a747072670e72bfc49d",
         | 
| 1237 | 
            +
                        "584b671f7072670e72bfaaf8",
         | 
| 1238 | 
            +
                        "584b81747072670e72bfbbfd",
         | 
| 1239 | 
            +
                        "584ba35f7072670e72bfca4d",
         | 
| 1240 | 
            +
                        "584ba5977072670e72bfcc2d",
         | 
| 1241 | 
            +
                        "584bc53c7072670e72bfe85f",
         | 
| 1242 | 
            +
                        "584bc3997072670e72bfe58d",
         | 
| 1243 | 
            +
                        "584bc4407072670e72bfe665",
         | 
| 1244 | 
            +
                        "584bd5587072670e72bffe39",
         | 
| 1245 | 
            +
                        "584bdadf7072670e72c0005c",
         | 
| 1246 | 
            +
                        "584be5ed7072670e72c007b3",
         | 
| 1247 | 
            +
                        "584c9ad27072670e72c060c5",
         | 
| 1248 | 
            +
                        "584c9cc67072670e72c063a1",
         | 
| 1249 | 
            +
                        "584cea557072670e72c07fb4",
         | 
| 1250 | 
            +
                        "584d19d47072670e72c0c6c0",
         | 
| 1251 | 
            +
                        "584dfe467072670e72c1665a",
         | 
| 1252 | 
            +
                        "584e875c7072670e72c1ec94",
         | 
| 1253 | 
            +
                        "584e05667072670e72c17167",
         | 
| 1254 | 
            +
                        "584f94e87072670e72c2d3f7",
         | 
| 1255 | 
            +
                        "584fdffd7072670e72c32dc7",
         | 
| 1256 | 
            +
                        "584fe07f7072670e72c32e59",
         | 
| 1257 | 
            +
                        "585a2a71b338a62ad50138dc",
         | 
| 1258 | 
            +
                        "585a206ab338a62ad501298f",
         | 
| 1259 | 
            +
                        "585a217cb338a62ad5012b38",
         | 
| 1260 | 
            +
                        "585b34afb338a62ad501e836",
         | 
| 1261 | 
            +
                        "585bb25fc49c8507c3ce7812",
         | 
| 1262 | 
            +
                        "585bbe55c49c8507c3ce81cd",
         | 
| 1263 | 
            +
                        "585d6c8a2a57cc11d4920a1e",
         | 
| 1264 | 
            +
                        "585e54c72a57cc11d492f71a",
         | 
| 1265 | 
            +
                        "585e34302a57cc11d492be30",
         | 
| 1266 | 
            +
                        "585ee0632a57cc11d4933608",
         | 
| 1267 | 
            +
                        "585f9661712e2761468dabca",
         | 
| 1268 | 
            +
                        "585ffe9a712e2761468df643",
         | 
| 1269 | 
            +
                        "586a37ec9d1b5e34c28184fc",
         | 
| 1270 | 
            +
                        "586a515a9d1b5e34c281b431",
         | 
| 1271 | 
            +
                        "586a94939d1b5e34c2823b5d",
         | 
| 1272 | 
            +
                        "586abc689d1b5e34c2826360",
         | 
| 1273 | 
            +
                        "586b0e219d1b5e34c2828862",
         | 
| 1274 | 
            +
                        "586b3db89d1b5e34c282cd52",
         | 
| 1275 | 
            +
                        "586b4c459d1b5e34c282e66d",
         | 
| 1276 | 
            +
                        "586b7d7d9d1b5e34c283359e",
         | 
| 1277 | 
            +
                        "586b8f149d1b5e34c283497c",
         | 
| 1278 | 
            +
                        "586b8f629d1b5e34c28349d6",
         | 
| 1279 | 
            +
                        "586c4c4d9d1b5e34c28391a1",
         | 
| 1280 | 
            +
                        "586c5b5b9d1b5e34c2839a5b",
         | 
| 1281 | 
            +
                        "586c9fdf9d1b5e34c283b657",
         | 
| 1282 | 
            +
                        "586caab99d1b5e34c283c213",
         | 
| 1283 | 
            +
                        "586cd0779d1b5e34c28403a7",
         | 
| 1284 | 
            +
                        "586d6d249d1b5e34c284b80e",
         | 
| 1285 | 
            +
                        "586d8a029d1b5e34c284c948",
         | 
| 1286 | 
            +
                        "586d55af9d1b5e34c284a999",
         | 
| 1287 | 
            +
                        "586d07869d1b5e34c2842e5b",
         | 
| 1288 | 
            +
                        "586d27489d1b5e34c28453af",
         | 
| 1289 | 
            +
                        "586e279c9d1b5e34c2852180",
         | 
| 1290 | 
            +
                        "587bc5ec2366dd5d06e262c1",
         | 
| 1291 | 
            +
                        "587c1abf2366dd5d06e28901",
         | 
| 1292 | 
            +
                        "587c03f12366dd5d06e27722",
         | 
| 1293 | 
            +
                        "587c19da2366dd5d06e2877b",
         | 
| 1294 | 
            +
                        "587c31b92366dd5d06e2a9dc",
         | 
| 1295 | 
            +
                        "587c87d02366dd5d06e2f989",
         | 
| 1296 | 
            +
                        "587c97a52366dd5d06e30a96",
         | 
| 1297 | 
            +
                        "587c45192366dd5d06e2c0eb",
         | 
| 1298 | 
            +
                        "587cec702366dd5d06e37862",
         | 
| 1299 | 
            +
                        "587cef0a2366dd5d06e379e3",
         | 
| 1300 | 
            +
                        "587db5872366dd5d06e3e0af",
         | 
| 1301 | 
            +
                        "587e2b1d2366dd5d06e41af0",
         | 
| 1302 | 
            +
                        "587e2ea62366dd5d06e41f2e",
         | 
| 1303 | 
            +
                        "587e5cb52366dd5d06e4486e",
         | 
| 1304 | 
            +
                        "587eb1822366dd5d06e45f29",
         | 
| 1305 | 
            +
                        "587f365d2366dd5d06e4906e",
         | 
| 1306 | 
            +
                        "588a9c5fec4d5a1c088ec350",
         | 
| 1307 | 
            +
                        "588a34cfec4d5a1c088ea8d1",
         | 
| 1308 | 
            +
                        "588ab5bdec4d5a1c088ed60f",
         | 
| 1309 | 
            +
                        "588aff9d90414422fbe7885a",
         | 
| 1310 | 
            +
                        "588b20d290414422fbe79f40",
         | 
| 1311 | 
            +
                        "588c08d590414422fbe8200b",
         | 
| 1312 | 
            +
                        "588c203d90414422fbe8319e",
         | 
| 1313 | 
            +
                        "588c989a90414422fbe86d96",
         | 
| 1314 | 
            +
                        "588ca09d90414422fbe871a1",
         | 
| 1315 | 
            +
                        "588cce2190414422fbe88520",
         | 
| 1316 | 
            +
                        "588cd5ef90414422fbe8875c",
         | 
| 1317 | 
            +
                        "588cf0ad90414422fbe8a20f",
         | 
| 1318 | 
            +
                        "588e01c490414422fbe8ee2a",
         | 
| 1319 | 
            +
                        "588e35e690414422fbe90a53",
         | 
| 1320 | 
            +
                        "588f017e90414422fbe9b74b",
         | 
| 1321 | 
            +
                        "588f095190414422fbe9c1ee",
         | 
| 1322 | 
            +
                        "589aca717dc3d323d55671c4",
         | 
| 1323 | 
            +
                        "589af2c97dc3d323d55691e8",
         | 
| 1324 | 
            +
                        "589b49ea7dc3d323d556d9b4",
         | 
| 1325 | 
            +
                        "589b04287dc3d323d556a185",
         | 
| 1326 | 
            +
                        "589bf6a57dc3d323d55743ab",
         | 
| 1327 | 
            +
                        "589c3c497dc3d323d5578468",
         | 
| 1328 | 
            +
                        "589c3c577dc3d323d5578480",
         | 
| 1329 | 
            +
                        "589c24527dc3d323d5577126",
         | 
| 1330 | 
            +
                        "589c35457dc3d323d5577d8d",
         | 
| 1331 | 
            +
                        "589ca6a6b896147a1b73aff7",
         | 
| 1332 | 
            +
                        "589d1e1fb896147a1b73ee5b",
         | 
| 1333 | 
            +
                        "589d5c58b896147a1b742256",
         | 
| 1334 | 
            +
                        "589d95538fa2cf375df3317b",
         | 
| 1335 | 
            +
                        "589df0ffb504a864ad63521a",
         | 
| 1336 | 
            +
                        "589ea316b504a864ad639a2b",
         | 
| 1337 | 
            +
                        "589ec97cb504a864ad63adc3",
         | 
| 1338 | 
            +
                        "589f214338486e3c9846f123",
         | 
| 1339 | 
            +
                        "589fdfe738486e3c984736cf",
         | 
| 1340 | 
            +
                        "590c2d70336bb52a190be886",
         | 
| 1341 | 
            +
                        "591a467a6109e14d4f09b776",
         | 
| 1342 | 
            +
                        "591cf3033162411cf9047f37",
         | 
| 1343 | 
            +
                        "591ea44850991c70dc99a207",
         | 
| 1344 | 
            +
                        "599aa591d5b41f366fed0d58",
         | 
| 1345 | 
            +
                        "5643df56138263b51db1b5f3",
         | 
| 1346 | 
            +
                        "5644bdac138263b51db9f669",
         | 
| 1347 | 
            +
                        "5850d4f97072670e72c425d6",
         | 
| 1348 | 
            +
                        "5854c405804be105852330fe",
         | 
| 1349 | 
            +
                        "5855a4fc804be1058523bd75",
         | 
| 1350 | 
            +
                        "5856ac15804be105852419d8",
         | 
| 1351 | 
            +
                        "5856ae8b804be10585241bae",
         | 
| 1352 | 
            +
                        "5856b460804be10585242059",
         | 
| 1353 | 
            +
                        "5857aa5ab338a62ad5ff4dbe",
         | 
| 1354 | 
            +
                        "5857acf8b338a62ad5ff5107",
         | 
| 1355 | 
            +
                        "5858db6cb338a62ad500103b",
         | 
| 1356 | 
            +
                        "5858dbcab338a62ad5001081",
         | 
| 1357 | 
            +
                        "5859d84fb338a62ad500e5cf",
         | 
| 1358 | 
            +
                        "5861d8ea712e2761468f3cb3",
         | 
| 1359 | 
            +
                        "5863edf8712e27614690cce0",
         | 
| 1360 | 
            +
                        "5864b076712e27614691197e",
         | 
| 1361 | 
            +
                        "5864da88712e276146913d8b",
         | 
| 1362 | 
            +
                        "5865f4a8712e27614691e39b",
         | 
| 1363 | 
            +
                        "5867a434833dfe3f7b88edaf",
         | 
| 1364 | 
            +
                        "5868cd15833dfe3f7b89bfa3",
         | 
| 1365 | 
            +
                        "5880b3692366dd5d06e5d534",
         | 
| 1366 | 
            +
                        "5880e3422366dd5d06e5ff8e",
         | 
| 1367 | 
            +
                        "5880f0ef2366dd5d06e6166e",
         | 
| 1368 | 
            +
                        "5881d2bfb6844814c136a119",
         | 
| 1369 | 
            +
                        "5881f11d8ce2c2754d0714c3",
         | 
| 1370 | 
            +
                        "5881fee18ce2c2754d0723f8",
         | 
| 1371 | 
            +
                        "5882cda2b116682b4adebd25",
         | 
| 1372 | 
            +
                        "5882d58fb116682b4adec7db",
         | 
| 1373 | 
            +
                        "5884c256932ba84fbed70bf5",
         | 
| 1374 | 
            +
                        "5884cc13932ba84fbed71ec4",
         | 
| 1375 | 
            +
                        "5885bc5296fa095e0671a7f0",
         | 
| 1376 | 
            +
                        "5886d14cb791366d617a362c",
         | 
| 1377 | 
            +
                        "5888becfc02346100f4b0b21",
         | 
| 1378 | 
            +
                        "5888e408c02346100f4b1a29",
         | 
| 1379 | 
            +
                        "5889da66ec4d5a1c088e5187",
         | 
| 1380 | 
            +
                        "5889e754ec4d5a1c088e60ba",
         | 
| 1381 | 
            +
                        "5890c16b90414422fbeb0262",
         | 
| 1382 | 
            +
                        "5891d8ae9a8c0314c5cd30ab",
         | 
| 1383 | 
            +
                        "5891d0479a8c0314c5cd2abd",
         | 
| 1384 | 
            +
                        "5891ecf19a8c0314c5cd490a",
         | 
| 1385 | 
            +
                        "5892c0cd9a8c0314c5cdc977",
         | 
| 1386 | 
            +
                        "5894ab309a8c0314c5cee57d",
         | 
| 1387 | 
            +
                        "5895a6a89a8c0314c5cfca7c",
         | 
| 1388 | 
            +
                        "5895b8c29a8c0314c5cfd051",
         | 
| 1389 | 
            +
                        "5895d38f9a8c0314c5cfe50c",
         | 
| 1390 | 
            +
                        "5895f2329a8c0314c5d00117",
         | 
| 1391 | 
            +
                        "5896bb989a8c0314c5d086b6",
         | 
| 1392 | 
            +
                        "5896ebf39a8c0314c5d0a8c4",
         | 
| 1393 | 
            +
                        "5898b1bac9dccc22987b7f74",
         | 
| 1394 | 
            +
                        "5898b6ffc9dccc22987b8a03",
         | 
| 1395 | 
            +
                        "5898bbaac9dccc22987b8eba",
         | 
| 1396 | 
            +
                        "5899cfa6b76d7a3780a4cb64",
         | 
| 1397 | 
            +
                        "5899e5dcb76d7a3780a4ecc1",
         | 
| 1398 | 
            +
                        "57102be2877e1421026358af",
         | 
| 1399 | 
            +
                        "57153d4031bb9900425bde85",
         | 
| 1400 | 
            +
                        "57177cd7fb8d93461afc4527",
         | 
| 1401 | 
            +
                        "58497cdf97b73e0b090c4273",
         | 
| 1402 | 
            +
                        "58500b007072670e72c35588",
         | 
| 1403 | 
            +
                        "58510bf97072670e72c46ddf",
         | 
| 1404 | 
            +
                        "58522bd56789802282f2ecb3",
         | 
| 1405 | 
            +
                        "58524a2e0e7012308944bcf3",
         | 
| 1406 | 
            +
                        "58524a080e7012308944bcbf",
         | 
| 1407 | 
            +
                        "58524c1d0e7012308944bfda",
         | 
| 1408 | 
            +
                        "58524f170e7012308944c200",
         | 
| 1409 | 
            +
                        "58529a4e0e70123089454c6f",
         | 
| 1410 | 
            +
                        "58551bdf804be1058523556d",
         | 
| 1411 | 
            +
                        "58568c9a804be10585240b03",
         | 
| 1412 | 
            +
                        "58574b35804be105852455fd",
         | 
| 1413 | 
            +
                        "58577c60b338a62ad5ff1564",
         | 
| 1414 | 
            +
                        "58592d69b338a62ad5007a74",
         | 
| 1415 | 
            +
                        "58625f42712e2761468fb44c",
         | 
| 1416 | 
            +
                        "58651bcc712e2761469166dc",
         | 
| 1417 | 
            +
                        "58660e79712e27614691fe3d",
         | 
| 1418 | 
            +
                        "58669aad712e27614692834c",
         | 
| 1419 | 
            +
                        "58676c36833dfe3f7b88b7f2",
         | 
| 1420 | 
            +
                        "58678b2d833dfe3f7b88e244",
         | 
| 1421 | 
            +
                        "58800b0b2366dd5d06e5312d",
         | 
| 1422 | 
            +
                        "58805eac2366dd5d06e56460",
         | 
| 1423 | 
            +
                        "58806e422366dd5d06e57bb6",
         | 
| 1424 | 
            +
                        "58831d060db9bf59bf8ab98b",
         | 
| 1425 | 
            +
                        "58851ebb932ba84fbed7abad",
         | 
| 1426 | 
            +
                        "58871dc3b791366d617a55ff",
         | 
| 1427 | 
            +
                        "58873cabb791366d617a65a7",
         | 
| 1428 | 
            +
                        "58873d44b791366d617a65dd",
         | 
| 1429 | 
            +
                        "58888b3dc02346100f4af665",
         | 
| 1430 | 
            +
                        "58933bac9a8c0314c5ce3508",
         | 
| 1431 | 
            +
                        "58938e6d9a8c0314c5ce726f",
         | 
| 1432 | 
            +
                        "58951cb49a8c0314c5cf4d5e",
         | 
| 1433 | 
            +
                        "58970fd09a8c0314c5d0e383",
         | 
| 1434 | 
            +
                        "58977ef09a8c0314c5d17b26",
         | 
| 1435 | 
            +
                        "59056e6760bb961de55f3501",
         | 
| 1436 | 
            +
                        "59071f2e5a6dbd3af4130f98",
         | 
| 1437 | 
            +
                        "59102c811225725be9e64149",
         | 
| 1438 | 
            +
                        "59338e76772c3e6384afbb15",
         | 
| 1439 | 
            +
                        "59350ca084b7f26bf5ce6eb8",
         | 
| 1440 | 
            +
                        "59397e493a87372f2c9e882b",
         | 
| 1441 | 
            +
                        "59521e0b9096412211c2aa9d",
         | 
| 1442 | 
            +
                        "59817e4a1bd4b175e7038d19",
         | 
| 1443 | 
            +
                        "567884f58d2828b95e3c8eba",
         | 
| 1444 | 
            +
                        "585559d9804be10585238ddf",
         | 
| 1445 | 
            +
                        "585834cdb338a62ad5ffab4d",
         | 
| 1446 | 
            +
                        "586082d8712e2761468e2877",
         | 
| 1447 | 
            +
                        "586133c2712e2761468ecfe3",
         | 
| 1448 | 
            +
                        "586281d2712e2761468fcaa2",
         | 
| 1449 | 
            +
                        "586316e5712e276146903c4d",
         | 
| 1450 | 
            +
                        "586326ad712e276146904571",
         | 
| 1451 | 
            +
                        "586375c9712e276146907429",
         | 
| 1452 | 
            +
                        "586389c9712e276146908da6",
         | 
| 1453 | 
            +
                        "586496fa712e2761469108e7",
         | 
| 1454 | 
            +
                        "586669c6712e27614692597a",
         | 
| 1455 | 
            +
                        "586913a49d1b5e34c2808b02",
         | 
| 1456 | 
            +
                        "586922da9d1b5e34c2809ff3",
         | 
| 1457 | 
            +
                        "588185d8dfb7a15588a114a3",
         | 
| 1458 | 
            +
                        "588315c60db9bf59bf8aa928",
         | 
| 1459 | 
            +
                        "588332ee0db9bf59bf8ae9c3",
         | 
| 1460 | 
            +
                        "588519d5932ba84fbed7a04a",
         | 
| 1461 | 
            +
                        "588824d1b791366d617adeef",
         | 
| 1462 | 
            +
                        "588857f6c02346100f4ac09f",
         | 
| 1463 | 
            +
                        "589145ef90414422fbeb2e08",
         | 
| 1464 | 
            +
                        "589433fa9a8c0314c5ce9656",
         | 
| 1465 | 
            +
                        "589765d39a8c0314c5d16b12",
         | 
| 1466 | 
            +
                        "5851165f7072670e72c4860d",
         | 
| 1467 | 
            +
                        "5859341ab338a62ad500848d",
         | 
| 1468 | 
            +
                        "5863915b712e276146909135",
         | 
| 1469 | 
            +
                        "5866445b712e27614692383e",
         | 
| 1470 | 
            +
                        "5866500d712e2761469240fd",
         | 
| 1471 | 
            +
                        "5867785a833dfe3f7b88c764",
         | 
| 1472 | 
            +
                        "5867969c833dfe3f7b88e8bc",
         | 
| 1473 | 
            +
                        "5868040c833dfe3f7b8934f7",
         | 
| 1474 | 
            +
                        "5882372c8ce2c2754d076af0",
         | 
| 1475 | 
            +
                        "5883535e932ba84fbed5ad07",
         | 
| 1476 | 
            +
                        "5888358cb791366d617af69d",
         | 
| 1477 | 
            +
                        "5890330d90414422fbeaa0cb",
         | 
| 1478 | 
            +
                        "5897076e9a8c0314c5d0d31b",
         | 
| 1479 | 
            +
                        "5940564ec2d9527ab869f7e2",
         | 
| 1480 | 
            +
                        "5947719bf1b45630bd096665",
         | 
| 1481 | 
            +
                        "5948194ff1b45630bd0f47e3",
         | 
| 1482 | 
            +
                        "5950206a41b158666ac50506",
         | 
| 1483 | 
            +
                        "5983012d1bd4b175e70c985a",
         | 
| 1484 | 
            +
                        "58586810b338a62ad5ffc20c",
         | 
| 1485 | 
            +
                        "58592046b338a62ad5006b33",
         | 
| 1486 | 
            +
                        "58592854b338a62ad500750a",
         | 
| 1487 | 
            +
                        "58596531b338a62ad500aace",
         | 
| 1488 | 
            +
                        "58818685dfb7a15588a11626",
         | 
| 1489 | 
            +
                        "58829563f42b1d3ee3ec835f",
         | 
| 1490 | 
            +
                        "58894345c02346100f4b51ca",
         | 
| 1491 | 
            +
                        "585289980e7012308945276a",
         | 
| 1492 | 
            +
                        "585369770e7012308945c709",
         | 
| 1493 | 
            +
                        "585373640e7012308945cab9",
         | 
| 1494 | 
            +
                        "588230658ce2c2754d076728",
         | 
| 1495 | 
            +
                        "589388059a8c0314c5ce718b",
         | 
| 1496 | 
            +
                        "595979485ec6a95e86a58c8d",
         | 
| 1497 | 
            +
                        "5841206219d291325678ca90",
         | 
| 1498 | 
            +
                        "58563650804be1058523da55",
         | 
| 1499 | 
            +
                        "58564084804be1058523e116",
         | 
| 1500 | 
            +
                        "58636467712e27614690661f",
         | 
| 1501 | 
            +
                        "58647495712e27614690f36d",
         | 
| 1502 | 
            +
                        "58654563712e276146918643",
         | 
| 1503 | 
            +
                        "58664251712e276146923738",
         | 
| 1504 | 
            +
                        "588084032366dd5d06e59e82",
         | 
| 1505 | 
            +
                        "588159582366dd5d06e66877",
         | 
| 1506 | 
            +
                        "5890279190414422fbea9734",
         | 
| 1507 | 
            +
                        "5890641690414422fbeabbe7",
         | 
| 1508 | 
            +
                        "585203546789802282f2aaf5",
         | 
| 1509 | 
            +
                    ]
         | 
| 1510 | 
            +
             | 
| 1511 | 
            +
                    # Validation set sequences after filtering
         | 
| 1512 | 
            +
                    self.val_split_scenes = [
         | 
| 1513 | 
            +
                        "00000000000000000000000a",
         | 
| 1514 | 
            +
                        "5a4a38dad38c8a075495b5d2",
         | 
| 1515 | 
            +
                        "5a489fb1c7dab83a7d7b1070",
         | 
| 1516 | 
            +
                        "5a572fd9fc597b0478a81d14",
         | 
| 1517 | 
            +
                        "5a588a8193ac3d233f77fbca",
         | 
| 1518 | 
            +
                        "5aa0f478a9efce63548c1cb4",
         | 
| 1519 | 
            +
                        "5ae2e9c5fe405c5076abc6b2",
         | 
| 1520 | 
            +
                        "5b2c67b5e0878c381608b8d8",
         | 
| 1521 | 
            +
                        "5b21e18c58e2823a67a10dd8",
         | 
| 1522 | 
            +
                        "5b864d850d072a699b32f4ae",
         | 
| 1523 | 
            +
                        "5b4933abf2b5f44e95de482a",
         | 
| 1524 | 
            +
                        "5b37189a35304b6f75e7583e",
         | 
| 1525 | 
            +
                        "5bc5f0e896b66a2cd8f9bd36",
         | 
| 1526 | 
            +
                        "5bccd6beca24970bce448134",
         | 
| 1527 | 
            +
                        "5bf26cbbd43923194854b270",
         | 
| 1528 | 
            +
                        "5bf18642c50e6f7f8bdbd492",
         | 
| 1529 | 
            +
                        "5bfc9d5aec61ca1dd69132a2",
         | 
| 1530 | 
            +
                        "5bff3c5cfe0ea555e6bcbf3a",
         | 
| 1531 | 
            +
                        "5c1f33f1d33e1f2e4aa6dda4",
         | 
| 1532 | 
            +
                        "5c34529873a8df509ae57b58",
         | 
| 1533 | 
            +
                        "58a186444a4d262a170ae3ae",
         | 
| 1534 | 
            +
                        "58f7f7299f5b5647873cb110",
         | 
| 1535 | 
            +
                        "59acd2f4b891807f439c8992",
         | 
| 1536 | 
            +
                        "567a0fb0a825d2fb79ac9a20",
         | 
| 1537 | 
            +
                        "584ad76bfe3cb463906ce6dc",
         | 
| 1538 | 
            +
                        "584c58b77072670e72c03990",
         | 
| 1539 | 
            +
                        "586c48329d1b5e34c2838e80",
         | 
| 1540 | 
            +
                        "586df9849d1b5e34c28506de",
         | 
| 1541 | 
            +
                        "588e0d8c90414422fbe8f8b2",
         | 
| 1542 | 
            +
                        "589c300f7dc3d323d5577926",
         | 
| 1543 | 
            +
                        "590f91851225725be9e25d4e",
         | 
| 1544 | 
            +
                        "5889e344ec4d5a1c088e59be",
         | 
| 1545 | 
            +
                        "5898b31cc9dccc22987b82ec",
         | 
| 1546 | 
            +
                        "5947b62af1b45630bd0c2a02",
         | 
| 1547 | 
            +
                        "58598db2b338a62ad500bc38",
         | 
| 1548 | 
            +
                        "58669c02712e27614692851a",
         | 
| 1549 | 
            +
                        "58790c82ce911104a3467c88",
         | 
| 1550 | 
            +
                        "58897f62c02346100f4b8ee6",
         | 
| 1551 | 
            +
                        "588305ed0db9bf59bf8a8c80",
         | 
| 1552 | 
            +
                        "588457b8932ba84fbed69942",
         | 
| 1553 | 
            +
                        "5862388b712e2761468f84aa",
         | 
| 1554 | 
            +
                        "5880675a2366dd5d06e570ca",
         | 
| 1555 | 
            +
                        "5890523090414422fbeab3f0",
         | 
| 1556 | 
            +
                    ]
         | 
| 1557 | 
            +
             | 
| 1558 | 
            +
             | 
| 1559 | 
            +
            class TartanAirV2Splits:
         | 
| 1560 | 
            +
                """
         | 
| 1561 | 
            +
                This class contains the information about the splits of the TartanAir V2 dataset.
         | 
| 1562 | 
            +
                """
         | 
| 1563 | 
            +
             | 
| 1564 | 
            +
                def __init__(self):
         | 
| 1565 | 
            +
                    """
         | 
| 1566 | 
            +
                    Splits of environments with unique geometry selected based on TartanVO & UFM splits.
         | 
| 1567 | 
            +
                    """
         | 
| 1568 | 
            +
                    # Apart from the below 2 splits, all other TAv2 scenes are in the train split
         | 
| 1569 | 
            +
                    # Val split
         | 
| 1570 | 
            +
                    self.val_split_scenes = ["EndofTheWorld", "HongKong", "WesternDesertTown"]
         | 
| 1571 | 
            +
             | 
| 1572 | 
            +
                    # Test split
         | 
| 1573 | 
            +
                    self.test_split_scenes = [
         | 
| 1574 | 
            +
                        "DesertGasStation",
         | 
| 1575 | 
            +
                        "OldScandinavia",
         | 
| 1576 | 
            +
                        "PolarSciFi",
         | 
| 1577 | 
            +
                        "Sewerage",
         | 
| 1578 | 
            +
                        "Supermarket",
         | 
| 1579 | 
            +
                    ]
         | 
| 1580 | 
            +
             | 
| 1581 | 
            +
             | 
| 1582 | 
            +
            class MegaDepthSplits:
         | 
| 1583 | 
            +
                """
         | 
| 1584 | 
            +
                This class contains the information about the splits of the MegaDepth dataset.
         | 
| 1585 | 
            +
                """
         | 
| 1586 | 
            +
             | 
| 1587 | 
            +
                def __init__(self):
         | 
| 1588 | 
            +
                    """
         | 
| 1589 | 
            +
                    Validation split is based on scenes used in DUSt3R.
         | 
| 1590 | 
            +
                    """
         | 
| 1591 | 
            +
                    self.val_split_scenes = ["0015_0", "0015_1", "0022_0"]
         | 
| 1592 | 
            +
             | 
| 1593 | 
            +
             | 
| 1594 | 
            +
            class SpringSplits:
         | 
| 1595 | 
            +
                """
         | 
| 1596 | 
            +
                This class contains the information about the splits of the Spring dataset.
         | 
| 1597 | 
            +
                """
         | 
| 1598 | 
            +
             | 
| 1599 | 
            +
                def __init__(self):
         | 
| 1600 | 
            +
                    self.val_split_scenes = ["0013", "0023", "0037"]
         | 
| 1601 | 
            +
             | 
| 1602 | 
            +
             | 
| 1603 | 
            +
            class MPSDSplits:
         | 
| 1604 | 
            +
                """
         | 
| 1605 | 
            +
                This class contains the information about the splits of the MPSD dataset.
         | 
| 1606 | 
            +
                """
         | 
| 1607 | 
            +
             | 
| 1608 | 
            +
                def __init__(self):
         | 
| 1609 | 
            +
                    """
         | 
| 1610 | 
            +
                    Train & Validation split numpy files containing folder names are generated during preprocessing of MPSD dataset.
         | 
| 1611 | 
            +
                    Load the numpy files to get the list of scenes in the train & validation split.
         | 
| 1612 | 
            +
                    A 95% (Train) & 5% (Validation) split is used.
         | 
| 1613 | 
            +
                    """
         | 
| 1614 | 
            +
                    self.train_split_scenes = "load_numpy_file_with_train_scenes"
         | 
| 1615 | 
            +
                    self.val_split_scenes = "load_numpy_file_with_val_scenes"
         | 
| 1616 | 
            +
             | 
| 1617 | 
            +
             | 
| 1618 | 
            +
            class ScanNetPPSplits:
         | 
| 1619 | 
            +
                """
         | 
| 1620 | 
            +
                This class contains the information about the splits of the ScanNetPP dataset.
         | 
| 1621 | 
            +
                """
         | 
| 1622 | 
            +
             | 
| 1623 | 
            +
                def __init__(self):
         | 
| 1624 | 
            +
                    """
         | 
| 1625 | 
            +
                    Validation & Test split only contains scenes from ScanNet++V2 to prevent data leak with other methods such as DUSt3R during benchmarking.
         | 
| 1626 | 
            +
             | 
| 1627 | 
            +
                    Following logic was used to generate the splits:
         | 
| 1628 | 
            +
                    # Select 80%, 10%, 10% of the scenes for train, val, test respectively from ScanNet++ V2 (~300 scene subset; excluding V1 scenes)
         | 
| 1629 | 
            +
                    snpp_v2_test_scenes = np.random.choice(
         | 
| 1630 | 
            +
                        snpp_v2_processed_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False
         | 
| 1631 | 
            +
                    )
         | 
| 1632 | 
            +
                    remaining_scenes = [scene for scene in snpp_v2_processed_scenes if scene not in snpp_v2_test_scenes]
         | 
| 1633 | 
            +
                    snpp_v2_val_scenes = np.random.choice(
         | 
| 1634 | 
            +
                        remaining_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False
         | 
| 1635 | 
            +
                    )
         | 
| 1636 | 
            +
                    snpp_v2_train_scenes = [
         | 
| 1637 | 
            +
                        scene for scene in remaining_scenes if scene not in snpp_v2_val_scenes and scene not in snpp_v2_test_scenes
         | 
| 1638 | 
            +
                    ]
         | 
| 1639 | 
            +
                    """
         | 
| 1640 | 
            +
                    # Validation Scenes
         | 
| 1641 | 
            +
                    self.val_split_scenes = [
         | 
| 1642 | 
            +
                        "1c7a683c92",
         | 
| 1643 | 
            +
                        "2a1b555966",
         | 
| 1644 | 
            +
                        "3a43c7b8d2",
         | 
| 1645 | 
            +
                        "4aef651da7",
         | 
| 1646 | 
            +
                        "06bc6d1b24",
         | 
| 1647 | 
            +
                        "7f22d5ef1b",
         | 
| 1648 | 
            +
                        "7f77abce34",
         | 
| 1649 | 
            +
                        "8ea517a2fc",
         | 
| 1650 | 
            +
                        "29c7afafed",
         | 
| 1651 | 
            +
                        "41eb967018",
         | 
| 1652 | 
            +
                        "77b40ce601",
         | 
| 1653 | 
            +
                        "086f09d6e3",
         | 
| 1654 | 
            +
                        "307e3262f1",
         | 
| 1655 | 
            +
                        "639f2c4d5a",
         | 
| 1656 | 
            +
                        "894dbd41f1",
         | 
| 1657 | 
            +
                        "898a7dfd0c",
         | 
| 1658 | 
            +
                        "2779f8f9e2",
         | 
| 1659 | 
            +
                        "151178afd7",
         | 
| 1660 | 
            +
                        "182932a4f3",
         | 
| 1661 | 
            +
                        "635852d56e",
         | 
| 1662 | 
            +
                        "9906136b57",
         | 
| 1663 | 
            +
                        "af112b8903",
         | 
| 1664 | 
            +
                        "b0f057c684",
         | 
| 1665 | 
            +
                        "b37177e6c8",
         | 
| 1666 | 
            +
                        "b119249da7",
         | 
| 1667 | 
            +
                        "be8367fcbe",
         | 
| 1668 | 
            +
                        "c8fc01c453",
         | 
| 1669 | 
            +
                        "e1fb8626c8",
         | 
| 1670 | 
            +
                        "e2caaaf5b5",
         | 
| 1671 | 
            +
                        "fe3fc057a1",
         | 
| 1672 | 
            +
                    ]
         | 
| 1673 | 
            +
             | 
| 1674 | 
            +
                    # Test Scenes
         | 
| 1675 | 
            +
                    self.test_split_scenes = [
         | 
| 1676 | 
            +
                        "0e900bcc5c",
         | 
| 1677 | 
            +
                        "0eba3981c9",
         | 
| 1678 | 
            +
                        "1cbb105c6a",
         | 
| 1679 | 
            +
                        "3c8d535d49",
         | 
| 1680 | 
            +
                        "5d902f1593",
         | 
| 1681 | 
            +
                        "6bd39ac392",
         | 
| 1682 | 
            +
                        "6c14d5fd01",
         | 
| 1683 | 
            +
                        "7c31a42404",
         | 
| 1684 | 
            +
                        "9bfbc75700",
         | 
| 1685 | 
            +
                        "13b4efaf62",
         | 
| 1686 | 
            +
                        "062e5a23a6",
         | 
| 1687 | 
            +
                        "95b9971d01",
         | 
| 1688 | 
            +
                        "246fe09e98",
         | 
| 1689 | 
            +
                        "637a27d04b",
         | 
| 1690 | 
            +
                        "725b8f0cba",
         | 
| 1691 | 
            +
                        "413085a827",
         | 
| 1692 | 
            +
                        "696317583f",
         | 
| 1693 | 
            +
                        "a4c043ac48",
         | 
| 1694 | 
            +
                        "a9e4791c7e",
         | 
| 1695 | 
            +
                        "b0b004c40f",
         | 
| 1696 | 
            +
                        "c3bc5e82c5",
         | 
| 1697 | 
            +
                        "c31ebd4b22",
         | 
| 1698 | 
            +
                        "cba701332a",
         | 
| 1699 | 
            +
                        "cc5ea8026c",
         | 
| 1700 | 
            +
                        "cec8312f4e",
         | 
| 1701 | 
            +
                        "e3b3b0d0c7",
         | 
| 1702 | 
            +
                        "e667e09fe6",
         | 
| 1703 | 
            +
                        "eaa6c90310",
         | 
| 1704 | 
            +
                        "f9397af4cb",
         | 
| 1705 | 
            +
                        "fb893ffaf3",
         | 
| 1706 | 
            +
                    ]
         | 
| 1707 | 
            +
             | 
| 1708 | 
            +
             | 
| 1709 | 
            +
            class DL3DV10KSplits:
         | 
| 1710 | 
            +
                """
         | 
| 1711 | 
            +
                This class contains the information about the splits of the DL3DV-10K dataset.
         | 
| 1712 | 
            +
                We use the official benchmark split as the val split.
         | 
| 1713 | 
            +
                """
         | 
| 1714 | 
            +
             | 
| 1715 | 
            +
                def __init__(self):
         | 
| 1716 | 
            +
                    """
         | 
| 1717 | 
            +
                    Validation split is based on DL3DV-Benchmark.
         | 
| 1718 | 
            +
                    """
         | 
| 1719 | 
            +
                    self.val_split_scenes = [
         | 
| 1720 | 
            +
                        "load https://huggingface.co/datasets/DL3DV/DL3DV-Benchmark/raw/main/benchmark-meta.csv \
         | 
| 1721 | 
            +
                        & https://raw.githubusercontent.com/DL3DV-10K/Dataset/main/cache/DL3DV-valid.csv"
         | 
| 1722 | 
            +
                    ]
         | 
| 1723 | 
            +
             | 
| 1724 | 
            +
             | 
| 1725 | 
            +
            class ETH3DSplits:
         | 
| 1726 | 
            +
                """
         | 
| 1727 | 
            +
                This class contains the information about the splits of the ETH3D dataset.
         | 
| 1728 | 
            +
                """
         | 
| 1729 | 
            +
             | 
| 1730 | 
            +
                def __init__(self):
         | 
| 1731 | 
            +
                    """
         | 
| 1732 | 
            +
                    All scenes are in the test split.
         | 
| 1733 | 
            +
                    """
         | 
| 1734 | 
            +
                    self.test_split_scenes = "all"
         | 
    	
        mapanything/datasets/wai/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        mapanything/datasets/wai/ase.py
    ADDED
    
    | @@ -0,0 +1,294 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            ASE Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 15 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class ASEWAI(BaseDataset):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                ASE dataset containing large diversity of synthetic indoor scenes.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    *args,
         | 
| 26 | 
            +
                    ROOT,
         | 
| 27 | 
            +
                    dataset_metadata_dir,
         | 
| 28 | 
            +
                    split,
         | 
| 29 | 
            +
                    overfit_num_sets=None,
         | 
| 30 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 31 | 
            +
                    specific_scene_name: str = None,
         | 
| 32 | 
            +
                    **kwargs,
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    Initialize the dataset attributes.
         | 
| 36 | 
            +
                    Args:
         | 
| 37 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 38 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 39 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 40 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 41 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 42 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    # Initialize the dataset attributes
         | 
| 45 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 46 | 
            +
                    self.ROOT = ROOT
         | 
| 47 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 48 | 
            +
                    self.split = split
         | 
| 49 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 50 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 51 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 52 | 
            +
                    self._load_data()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # Define the dataset type flags
         | 
| 55 | 
            +
                    self.is_metric_scale = True
         | 
| 56 | 
            +
                    self.is_synthetic = True
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def _load_data(self):
         | 
| 59 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 60 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 61 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 62 | 
            +
                        self.dataset_metadata_dir,
         | 
| 63 | 
            +
                        self.split,
         | 
| 64 | 
            +
                        f"ase_scene_list_{self.split}.npy",
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Get the list of all scenes
         | 
| 69 | 
            +
                    if not self.sample_specific_scene:
         | 
| 70 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 73 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 76 | 
            +
                    # Get the scene name of the sampled index
         | 
| 77 | 
            +
                    scene_index = sampled_idx
         | 
| 78 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 81 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 82 | 
            +
                    scene_meta = load_data(
         | 
| 83 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 86 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 89 | 
            +
                    covisibility_version_key = "v0"
         | 
| 90 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 91 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 94 | 
            +
                    covisibility_map_name = next(
         | 
| 95 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 98 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 103 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 104 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 108 | 
            +
                    views = []
         | 
| 109 | 
            +
                    for view_index in view_indices:
         | 
| 110 | 
            +
                        # Load the data corresponding to the view
         | 
| 111 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 112 | 
            +
                        view_data = load_frame(
         | 
| 113 | 
            +
                            scene_root,
         | 
| 114 | 
            +
                            view_file_name,
         | 
| 115 | 
            +
                            modalities=["image", "depth"],
         | 
| 116 | 
            +
                            scene_meta=scene_meta,
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # Convert necessary data to numpy
         | 
| 120 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 121 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 122 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 123 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 127 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 130 | 
            +
                        image, depthmap, intrinsics = self._crop_resize_if_necessary(
         | 
| 131 | 
            +
                            image=image,
         | 
| 132 | 
            +
                            resolution=resolution,
         | 
| 133 | 
            +
                            depthmap=depthmap,
         | 
| 134 | 
            +
                            intrinsics=intrinsics,
         | 
| 135 | 
            +
                            additional_quantities=None,
         | 
| 136 | 
            +
                        )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 139 | 
            +
                        views.append(
         | 
| 140 | 
            +
                            dict(
         | 
| 141 | 
            +
                                img=image,
         | 
| 142 | 
            +
                                depthmap=depthmap,
         | 
| 143 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 144 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 145 | 
            +
                                dataset="ASE",
         | 
| 146 | 
            +
                                label=scene_name,
         | 
| 147 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 148 | 
            +
                            )
         | 
| 149 | 
            +
                        )
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    return views
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            def get_parser():
         | 
| 155 | 
            +
                import argparse
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 158 | 
            +
                parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/ase", type=str)
         | 
| 159 | 
            +
                parser.add_argument(
         | 
| 160 | 
            +
                    "-dmd",
         | 
| 161 | 
            +
                    "--dataset_metadata_dir",
         | 
| 162 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 163 | 
            +
                    type=str,
         | 
| 164 | 
            +
                )
         | 
| 165 | 
            +
                parser.add_argument(
         | 
| 166 | 
            +
                    "-nv",
         | 
| 167 | 
            +
                    "--num_of_views",
         | 
| 168 | 
            +
                    default=2,
         | 
| 169 | 
            +
                    type=int,
         | 
| 170 | 
            +
                )
         | 
| 171 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                return parser
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 176 | 
            +
            if __name__ == "__main__":
         | 
| 177 | 
            +
                import rerun as rr
         | 
| 178 | 
            +
                from tqdm import tqdm
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 181 | 
            +
                from mapanything.utils.image import rgb
         | 
| 182 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                parser = get_parser()
         | 
| 185 | 
            +
                script_add_rerun_args(
         | 
| 186 | 
            +
                    parser
         | 
| 187 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 188 | 
            +
                args = parser.parse_args()
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                dataset = ASEWAI(
         | 
| 191 | 
            +
                    num_views=args.num_of_views,
         | 
| 192 | 
            +
                    split="train",
         | 
| 193 | 
            +
                    covisibility_thres=0.25,
         | 
| 194 | 
            +
                    ROOT=args.root_dir,
         | 
| 195 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 196 | 
            +
                    resolution=(518, 518),
         | 
| 197 | 
            +
                    aug_crop=16,
         | 
| 198 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 199 | 
            +
                    data_norm_type="dinov2",
         | 
| 200 | 
            +
                )
         | 
| 201 | 
            +
                # dataset = ASEWAI(
         | 
| 202 | 
            +
                #     num_views=args.num_of_views,
         | 
| 203 | 
            +
                #     split="val",
         | 
| 204 | 
            +
                #     covisibility_thres=0.25,
         | 
| 205 | 
            +
                #     ROOT=args.root_dir,
         | 
| 206 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 207 | 
            +
                #     resolution=(518, 518),
         | 
| 208 | 
            +
                #     seed=777,
         | 
| 209 | 
            +
                #     transform="imgnorm",
         | 
| 210 | 
            +
                #     data_norm_type="dinov2",
         | 
| 211 | 
            +
                # )
         | 
| 212 | 
            +
                print(dataset.get_stats())
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                if args.viz:
         | 
| 215 | 
            +
                    rr.script_setup(args, "ASE_Dataloader")
         | 
| 216 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 217 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 222 | 
            +
                    views = dataset[idx]
         | 
| 223 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 224 | 
            +
                    sample_name = f"{idx}"
         | 
| 225 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 226 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 227 | 
            +
                    print(sample_name)
         | 
| 228 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 229 | 
            +
                        image = rgb(
         | 
| 230 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 231 | 
            +
                        )
         | 
| 232 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 233 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 234 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 235 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 236 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 237 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 238 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 239 | 
            +
                        else:
         | 
| 240 | 
            +
                            non_ambiguous_mask = None
         | 
| 241 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 242 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 243 | 
            +
                        else:
         | 
| 244 | 
            +
                            prior_depth_along_ray = None
         | 
| 245 | 
            +
                        if args.viz:
         | 
| 246 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 247 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 248 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 249 | 
            +
                            # Log camera info and loaded data
         | 
| 250 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 251 | 
            +
                            rr.log(
         | 
| 252 | 
            +
                                base_name,
         | 
| 253 | 
            +
                                rr.Transform3D(
         | 
| 254 | 
            +
                                    translation=pose[:3, 3],
         | 
| 255 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 256 | 
            +
                                ),
         | 
| 257 | 
            +
                            )
         | 
| 258 | 
            +
                            rr.log(
         | 
| 259 | 
            +
                                f"{base_name}/pinhole",
         | 
| 260 | 
            +
                                rr.Pinhole(
         | 
| 261 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 262 | 
            +
                                    height=height,
         | 
| 263 | 
            +
                                    width=width,
         | 
| 264 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 265 | 
            +
                                ),
         | 
| 266 | 
            +
                            )
         | 
| 267 | 
            +
                            rr.log(
         | 
| 268 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 269 | 
            +
                                rr.Image(image),
         | 
| 270 | 
            +
                            )
         | 
| 271 | 
            +
                            rr.log(
         | 
| 272 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 273 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 274 | 
            +
                            )
         | 
| 275 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 276 | 
            +
                                rr.log(
         | 
| 277 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 278 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 279 | 
            +
                                )
         | 
| 280 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 281 | 
            +
                                rr.log(
         | 
| 282 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 283 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 284 | 
            +
                                )
         | 
| 285 | 
            +
                            # Log points in 3D
         | 
| 286 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 287 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 288 | 
            +
                            rr.log(
         | 
| 289 | 
            +
                                pts_name,
         | 
| 290 | 
            +
                                rr.Points3D(
         | 
| 291 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 292 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 293 | 
            +
                                ),
         | 
| 294 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/blendedmvs.py
    ADDED
    
    | @@ -0,0 +1,313 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            BlendedMVS Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import cv2
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 16 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class BlendedMVSWAI(BaseDataset):
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                BlendedMVS dataset containing object-centric and birds-eye-view scenes.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    *args,
         | 
| 27 | 
            +
                    ROOT,
         | 
| 28 | 
            +
                    dataset_metadata_dir,
         | 
| 29 | 
            +
                    split,
         | 
| 30 | 
            +
                    overfit_num_sets=None,
         | 
| 31 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 32 | 
            +
                    specific_scene_name: str = None,
         | 
| 33 | 
            +
                    **kwargs,
         | 
| 34 | 
            +
                ):
         | 
| 35 | 
            +
                    """
         | 
| 36 | 
            +
                    Initialize the dataset attributes.
         | 
| 37 | 
            +
                    Args:
         | 
| 38 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 39 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 40 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 41 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 42 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 43 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    # Initialize the dataset attributes
         | 
| 46 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 47 | 
            +
                    self.ROOT = ROOT
         | 
| 48 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 49 | 
            +
                    self.split = split
         | 
| 50 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 51 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 52 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 53 | 
            +
                    self._load_data()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Define the dataset type flags
         | 
| 56 | 
            +
                    self.is_metric_scale = False
         | 
| 57 | 
            +
                    self.is_synthetic = False
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def _load_data(self):
         | 
| 60 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 61 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 62 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 63 | 
            +
                        self.dataset_metadata_dir,
         | 
| 64 | 
            +
                        self.split,
         | 
| 65 | 
            +
                        f"blendedmvs_scene_list_{self.split}.npy",
         | 
| 66 | 
            +
                    )
         | 
| 67 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # Get the list of all scenes
         | 
| 70 | 
            +
                    if not self.sample_specific_scene:
         | 
| 71 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 72 | 
            +
                    else:
         | 
| 73 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 74 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 77 | 
            +
                    # Get the scene name of the sampled index
         | 
| 78 | 
            +
                    scene_index = sampled_idx
         | 
| 79 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 82 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 83 | 
            +
                    scene_meta = load_data(
         | 
| 84 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 87 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 90 | 
            +
                    covisibility_version_key = "v0"
         | 
| 91 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 92 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 95 | 
            +
                    covisibility_map_name = next(
         | 
| 96 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 99 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 104 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 105 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 109 | 
            +
                    views = []
         | 
| 110 | 
            +
                    for view_index in view_indices:
         | 
| 111 | 
            +
                        # Load the data corresponding to the view
         | 
| 112 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 113 | 
            +
                        view_data = load_frame(
         | 
| 114 | 
            +
                            scene_root,
         | 
| 115 | 
            +
                            view_file_name,
         | 
| 116 | 
            +
                            modalities=["image", "depth", "pred_mask/moge2"],
         | 
| 117 | 
            +
                            scene_meta=scene_meta,
         | 
| 118 | 
            +
                        )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                        # Convert necessary data to numpy
         | 
| 121 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 122 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 123 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 128 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        # Get the non_ambiguous_mask and ensure it matches image resolution
         | 
| 131 | 
            +
                        non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
         | 
| 132 | 
            +
                        non_ambiguous_mask = cv2.resize(
         | 
| 133 | 
            +
                            non_ambiguous_mask,
         | 
| 134 | 
            +
                            (image.shape[1], image.shape[0]),
         | 
| 135 | 
            +
                            interpolation=cv2.INTER_NEAREST,
         | 
| 136 | 
            +
                        )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        # Mask out the GT depth using the non_ambiguous_mask
         | 
| 139 | 
            +
                        depthmap = np.where(non_ambiguous_mask, depthmap, 0)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 142 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 143 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 144 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 145 | 
            +
                                image=image,
         | 
| 146 | 
            +
                                resolution=resolution,
         | 
| 147 | 
            +
                                depthmap=depthmap,
         | 
| 148 | 
            +
                                intrinsics=intrinsics,
         | 
| 149 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 150 | 
            +
                            )
         | 
| 151 | 
            +
                        )
         | 
| 152 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 155 | 
            +
                        views.append(
         | 
| 156 | 
            +
                            dict(
         | 
| 157 | 
            +
                                img=image,
         | 
| 158 | 
            +
                                depthmap=depthmap,
         | 
| 159 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 160 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 161 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 162 | 
            +
                                dataset="BlendedMVS",
         | 
| 163 | 
            +
                                label=scene_name,
         | 
| 164 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 165 | 
            +
                            )
         | 
| 166 | 
            +
                        )
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    return views
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def get_parser():
         | 
| 172 | 
            +
                import argparse
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 175 | 
            +
                parser.add_argument(
         | 
| 176 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/blendedmvs", type=str
         | 
| 177 | 
            +
                )
         | 
| 178 | 
            +
                parser.add_argument(
         | 
| 179 | 
            +
                    "-dmd",
         | 
| 180 | 
            +
                    "--dataset_metadata_dir",
         | 
| 181 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 182 | 
            +
                    type=str,
         | 
| 183 | 
            +
                )
         | 
| 184 | 
            +
                parser.add_argument(
         | 
| 185 | 
            +
                    "-nv",
         | 
| 186 | 
            +
                    "--num_of_views",
         | 
| 187 | 
            +
                    default=2,
         | 
| 188 | 
            +
                    type=int,
         | 
| 189 | 
            +
                )
         | 
| 190 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                return parser
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            if __name__ == "__main__":
         | 
| 196 | 
            +
                import rerun as rr
         | 
| 197 | 
            +
                from tqdm import tqdm
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 200 | 
            +
                from mapanything.utils.image import rgb
         | 
| 201 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                parser = get_parser()
         | 
| 204 | 
            +
                script_add_rerun_args(
         | 
| 205 | 
            +
                    parser
         | 
| 206 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 207 | 
            +
                args = parser.parse_args()
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                dataset = BlendedMVSWAI(
         | 
| 210 | 
            +
                    num_views=args.num_of_views,
         | 
| 211 | 
            +
                    split="train",
         | 
| 212 | 
            +
                    covisibility_thres=0.25,
         | 
| 213 | 
            +
                    ROOT=args.root_dir,
         | 
| 214 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 215 | 
            +
                    resolution=(518, 392),
         | 
| 216 | 
            +
                    aug_crop=16,
         | 
| 217 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 218 | 
            +
                    data_norm_type="dinov2",
         | 
| 219 | 
            +
                )
         | 
| 220 | 
            +
                # dataset = BlendedMVSWAI(
         | 
| 221 | 
            +
                #     num_views=args.num_of_views,
         | 
| 222 | 
            +
                #     split="val",
         | 
| 223 | 
            +
                #     covisibility_thres=0.25,
         | 
| 224 | 
            +
                #     ROOT=args.root_dir,
         | 
| 225 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 226 | 
            +
                #     resolution=(518, 392),
         | 
| 227 | 
            +
                #     seed=777,
         | 
| 228 | 
            +
                #     transform="imgnorm",
         | 
| 229 | 
            +
                #     data_norm_type="dinov2",
         | 
| 230 | 
            +
                # )
         | 
| 231 | 
            +
                print(dataset.get_stats())
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                if args.viz:
         | 
| 234 | 
            +
                    rr.script_setup(args, "BlendedMVS_Dataloader")
         | 
| 235 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 236 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 241 | 
            +
                    views = dataset[idx]
         | 
| 242 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 243 | 
            +
                    sample_name = f"{idx}"
         | 
| 244 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 245 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 246 | 
            +
                    print(sample_name)
         | 
| 247 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 248 | 
            +
                        image = rgb(
         | 
| 249 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 250 | 
            +
                        )
         | 
| 251 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 252 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 253 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 254 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 255 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 256 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 257 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 258 | 
            +
                        else:
         | 
| 259 | 
            +
                            non_ambiguous_mask = None
         | 
| 260 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 261 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 262 | 
            +
                        else:
         | 
| 263 | 
            +
                            prior_depth_along_ray = None
         | 
| 264 | 
            +
                        if args.viz:
         | 
| 265 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 266 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 267 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 268 | 
            +
                            # Log camera info and loaded data
         | 
| 269 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 270 | 
            +
                            rr.log(
         | 
| 271 | 
            +
                                base_name,
         | 
| 272 | 
            +
                                rr.Transform3D(
         | 
| 273 | 
            +
                                    translation=pose[:3, 3],
         | 
| 274 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 275 | 
            +
                                ),
         | 
| 276 | 
            +
                            )
         | 
| 277 | 
            +
                            rr.log(
         | 
| 278 | 
            +
                                f"{base_name}/pinhole",
         | 
| 279 | 
            +
                                rr.Pinhole(
         | 
| 280 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 281 | 
            +
                                    height=height,
         | 
| 282 | 
            +
                                    width=width,
         | 
| 283 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 284 | 
            +
                                ),
         | 
| 285 | 
            +
                            )
         | 
| 286 | 
            +
                            rr.log(
         | 
| 287 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 288 | 
            +
                                rr.Image(image),
         | 
| 289 | 
            +
                            )
         | 
| 290 | 
            +
                            rr.log(
         | 
| 291 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 292 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 293 | 
            +
                            )
         | 
| 294 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 295 | 
            +
                                rr.log(
         | 
| 296 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 297 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 298 | 
            +
                                )
         | 
| 299 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 300 | 
            +
                                rr.log(
         | 
| 301 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 302 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 303 | 
            +
                                )
         | 
| 304 | 
            +
                            # Log points in 3D
         | 
| 305 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 306 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 307 | 
            +
                            rr.log(
         | 
| 308 | 
            +
                                pts_name,
         | 
| 309 | 
            +
                                rr.Points3D(
         | 
| 310 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 311 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 312 | 
            +
                                ),
         | 
| 313 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/dl3dv.py
    ADDED
    
    | @@ -0,0 +1,356 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            DL3DV Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import cv2
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 16 | 
            +
            from mapanything.utils.cropping import (
         | 
| 17 | 
            +
                rescale_image_and_other_optional_info,
         | 
| 18 | 
            +
                resize_with_nearest_interpolation_to_match_aspect_ratio,
         | 
| 19 | 
            +
            )
         | 
| 20 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class DL3DVWAI(BaseDataset):
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                DL3DV dataset containing over 10k in-the-wild and indoor scenes.
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def __init__(
         | 
| 29 | 
            +
                    self,
         | 
| 30 | 
            +
                    *args,
         | 
| 31 | 
            +
                    ROOT,
         | 
| 32 | 
            +
                    dataset_metadata_dir,
         | 
| 33 | 
            +
                    split,
         | 
| 34 | 
            +
                    overfit_num_sets=None,
         | 
| 35 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 36 | 
            +
                    specific_scene_name: str = None,
         | 
| 37 | 
            +
                    mvs_confidence_filter_thres: float = 0.25,
         | 
| 38 | 
            +
                    **kwargs,
         | 
| 39 | 
            +
                ):
         | 
| 40 | 
            +
                    """
         | 
| 41 | 
            +
                    Initialize the dataset attributes.
         | 
| 42 | 
            +
                    Args:
         | 
| 43 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 44 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 45 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 46 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 47 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 48 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 49 | 
            +
                        mvs_confidence_filter_thres: Confidence threshold to filter MVS depth. Defaults to 0.25.
         | 
| 50 | 
            +
                    """
         | 
| 51 | 
            +
                    # Initialize the dataset attributes
         | 
| 52 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 53 | 
            +
                    self.ROOT = ROOT
         | 
| 54 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 55 | 
            +
                    self.split = split
         | 
| 56 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 57 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 58 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 59 | 
            +
                    self.mvs_confidence_filter_thres = mvs_confidence_filter_thres
         | 
| 60 | 
            +
                    self._load_data()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # Define the dataset type flags
         | 
| 63 | 
            +
                    self.is_metric_scale = False
         | 
| 64 | 
            +
                    self.is_synthetic = False
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def _load_data(self):
         | 
| 67 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 68 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 69 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 70 | 
            +
                        self.dataset_metadata_dir,
         | 
| 71 | 
            +
                        self.split,
         | 
| 72 | 
            +
                        f"dl3dv_scene_list_{self.split}.npy",
         | 
| 73 | 
            +
                    )
         | 
| 74 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    # Get the list of all scenes
         | 
| 77 | 
            +
                    if not self.sample_specific_scene:
         | 
| 78 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 79 | 
            +
                    else:
         | 
| 80 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 81 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 84 | 
            +
                    # Get the scene name of the sampled index
         | 
| 85 | 
            +
                    scene_index = sampled_idx
         | 
| 86 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 89 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 90 | 
            +
                    scene_meta = load_data(
         | 
| 91 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 94 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 97 | 
            +
                    covisibility_version_key = "v0_mvsa_based"
         | 
| 98 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 99 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 102 | 
            +
                    covisibility_map_name = next(
         | 
| 103 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 104 | 
            +
                    )
         | 
| 105 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 106 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 111 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 112 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 113 | 
            +
                    )
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 116 | 
            +
                    views = []
         | 
| 117 | 
            +
                    for view_index in view_indices:
         | 
| 118 | 
            +
                        # Load the data corresponding to the view
         | 
| 119 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 120 | 
            +
                        view_data = load_frame(
         | 
| 121 | 
            +
                            scene_root,
         | 
| 122 | 
            +
                            view_file_name,
         | 
| 123 | 
            +
                            modalities=[
         | 
| 124 | 
            +
                                "image",
         | 
| 125 | 
            +
                                "pred_depth/mvsanywhere",
         | 
| 126 | 
            +
                                "pred_mask/moge2",
         | 
| 127 | 
            +
                                "depth_confidence/mvsanywhere",
         | 
| 128 | 
            +
                            ],
         | 
| 129 | 
            +
                            scene_meta=scene_meta,
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                        # Convert necessary data to numpy
         | 
| 133 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 134 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 135 | 
            +
                        depthmap = view_data["pred_depth/mvsanywhere"].numpy().astype(np.float32)
         | 
| 136 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 137 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 140 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                        # Get the dimensions of the original image
         | 
| 143 | 
            +
                        img_h, img_w = image.shape[:2]
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                        # Resize depth to match image aspect ratio while ensuring that depth resolution doesn't increase
         | 
| 146 | 
            +
                        depthmap, target_depth_h, target_depth_w = (
         | 
| 147 | 
            +
                            resize_with_nearest_interpolation_to_match_aspect_ratio(
         | 
| 148 | 
            +
                                input_data=depthmap, img_h=img_h, img_w=img_w
         | 
| 149 | 
            +
                            )
         | 
| 150 | 
            +
                        )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                        # Now resize the image and update intrinsics to match the resized depth
         | 
| 153 | 
            +
                        image, _, intrinsics, _ = rescale_image_and_other_optional_info(
         | 
| 154 | 
            +
                            image=image,
         | 
| 155 | 
            +
                            output_resolution=(target_depth_w, target_depth_h),
         | 
| 156 | 
            +
                            depthmap=None,
         | 
| 157 | 
            +
                            camera_intrinsics=intrinsics,
         | 
| 158 | 
            +
                        )
         | 
| 159 | 
            +
                        image = np.array(image)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                        # Get the depth confidence map and mask out the MVS depth
         | 
| 162 | 
            +
                        confidence_map = view_data["depth_confidence/mvsanywhere"].numpy()
         | 
| 163 | 
            +
                        confidence_mask = (
         | 
| 164 | 
            +
                            confidence_map > self.mvs_confidence_filter_thres
         | 
| 165 | 
            +
                        ).astype(int)
         | 
| 166 | 
            +
                        confidence_mask = cv2.resize(
         | 
| 167 | 
            +
                            confidence_mask,
         | 
| 168 | 
            +
                            (image.shape[1], image.shape[0]),
         | 
| 169 | 
            +
                            interpolation=cv2.INTER_NEAREST,
         | 
| 170 | 
            +
                        )
         | 
| 171 | 
            +
                        depthmap = np.where(confidence_mask, depthmap, 0)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        # Get the non_ambiguous_mask and ensure it matches image resolution
         | 
| 174 | 
            +
                        non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
         | 
| 175 | 
            +
                        non_ambiguous_mask = cv2.resize(
         | 
| 176 | 
            +
                            non_ambiguous_mask,
         | 
| 177 | 
            +
                            (image.shape[1], image.shape[0]),
         | 
| 178 | 
            +
                            interpolation=cv2.INTER_NEAREST,
         | 
| 179 | 
            +
                        )
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                        # Mask out the GT depth using the non_ambiguous_mask
         | 
| 182 | 
            +
                        depthmap = np.where(non_ambiguous_mask, depthmap, 0)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 185 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 186 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 187 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 188 | 
            +
                                image=image,
         | 
| 189 | 
            +
                                resolution=resolution,
         | 
| 190 | 
            +
                                depthmap=depthmap,
         | 
| 191 | 
            +
                                intrinsics=intrinsics,
         | 
| 192 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 193 | 
            +
                            )
         | 
| 194 | 
            +
                        )
         | 
| 195 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 198 | 
            +
                        views.append(
         | 
| 199 | 
            +
                            dict(
         | 
| 200 | 
            +
                                img=image,
         | 
| 201 | 
            +
                                depthmap=depthmap,
         | 
| 202 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 203 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 204 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 205 | 
            +
                                dataset="DL3DV",
         | 
| 206 | 
            +
                                label=scene_name,
         | 
| 207 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 208 | 
            +
                            )
         | 
| 209 | 
            +
                        )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    return views
         | 
| 212 | 
            +
             | 
| 213 | 
            +
             | 
| 214 | 
            +
            def get_parser():
         | 
| 215 | 
            +
                import argparse
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 218 | 
            +
                parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dl3dv", type=str)
         | 
| 219 | 
            +
                parser.add_argument(
         | 
| 220 | 
            +
                    "-dmd",
         | 
| 221 | 
            +
                    "--dataset_metadata_dir",
         | 
| 222 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 223 | 
            +
                    type=str,
         | 
| 224 | 
            +
                )
         | 
| 225 | 
            +
                parser.add_argument(
         | 
| 226 | 
            +
                    "-nv",
         | 
| 227 | 
            +
                    "--num_of_views",
         | 
| 228 | 
            +
                    default=2,
         | 
| 229 | 
            +
                    type=int,
         | 
| 230 | 
            +
                )
         | 
| 231 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                return parser
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            if __name__ == "__main__":
         | 
| 237 | 
            +
                import rerun as rr
         | 
| 238 | 
            +
                from tqdm import tqdm
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 241 | 
            +
                from mapanything.utils.image import rgb
         | 
| 242 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                parser = get_parser()
         | 
| 245 | 
            +
                script_add_rerun_args(
         | 
| 246 | 
            +
                    parser
         | 
| 247 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 248 | 
            +
                args = parser.parse_args()
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                dataset = DL3DVWAI(
         | 
| 251 | 
            +
                    num_views=args.num_of_views,
         | 
| 252 | 
            +
                    split="train",
         | 
| 253 | 
            +
                    covisibility_thres=0.25,
         | 
| 254 | 
            +
                    ROOT=args.root_dir,
         | 
| 255 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 256 | 
            +
                    mvs_confidence_filter_thres=0.25,
         | 
| 257 | 
            +
                    resolution=(518, 294),
         | 
| 258 | 
            +
                    aug_crop=16,
         | 
| 259 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 260 | 
            +
                    data_norm_type="dinov2",
         | 
| 261 | 
            +
                )
         | 
| 262 | 
            +
                # dataset = DL3DVWAI(
         | 
| 263 | 
            +
                #     num_views=args.num_of_views,
         | 
| 264 | 
            +
                #     split="val",
         | 
| 265 | 
            +
                #     covisibility_thres=0.25,
         | 
| 266 | 
            +
                #     ROOT=args.root_dir,
         | 
| 267 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 268 | 
            +
                #     mvs_confidence_filter_thres=0.25,
         | 
| 269 | 
            +
                #     resolution=(518, 294),
         | 
| 270 | 
            +
                #     seed=777,
         | 
| 271 | 
            +
                #     transform="imgnorm",
         | 
| 272 | 
            +
                #     data_norm_type="dinov2",
         | 
| 273 | 
            +
                # )
         | 
| 274 | 
            +
                print(dataset.get_stats())
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                if args.viz:
         | 
| 277 | 
            +
                    rr.script_setup(args, "DL3DV_Dataloader")
         | 
| 278 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 279 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 284 | 
            +
                    views = dataset[idx]
         | 
| 285 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 286 | 
            +
                    sample_name = f"{idx}"
         | 
| 287 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 288 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 289 | 
            +
                    print(sample_name)
         | 
| 290 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 291 | 
            +
                        image = rgb(
         | 
| 292 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 293 | 
            +
                        )
         | 
| 294 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 295 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 296 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 297 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 298 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 299 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 300 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 301 | 
            +
                        else:
         | 
| 302 | 
            +
                            non_ambiguous_mask = None
         | 
| 303 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 304 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 305 | 
            +
                        else:
         | 
| 306 | 
            +
                            prior_depth_along_ray = None
         | 
| 307 | 
            +
                        if args.viz:
         | 
| 308 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 309 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 310 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 311 | 
            +
                            # Log camera info and loaded data
         | 
| 312 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 313 | 
            +
                            rr.log(
         | 
| 314 | 
            +
                                base_name,
         | 
| 315 | 
            +
                                rr.Transform3D(
         | 
| 316 | 
            +
                                    translation=pose[:3, 3],
         | 
| 317 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 318 | 
            +
                                ),
         | 
| 319 | 
            +
                            )
         | 
| 320 | 
            +
                            rr.log(
         | 
| 321 | 
            +
                                f"{base_name}/pinhole",
         | 
| 322 | 
            +
                                rr.Pinhole(
         | 
| 323 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 324 | 
            +
                                    height=height,
         | 
| 325 | 
            +
                                    width=width,
         | 
| 326 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 327 | 
            +
                                ),
         | 
| 328 | 
            +
                            )
         | 
| 329 | 
            +
                            rr.log(
         | 
| 330 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 331 | 
            +
                                rr.Image(image),
         | 
| 332 | 
            +
                            )
         | 
| 333 | 
            +
                            rr.log(
         | 
| 334 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 335 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 336 | 
            +
                            )
         | 
| 337 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 338 | 
            +
                                rr.log(
         | 
| 339 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 340 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 341 | 
            +
                                )
         | 
| 342 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 343 | 
            +
                                rr.log(
         | 
| 344 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 345 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 346 | 
            +
                                )
         | 
| 347 | 
            +
                            # Log points in 3D
         | 
| 348 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 349 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 350 | 
            +
                            rr.log(
         | 
| 351 | 
            +
                                pts_name,
         | 
| 352 | 
            +
                                rr.Points3D(
         | 
| 353 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 354 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 355 | 
            +
                                ),
         | 
| 356 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/dynamicreplica.py
    ADDED
    
    | @@ -0,0 +1,297 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            Dynamic Replica Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 15 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class DynamicReplicaWAI(BaseDataset):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                Dynamic Replica dataset containing synthetic scenes with humans and animals.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    *args,
         | 
| 26 | 
            +
                    ROOT,
         | 
| 27 | 
            +
                    dataset_metadata_dir,
         | 
| 28 | 
            +
                    split,
         | 
| 29 | 
            +
                    overfit_num_sets=None,
         | 
| 30 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 31 | 
            +
                    specific_scene_name: str = None,
         | 
| 32 | 
            +
                    **kwargs,
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    Initialize the dataset attributes.
         | 
| 36 | 
            +
                    Args:
         | 
| 37 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 38 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 39 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 40 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 41 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 42 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    # Initialize the dataset attributes
         | 
| 45 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 46 | 
            +
                    self.ROOT = ROOT
         | 
| 47 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 48 | 
            +
                    self.split = split
         | 
| 49 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 50 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 51 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 52 | 
            +
                    self._load_data()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # Define the dataset type flags
         | 
| 55 | 
            +
                    self.is_metric_scale = True
         | 
| 56 | 
            +
                    self.is_synthetic = True
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def _load_data(self):
         | 
| 59 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 60 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 61 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 62 | 
            +
                        self.dataset_metadata_dir,
         | 
| 63 | 
            +
                        self.split,
         | 
| 64 | 
            +
                        f"dynamicreplica_scene_list_{self.split}.npy",
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Get the list of all scenes
         | 
| 69 | 
            +
                    if not self.sample_specific_scene:
         | 
| 70 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 73 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 76 | 
            +
                    # Get the scene name of the sampled index
         | 
| 77 | 
            +
                    scene_index = sampled_idx
         | 
| 78 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 81 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 82 | 
            +
                    scene_meta = load_data(
         | 
| 83 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 86 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 89 | 
            +
                    covisibility_version_key = "v0"
         | 
| 90 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 91 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 94 | 
            +
                    covisibility_map_name = next(
         | 
| 95 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 98 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 103 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 104 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 108 | 
            +
                    views = []
         | 
| 109 | 
            +
                    for view_index in view_indices:
         | 
| 110 | 
            +
                        # Load the data corresponding to the view
         | 
| 111 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 112 | 
            +
                        view_data = load_frame(
         | 
| 113 | 
            +
                            scene_root,
         | 
| 114 | 
            +
                            view_file_name,
         | 
| 115 | 
            +
                            modalities=["image", "depth"],
         | 
| 116 | 
            +
                            scene_meta=scene_meta,
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # Convert necessary data to numpy
         | 
| 120 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 121 | 
            +
                        image = image[:, :, :3]  # RGBA to RGB
         | 
| 122 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 123 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 128 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 131 | 
            +
                        image, depthmap, intrinsics = self._crop_resize_if_necessary(
         | 
| 132 | 
            +
                            image=image,
         | 
| 133 | 
            +
                            resolution=resolution,
         | 
| 134 | 
            +
                            depthmap=depthmap,
         | 
| 135 | 
            +
                            intrinsics=intrinsics,
         | 
| 136 | 
            +
                            additional_quantities=None,
         | 
| 137 | 
            +
                        )
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 140 | 
            +
                        views.append(
         | 
| 141 | 
            +
                            dict(
         | 
| 142 | 
            +
                                img=image,
         | 
| 143 | 
            +
                                depthmap=depthmap,
         | 
| 144 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 145 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 146 | 
            +
                                dataset="DynamicReplica",
         | 
| 147 | 
            +
                                label=scene_name,
         | 
| 148 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 149 | 
            +
                            )
         | 
| 150 | 
            +
                        )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    return views
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def get_parser():
         | 
| 156 | 
            +
                import argparse
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 159 | 
            +
                parser.add_argument(
         | 
| 160 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/dynamicreplica", type=str
         | 
| 161 | 
            +
                )
         | 
| 162 | 
            +
                parser.add_argument(
         | 
| 163 | 
            +
                    "-dmd",
         | 
| 164 | 
            +
                    "--dataset_metadata_dir",
         | 
| 165 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 166 | 
            +
                    type=str,
         | 
| 167 | 
            +
                )
         | 
| 168 | 
            +
                parser.add_argument(
         | 
| 169 | 
            +
                    "-nv",
         | 
| 170 | 
            +
                    "--num_of_views",
         | 
| 171 | 
            +
                    default=2,
         | 
| 172 | 
            +
                    type=int,
         | 
| 173 | 
            +
                )
         | 
| 174 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                return parser
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            if __name__ == "__main__":
         | 
| 180 | 
            +
                import rerun as rr
         | 
| 181 | 
            +
                from tqdm import tqdm
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 184 | 
            +
                from mapanything.utils.image import rgb
         | 
| 185 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                parser = get_parser()
         | 
| 188 | 
            +
                script_add_rerun_args(
         | 
| 189 | 
            +
                    parser
         | 
| 190 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 191 | 
            +
                args = parser.parse_args()
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                dataset = DynamicReplicaWAI(
         | 
| 194 | 
            +
                    num_views=args.num_of_views,
         | 
| 195 | 
            +
                    split="train",
         | 
| 196 | 
            +
                    covisibility_thres=0.25,
         | 
| 197 | 
            +
                    ROOT=args.root_dir,
         | 
| 198 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 199 | 
            +
                    resolution=(518, 294),
         | 
| 200 | 
            +
                    aug_crop=16,
         | 
| 201 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 202 | 
            +
                    data_norm_type="dinov2",
         | 
| 203 | 
            +
                )
         | 
| 204 | 
            +
                # dataset = DynamicReplicaWAI(
         | 
| 205 | 
            +
                #     num_views=args.num_of_views,
         | 
| 206 | 
            +
                #     split="val",
         | 
| 207 | 
            +
                #     covisibility_thres=0.25,
         | 
| 208 | 
            +
                #     ROOT=args.root_dir,
         | 
| 209 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 210 | 
            +
                #     resolution=(518, 294),
         | 
| 211 | 
            +
                #     seed=777,
         | 
| 212 | 
            +
                #     transform="imgnorm",
         | 
| 213 | 
            +
                #     data_norm_type="dinov2",
         | 
| 214 | 
            +
                # )
         | 
| 215 | 
            +
                print(dataset.get_stats())
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                if args.viz:
         | 
| 218 | 
            +
                    rr.script_setup(args, "DynamicReplica_Dataloader")
         | 
| 219 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 220 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 225 | 
            +
                    views = dataset[idx]
         | 
| 226 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 227 | 
            +
                    sample_name = f"{idx}"
         | 
| 228 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 229 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 230 | 
            +
                    print(sample_name)
         | 
| 231 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 232 | 
            +
                        image = rgb(
         | 
| 233 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 234 | 
            +
                        )
         | 
| 235 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 236 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 237 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 238 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 239 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 240 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 241 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 242 | 
            +
                        else:
         | 
| 243 | 
            +
                            non_ambiguous_mask = None
         | 
| 244 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 245 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 246 | 
            +
                        else:
         | 
| 247 | 
            +
                            prior_depth_along_ray = None
         | 
| 248 | 
            +
                        if args.viz:
         | 
| 249 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 250 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 251 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 252 | 
            +
                            # Log camera info and loaded data
         | 
| 253 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 254 | 
            +
                            rr.log(
         | 
| 255 | 
            +
                                base_name,
         | 
| 256 | 
            +
                                rr.Transform3D(
         | 
| 257 | 
            +
                                    translation=pose[:3, 3],
         | 
| 258 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 259 | 
            +
                                ),
         | 
| 260 | 
            +
                            )
         | 
| 261 | 
            +
                            rr.log(
         | 
| 262 | 
            +
                                f"{base_name}/pinhole",
         | 
| 263 | 
            +
                                rr.Pinhole(
         | 
| 264 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 265 | 
            +
                                    height=height,
         | 
| 266 | 
            +
                                    width=width,
         | 
| 267 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 268 | 
            +
                                ),
         | 
| 269 | 
            +
                            )
         | 
| 270 | 
            +
                            rr.log(
         | 
| 271 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 272 | 
            +
                                rr.Image(image),
         | 
| 273 | 
            +
                            )
         | 
| 274 | 
            +
                            rr.log(
         | 
| 275 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 276 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 277 | 
            +
                            )
         | 
| 278 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 279 | 
            +
                                rr.log(
         | 
| 280 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 281 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 282 | 
            +
                                )
         | 
| 283 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 284 | 
            +
                                rr.log(
         | 
| 285 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 286 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 287 | 
            +
                                )
         | 
| 288 | 
            +
                            # Log points in 3D
         | 
| 289 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 290 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 291 | 
            +
                            rr.log(
         | 
| 292 | 
            +
                                pts_name,
         | 
| 293 | 
            +
                                rr.Points3D(
         | 
| 294 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 295 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 296 | 
            +
                                ),
         | 
| 297 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/eth3d.py
    ADDED
    
    | @@ -0,0 +1,277 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            ETH3D Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 15 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class ETH3DWAI(BaseDataset):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                ETH3D dataset containing high-quality outdoor and indoor scans of the ETH Zurich campus.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    *args,
         | 
| 26 | 
            +
                    ROOT,
         | 
| 27 | 
            +
                    dataset_metadata_dir,
         | 
| 28 | 
            +
                    overfit_num_sets=None,
         | 
| 29 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 30 | 
            +
                    specific_scene_name: str = None,
         | 
| 31 | 
            +
                    **kwargs,
         | 
| 32 | 
            +
                ):
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    Initialize the dataset attributes.
         | 
| 35 | 
            +
                    Args:
         | 
| 36 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 37 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 38 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 39 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 40 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 41 | 
            +
                    """
         | 
| 42 | 
            +
                    # Initialize the dataset attributes
         | 
| 43 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 44 | 
            +
                    self.ROOT = ROOT
         | 
| 45 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 46 | 
            +
                    self.split = "test"
         | 
| 47 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 48 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 49 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 50 | 
            +
                    self._load_data()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    # Define the dataset type flags
         | 
| 53 | 
            +
                    self.is_metric_scale = True
         | 
| 54 | 
            +
                    self.is_synthetic = False
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def _load_data(self):
         | 
| 57 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 58 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 59 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 60 | 
            +
                        self.dataset_metadata_dir,
         | 
| 61 | 
            +
                        self.split,
         | 
| 62 | 
            +
                        f"eth3d_scene_list_{self.split}.npy",
         | 
| 63 | 
            +
                    )
         | 
| 64 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    # Get the list of all scenes
         | 
| 67 | 
            +
                    if not self.sample_specific_scene:
         | 
| 68 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 69 | 
            +
                    else:
         | 
| 70 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 71 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 74 | 
            +
                    # Get the scene name of the sampled index
         | 
| 75 | 
            +
                    scene_index = sampled_idx
         | 
| 76 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 79 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 80 | 
            +
                    scene_meta = load_data(
         | 
| 81 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 84 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 87 | 
            +
                    covisibility_version_key = "v0"
         | 
| 88 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 89 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 92 | 
            +
                    covisibility_map_name = next(
         | 
| 93 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 94 | 
            +
                    )
         | 
| 95 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 96 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 101 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 102 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 106 | 
            +
                    views = []
         | 
| 107 | 
            +
                    for view_index in view_indices:
         | 
| 108 | 
            +
                        # Load the data corresponding to the view
         | 
| 109 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 110 | 
            +
                        view_data = load_frame(
         | 
| 111 | 
            +
                            scene_root,
         | 
| 112 | 
            +
                            view_file_name,
         | 
| 113 | 
            +
                            modalities=["image", "depth"],
         | 
| 114 | 
            +
                            scene_meta=scene_meta,
         | 
| 115 | 
            +
                        )
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                        # Convert necessary data to numpy
         | 
| 118 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 119 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 120 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 121 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 122 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 125 | 
            +
                        image, depthmap, intrinsics = self._crop_resize_if_necessary(
         | 
| 126 | 
            +
                            image=image,
         | 
| 127 | 
            +
                            resolution=resolution,
         | 
| 128 | 
            +
                            depthmap=depthmap,
         | 
| 129 | 
            +
                            intrinsics=intrinsics,
         | 
| 130 | 
            +
                            additional_quantities=None,
         | 
| 131 | 
            +
                        )
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 134 | 
            +
                        views.append(
         | 
| 135 | 
            +
                            dict(
         | 
| 136 | 
            +
                                img=image,
         | 
| 137 | 
            +
                                depthmap=depthmap,
         | 
| 138 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 139 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 140 | 
            +
                                dataset="ETH3D",
         | 
| 141 | 
            +
                                label=scene_name,
         | 
| 142 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 143 | 
            +
                            )
         | 
| 144 | 
            +
                        )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    return views
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def get_parser():
         | 
| 150 | 
            +
                import argparse
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 153 | 
            +
                parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/eth3d", type=str)
         | 
| 154 | 
            +
                parser.add_argument(
         | 
| 155 | 
            +
                    "-dmd",
         | 
| 156 | 
            +
                    "--dataset_metadata_dir",
         | 
| 157 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 158 | 
            +
                    type=str,
         | 
| 159 | 
            +
                )
         | 
| 160 | 
            +
                parser.add_argument(
         | 
| 161 | 
            +
                    "-nv",
         | 
| 162 | 
            +
                    "--num_of_views",
         | 
| 163 | 
            +
                    default=2,
         | 
| 164 | 
            +
                    type=int,
         | 
| 165 | 
            +
                )
         | 
| 166 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                return parser
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            if __name__ == "__main__":
         | 
| 172 | 
            +
                import rerun as rr
         | 
| 173 | 
            +
                from tqdm import tqdm
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 176 | 
            +
                from mapanything.utils.image import rgb
         | 
| 177 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                parser = get_parser()
         | 
| 180 | 
            +
                script_add_rerun_args(
         | 
| 181 | 
            +
                    parser
         | 
| 182 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 183 | 
            +
                args = parser.parse_args()
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                dataset = ETH3DWAI(
         | 
| 186 | 
            +
                    num_views=args.num_of_views,
         | 
| 187 | 
            +
                    covisibility_thres=0.025,
         | 
| 188 | 
            +
                    ROOT=args.root_dir,
         | 
| 189 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 190 | 
            +
                    resolution=(518, 336),
         | 
| 191 | 
            +
                    seed=777,
         | 
| 192 | 
            +
                    transform="imgnorm",
         | 
| 193 | 
            +
                    data_norm_type="dinov2",
         | 
| 194 | 
            +
                )
         | 
| 195 | 
            +
                print(dataset.get_stats())
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                if args.viz:
         | 
| 198 | 
            +
                    rr.script_setup(args, "ETH3D_Dataloader")
         | 
| 199 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 200 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 205 | 
            +
                    views = dataset[idx]
         | 
| 206 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 207 | 
            +
                    sample_name = f"{idx}"
         | 
| 208 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 209 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 210 | 
            +
                    print(sample_name)
         | 
| 211 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 212 | 
            +
                        image = rgb(
         | 
| 213 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 214 | 
            +
                        )
         | 
| 215 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 216 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 217 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 218 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 219 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 220 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 221 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 222 | 
            +
                        else:
         | 
| 223 | 
            +
                            non_ambiguous_mask = None
         | 
| 224 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 225 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 226 | 
            +
                        else:
         | 
| 227 | 
            +
                            prior_depth_along_ray = None
         | 
| 228 | 
            +
                        if args.viz:
         | 
| 229 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 230 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 231 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 232 | 
            +
                            # Log camera info and loaded data
         | 
| 233 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 234 | 
            +
                            rr.log(
         | 
| 235 | 
            +
                                base_name,
         | 
| 236 | 
            +
                                rr.Transform3D(
         | 
| 237 | 
            +
                                    translation=pose[:3, 3],
         | 
| 238 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 239 | 
            +
                                ),
         | 
| 240 | 
            +
                            )
         | 
| 241 | 
            +
                            rr.log(
         | 
| 242 | 
            +
                                f"{base_name}/pinhole",
         | 
| 243 | 
            +
                                rr.Pinhole(
         | 
| 244 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 245 | 
            +
                                    height=height,
         | 
| 246 | 
            +
                                    width=width,
         | 
| 247 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 248 | 
            +
                                ),
         | 
| 249 | 
            +
                            )
         | 
| 250 | 
            +
                            rr.log(
         | 
| 251 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 252 | 
            +
                                rr.Image(image),
         | 
| 253 | 
            +
                            )
         | 
| 254 | 
            +
                            rr.log(
         | 
| 255 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 256 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 257 | 
            +
                            )
         | 
| 258 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 259 | 
            +
                                rr.log(
         | 
| 260 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 261 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 262 | 
            +
                                )
         | 
| 263 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 264 | 
            +
                                rr.log(
         | 
| 265 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 266 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 267 | 
            +
                                )
         | 
| 268 | 
            +
                            # Log points in 3D
         | 
| 269 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 270 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 271 | 
            +
                            rr.log(
         | 
| 272 | 
            +
                                pts_name,
         | 
| 273 | 
            +
                                rr.Points3D(
         | 
| 274 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 275 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 276 | 
            +
                                ),
         | 
| 277 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/megadepth.py
    ADDED
    
    | @@ -0,0 +1,314 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            MegaDepth Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import cv2
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 16 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class MegaDepthWAI(BaseDataset):
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                MegaDepth dataset containing outdoor phototourism and in-the-wild scenes.
         | 
| 22 | 
            +
                Also includes Tanks & Temples scenes.
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def __init__(
         | 
| 26 | 
            +
                    self,
         | 
| 27 | 
            +
                    *args,
         | 
| 28 | 
            +
                    ROOT,
         | 
| 29 | 
            +
                    dataset_metadata_dir,
         | 
| 30 | 
            +
                    split,
         | 
| 31 | 
            +
                    overfit_num_sets=None,
         | 
| 32 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 33 | 
            +
                    specific_scene_name: str = None,
         | 
| 34 | 
            +
                    **kwargs,
         | 
| 35 | 
            +
                ):
         | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    Initialize the dataset attributes.
         | 
| 38 | 
            +
                    Args:
         | 
| 39 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 40 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 41 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 42 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 43 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 44 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 45 | 
            +
                    """
         | 
| 46 | 
            +
                    # Initialize the dataset attributes
         | 
| 47 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 48 | 
            +
                    self.ROOT = ROOT
         | 
| 49 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 50 | 
            +
                    self.split = split
         | 
| 51 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 52 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 53 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 54 | 
            +
                    self._load_data()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    # Define the dataset type flags
         | 
| 57 | 
            +
                    self.is_metric_scale = False
         | 
| 58 | 
            +
                    self.is_synthetic = False
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def _load_data(self):
         | 
| 61 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 62 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 63 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 64 | 
            +
                        self.dataset_metadata_dir,
         | 
| 65 | 
            +
                        self.split,
         | 
| 66 | 
            +
                        f"megadepth_scene_list_{self.split}.npy",
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    # Get the list of all scenes
         | 
| 71 | 
            +
                    if not self.sample_specific_scene:
         | 
| 72 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 73 | 
            +
                    else:
         | 
| 74 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 75 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 78 | 
            +
                    # Get the scene name of the sampled index
         | 
| 79 | 
            +
                    scene_index = sampled_idx
         | 
| 80 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 83 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 84 | 
            +
                    scene_meta = load_data(
         | 
| 85 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 88 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 91 | 
            +
                    covisibility_version_key = "v0"
         | 
| 92 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 93 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 94 | 
            +
                    )
         | 
| 95 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 96 | 
            +
                    covisibility_map_name = next(
         | 
| 97 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 98 | 
            +
                    )
         | 
| 99 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 100 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 101 | 
            +
                    )
         | 
| 102 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 105 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 106 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 110 | 
            +
                    views = []
         | 
| 111 | 
            +
                    for view_index in view_indices:
         | 
| 112 | 
            +
                        # Load the data corresponding to the view
         | 
| 113 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 114 | 
            +
                        view_data = load_frame(
         | 
| 115 | 
            +
                            scene_root,
         | 
| 116 | 
            +
                            view_file_name,
         | 
| 117 | 
            +
                            modalities=["image", "depth", "pred_mask/moge2"],
         | 
| 118 | 
            +
                            scene_meta=scene_meta,
         | 
| 119 | 
            +
                        )
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                        # Convert necessary data to numpy
         | 
| 122 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 123 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 124 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 125 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 126 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 129 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        # Get the non_ambiguous_mask and ensure it matches image resolution
         | 
| 132 | 
            +
                        non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
         | 
| 133 | 
            +
                        non_ambiguous_mask = cv2.resize(
         | 
| 134 | 
            +
                            non_ambiguous_mask,
         | 
| 135 | 
            +
                            (image.shape[1], image.shape[0]),
         | 
| 136 | 
            +
                            interpolation=cv2.INTER_NEAREST,
         | 
| 137 | 
            +
                        )
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        # Mask out the GT depth using the non_ambiguous_mask
         | 
| 140 | 
            +
                        depthmap = np.where(non_ambiguous_mask, depthmap, 0)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 143 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 144 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 145 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 146 | 
            +
                                image=image,
         | 
| 147 | 
            +
                                resolution=resolution,
         | 
| 148 | 
            +
                                depthmap=depthmap,
         | 
| 149 | 
            +
                                intrinsics=intrinsics,
         | 
| 150 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 151 | 
            +
                            )
         | 
| 152 | 
            +
                        )
         | 
| 153 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 156 | 
            +
                        views.append(
         | 
| 157 | 
            +
                            dict(
         | 
| 158 | 
            +
                                img=image,
         | 
| 159 | 
            +
                                depthmap=depthmap,
         | 
| 160 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 161 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 162 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 163 | 
            +
                                dataset="MegaDepth",
         | 
| 164 | 
            +
                                label=scene_name,
         | 
| 165 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 166 | 
            +
                            )
         | 
| 167 | 
            +
                        )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    return views
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            def get_parser():
         | 
| 173 | 
            +
                import argparse
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 176 | 
            +
                parser.add_argument(
         | 
| 177 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/megadepth", type=str
         | 
| 178 | 
            +
                )
         | 
| 179 | 
            +
                parser.add_argument(
         | 
| 180 | 
            +
                    "-dmd",
         | 
| 181 | 
            +
                    "--dataset_metadata_dir",
         | 
| 182 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 183 | 
            +
                    type=str,
         | 
| 184 | 
            +
                )
         | 
| 185 | 
            +
                parser.add_argument(
         | 
| 186 | 
            +
                    "-nv",
         | 
| 187 | 
            +
                    "--num_of_views",
         | 
| 188 | 
            +
                    default=2,
         | 
| 189 | 
            +
                    type=int,
         | 
| 190 | 
            +
                )
         | 
| 191 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                return parser
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            if __name__ == "__main__":
         | 
| 197 | 
            +
                import rerun as rr
         | 
| 198 | 
            +
                from tqdm import tqdm
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 201 | 
            +
                from mapanything.utils.image import rgb
         | 
| 202 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                parser = get_parser()
         | 
| 205 | 
            +
                script_add_rerun_args(
         | 
| 206 | 
            +
                    parser
         | 
| 207 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 208 | 
            +
                args = parser.parse_args()
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                dataset = MegaDepthWAI(
         | 
| 211 | 
            +
                    num_views=args.num_of_views,
         | 
| 212 | 
            +
                    split="train",
         | 
| 213 | 
            +
                    covisibility_thres=0.25,
         | 
| 214 | 
            +
                    ROOT=args.root_dir,
         | 
| 215 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 216 | 
            +
                    resolution=(518, 336),
         | 
| 217 | 
            +
                    aug_crop=16,
         | 
| 218 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 219 | 
            +
                    data_norm_type="dinov2",
         | 
| 220 | 
            +
                )
         | 
| 221 | 
            +
                # dataset = MegaDepthWAI(
         | 
| 222 | 
            +
                #     num_views=args.num_of_views,
         | 
| 223 | 
            +
                #     split="val",
         | 
| 224 | 
            +
                #     covisibility_thres=0.25,
         | 
| 225 | 
            +
                #     ROOT=args.root_dir,
         | 
| 226 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 227 | 
            +
                #     resolution=(518, 336),
         | 
| 228 | 
            +
                #     seed=777,
         | 
| 229 | 
            +
                #     transform="imgnorm",
         | 
| 230 | 
            +
                #     data_norm_type="dinov2",
         | 
| 231 | 
            +
                # )
         | 
| 232 | 
            +
                print(dataset.get_stats())
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                if args.viz:
         | 
| 235 | 
            +
                    rr.script_setup(args, "MegaDepth_Dataloader")
         | 
| 236 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 237 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 242 | 
            +
                    views = dataset[idx]
         | 
| 243 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 244 | 
            +
                    sample_name = f"{idx}"
         | 
| 245 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 246 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 247 | 
            +
                    print(sample_name)
         | 
| 248 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 249 | 
            +
                        image = rgb(
         | 
| 250 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 251 | 
            +
                        )
         | 
| 252 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 253 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 254 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 255 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 256 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 257 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 258 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 259 | 
            +
                        else:
         | 
| 260 | 
            +
                            non_ambiguous_mask = None
         | 
| 261 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 262 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 263 | 
            +
                        else:
         | 
| 264 | 
            +
                            prior_depth_along_ray = None
         | 
| 265 | 
            +
                        if args.viz:
         | 
| 266 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 267 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 268 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 269 | 
            +
                            # Log camera info and loaded data
         | 
| 270 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 271 | 
            +
                            rr.log(
         | 
| 272 | 
            +
                                base_name,
         | 
| 273 | 
            +
                                rr.Transform3D(
         | 
| 274 | 
            +
                                    translation=pose[:3, 3],
         | 
| 275 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 276 | 
            +
                                ),
         | 
| 277 | 
            +
                            )
         | 
| 278 | 
            +
                            rr.log(
         | 
| 279 | 
            +
                                f"{base_name}/pinhole",
         | 
| 280 | 
            +
                                rr.Pinhole(
         | 
| 281 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 282 | 
            +
                                    height=height,
         | 
| 283 | 
            +
                                    width=width,
         | 
| 284 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 285 | 
            +
                                ),
         | 
| 286 | 
            +
                            )
         | 
| 287 | 
            +
                            rr.log(
         | 
| 288 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 289 | 
            +
                                rr.Image(image),
         | 
| 290 | 
            +
                            )
         | 
| 291 | 
            +
                            rr.log(
         | 
| 292 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 293 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 294 | 
            +
                            )
         | 
| 295 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 296 | 
            +
                                rr.log(
         | 
| 297 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 298 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 299 | 
            +
                                )
         | 
| 300 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 301 | 
            +
                                rr.log(
         | 
| 302 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 303 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 304 | 
            +
                                )
         | 
| 305 | 
            +
                            # Log points in 3D
         | 
| 306 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 307 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 308 | 
            +
                            rr.log(
         | 
| 309 | 
            +
                                pts_name,
         | 
| 310 | 
            +
                                rr.Points3D(
         | 
| 311 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 312 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 313 | 
            +
                                ),
         | 
| 314 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/mpsd.py
    ADDED
    
    | @@ -0,0 +1,311 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            MPSD Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import cv2
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 16 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class MPSDWAI(BaseDataset):
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                MPSD dataset containing outdoor planet scale metric reconstructions.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    *args,
         | 
| 27 | 
            +
                    ROOT,
         | 
| 28 | 
            +
                    dataset_metadata_dir,
         | 
| 29 | 
            +
                    split,
         | 
| 30 | 
            +
                    overfit_num_sets=None,
         | 
| 31 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 32 | 
            +
                    specific_scene_name: str = None,
         | 
| 33 | 
            +
                    **kwargs,
         | 
| 34 | 
            +
                ):
         | 
| 35 | 
            +
                    """
         | 
| 36 | 
            +
                    Initialize the dataset attributes.
         | 
| 37 | 
            +
                    Args:
         | 
| 38 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 39 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 40 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 41 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 42 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 43 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    # Initialize the dataset attributes
         | 
| 46 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 47 | 
            +
                    self.ROOT = ROOT
         | 
| 48 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 49 | 
            +
                    self.split = split
         | 
| 50 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 51 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 52 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 53 | 
            +
                    self._load_data()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Define the dataset type flags
         | 
| 56 | 
            +
                    self.is_metric_scale = True
         | 
| 57 | 
            +
                    self.is_synthetic = False
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def _load_data(self):
         | 
| 60 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 61 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 62 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 63 | 
            +
                        self.dataset_metadata_dir,
         | 
| 64 | 
            +
                        self.split,
         | 
| 65 | 
            +
                        f"mpsd_scene_list_{self.split}.npy",
         | 
| 66 | 
            +
                    )
         | 
| 67 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # Get the list of all scenes
         | 
| 70 | 
            +
                    if not self.sample_specific_scene:
         | 
| 71 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 72 | 
            +
                    else:
         | 
| 73 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 74 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 77 | 
            +
                    # Get the scene name of the sampled index
         | 
| 78 | 
            +
                    scene_index = sampled_idx
         | 
| 79 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 82 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 83 | 
            +
                    scene_meta = load_data(
         | 
| 84 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 87 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 90 | 
            +
                    covisibility_version_key = "v0"
         | 
| 91 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 92 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 95 | 
            +
                    covisibility_map_name = next(
         | 
| 96 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 99 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 104 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 105 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 109 | 
            +
                    views = []
         | 
| 110 | 
            +
                    for view_index in view_indices:
         | 
| 111 | 
            +
                        # Load the data corresponding to the view
         | 
| 112 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 113 | 
            +
                        view_data = load_frame(
         | 
| 114 | 
            +
                            scene_root,
         | 
| 115 | 
            +
                            view_file_name,
         | 
| 116 | 
            +
                            modalities=["image", "depth", "pred_mask/moge2"],
         | 
| 117 | 
            +
                            scene_meta=scene_meta,
         | 
| 118 | 
            +
                        )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                        # Convert necessary data to numpy
         | 
| 121 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 122 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 123 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 128 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        # Get the non_ambiguous_mask and ensure it matches image resolution
         | 
| 131 | 
            +
                        non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
         | 
| 132 | 
            +
                        non_ambiguous_mask = cv2.resize(
         | 
| 133 | 
            +
                            non_ambiguous_mask,
         | 
| 134 | 
            +
                            (image.shape[1], image.shape[0]),
         | 
| 135 | 
            +
                            interpolation=cv2.INTER_NEAREST,
         | 
| 136 | 
            +
                        )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        # Mask out the GT depth using the non_ambiguous_mask
         | 
| 139 | 
            +
                        depthmap = np.where(non_ambiguous_mask, depthmap, 0)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 142 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 143 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 144 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 145 | 
            +
                                image=image,
         | 
| 146 | 
            +
                                resolution=resolution,
         | 
| 147 | 
            +
                                depthmap=depthmap,
         | 
| 148 | 
            +
                                intrinsics=intrinsics,
         | 
| 149 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 150 | 
            +
                            )
         | 
| 151 | 
            +
                        )
         | 
| 152 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 155 | 
            +
                        views.append(
         | 
| 156 | 
            +
                            dict(
         | 
| 157 | 
            +
                                img=image,
         | 
| 158 | 
            +
                                depthmap=depthmap,
         | 
| 159 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 160 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 161 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 162 | 
            +
                                dataset="MPSD",
         | 
| 163 | 
            +
                                label=scene_name,
         | 
| 164 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 165 | 
            +
                            )
         | 
| 166 | 
            +
                        )
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    return views
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def get_parser():
         | 
| 172 | 
            +
                import argparse
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 175 | 
            +
                parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/mpsd", type=str)
         | 
| 176 | 
            +
                parser.add_argument(
         | 
| 177 | 
            +
                    "-dmd",
         | 
| 178 | 
            +
                    "--dataset_metadata_dir",
         | 
| 179 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 180 | 
            +
                    type=str,
         | 
| 181 | 
            +
                )
         | 
| 182 | 
            +
                parser.add_argument(
         | 
| 183 | 
            +
                    "-nv",
         | 
| 184 | 
            +
                    "--num_of_views",
         | 
| 185 | 
            +
                    default=2,
         | 
| 186 | 
            +
                    type=int,
         | 
| 187 | 
            +
                )
         | 
| 188 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                return parser
         | 
| 191 | 
            +
             | 
| 192 | 
            +
             | 
| 193 | 
            +
            if __name__ == "__main__":
         | 
| 194 | 
            +
                import rerun as rr
         | 
| 195 | 
            +
                from tqdm import tqdm
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 198 | 
            +
                from mapanything.utils.image import rgb
         | 
| 199 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                parser = get_parser()
         | 
| 202 | 
            +
                script_add_rerun_args(
         | 
| 203 | 
            +
                    parser
         | 
| 204 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 205 | 
            +
                args = parser.parse_args()
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                dataset = MPSDWAI(
         | 
| 208 | 
            +
                    num_views=args.num_of_views,
         | 
| 209 | 
            +
                    split="train",
         | 
| 210 | 
            +
                    covisibility_thres=0.15,
         | 
| 211 | 
            +
                    ROOT=args.root_dir,
         | 
| 212 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 213 | 
            +
                    resolution=(518, 392),
         | 
| 214 | 
            +
                    aug_crop=16,
         | 
| 215 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 216 | 
            +
                    data_norm_type="dinov2",
         | 
| 217 | 
            +
                )
         | 
| 218 | 
            +
                # dataset = MPSDWAI(
         | 
| 219 | 
            +
                #     num_views=args.num_of_views,
         | 
| 220 | 
            +
                #     split="val",
         | 
| 221 | 
            +
                #     covisibility_thres=0.15,
         | 
| 222 | 
            +
                #     ROOT=args.root_dir,
         | 
| 223 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 224 | 
            +
                #     resolution=(518, 392),
         | 
| 225 | 
            +
                #     seed=777,
         | 
| 226 | 
            +
                #     transform="imgnorm",
         | 
| 227 | 
            +
                #     data_norm_type="dinov2",
         | 
| 228 | 
            +
                # )
         | 
| 229 | 
            +
                print(dataset.get_stats())
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                if args.viz:
         | 
| 232 | 
            +
                    rr.script_setup(args, "MPSD_Dataloader")
         | 
| 233 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 234 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 239 | 
            +
                    views = dataset[idx]
         | 
| 240 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 241 | 
            +
                    sample_name = f"{idx}"
         | 
| 242 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 243 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 244 | 
            +
                    print(sample_name)
         | 
| 245 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 246 | 
            +
                        image = rgb(
         | 
| 247 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 248 | 
            +
                        )
         | 
| 249 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 250 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 251 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 252 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 253 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 254 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 255 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 256 | 
            +
                        else:
         | 
| 257 | 
            +
                            non_ambiguous_mask = None
         | 
| 258 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 259 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 260 | 
            +
                        else:
         | 
| 261 | 
            +
                            prior_depth_along_ray = None
         | 
| 262 | 
            +
                        if args.viz:
         | 
| 263 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 264 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 265 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 266 | 
            +
                            # Log camera info and loaded data
         | 
| 267 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 268 | 
            +
                            rr.log(
         | 
| 269 | 
            +
                                base_name,
         | 
| 270 | 
            +
                                rr.Transform3D(
         | 
| 271 | 
            +
                                    translation=pose[:3, 3],
         | 
| 272 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 273 | 
            +
                                ),
         | 
| 274 | 
            +
                            )
         | 
| 275 | 
            +
                            rr.log(
         | 
| 276 | 
            +
                                f"{base_name}/pinhole",
         | 
| 277 | 
            +
                                rr.Pinhole(
         | 
| 278 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 279 | 
            +
                                    height=height,
         | 
| 280 | 
            +
                                    width=width,
         | 
| 281 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 282 | 
            +
                                ),
         | 
| 283 | 
            +
                            )
         | 
| 284 | 
            +
                            rr.log(
         | 
| 285 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 286 | 
            +
                                rr.Image(image),
         | 
| 287 | 
            +
                            )
         | 
| 288 | 
            +
                            rr.log(
         | 
| 289 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 290 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 291 | 
            +
                            )
         | 
| 292 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 293 | 
            +
                                rr.log(
         | 
| 294 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 295 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 296 | 
            +
                                )
         | 
| 297 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 298 | 
            +
                                rr.log(
         | 
| 299 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 300 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 301 | 
            +
                                )
         | 
| 302 | 
            +
                            # Log points in 3D
         | 
| 303 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 304 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 305 | 
            +
                            rr.log(
         | 
| 306 | 
            +
                                pts_name,
         | 
| 307 | 
            +
                                rr.Points3D(
         | 
| 308 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 309 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 310 | 
            +
                                ),
         | 
| 311 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/mvs_synth.py
    ADDED
    
    | @@ -0,0 +1,308 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            MVS Synth Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 15 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class MVSSynthWAI(BaseDataset):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                MVS Synth dataset containing large diversity of synthetic in-the-wild scenes.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    *args,
         | 
| 26 | 
            +
                    ROOT,
         | 
| 27 | 
            +
                    dataset_metadata_dir,
         | 
| 28 | 
            +
                    split,
         | 
| 29 | 
            +
                    overfit_num_sets=None,
         | 
| 30 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 31 | 
            +
                    specific_scene_name: str = None,
         | 
| 32 | 
            +
                    **kwargs,
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    Initialize the dataset attributes.
         | 
| 36 | 
            +
                    Args:
         | 
| 37 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 38 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 39 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 40 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 41 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 42 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    # Initialize the dataset attributes
         | 
| 45 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 46 | 
            +
                    self.ROOT = ROOT
         | 
| 47 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 48 | 
            +
                    self.split = split
         | 
| 49 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 50 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 51 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 52 | 
            +
                    self._load_data()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # Define the dataset type flags
         | 
| 55 | 
            +
                    self.is_metric_scale = True
         | 
| 56 | 
            +
                    self.is_synthetic = True
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def _load_data(self):
         | 
| 59 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 60 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 61 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 62 | 
            +
                        self.dataset_metadata_dir,
         | 
| 63 | 
            +
                        self.split,
         | 
| 64 | 
            +
                        f"mvs_synth_scene_list_{self.split}.npy",
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Get the list of all scenes
         | 
| 69 | 
            +
                    if not self.sample_specific_scene:
         | 
| 70 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 73 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 76 | 
            +
                    # Get the scene name of the sampled index
         | 
| 77 | 
            +
                    scene_index = sampled_idx
         | 
| 78 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 81 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 82 | 
            +
                    scene_meta = load_data(
         | 
| 83 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 86 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 89 | 
            +
                    covisibility_version_key = "v0"
         | 
| 90 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 91 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 94 | 
            +
                    covisibility_map_name = next(
         | 
| 95 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 98 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 103 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 104 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 108 | 
            +
                    views = []
         | 
| 109 | 
            +
                    for view_index in view_indices:
         | 
| 110 | 
            +
                        # Load the data corresponding to the view
         | 
| 111 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 112 | 
            +
                        view_data = load_frame(
         | 
| 113 | 
            +
                            scene_root,
         | 
| 114 | 
            +
                            view_file_name,
         | 
| 115 | 
            +
                            modalities=["image", "depth"],
         | 
| 116 | 
            +
                            scene_meta=scene_meta,
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # Convert necessary data to numpy
         | 
| 120 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 121 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 122 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 123 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 127 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
         | 
| 130 | 
            +
                        non_ambiguous_mask = (depthmap > 0).astype(int)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                        # Mask out the outlier depth (horizon depth)
         | 
| 133 | 
            +
                        percentile_depth = np.percentile(depthmap, 95)
         | 
| 134 | 
            +
                        depthmap[depthmap > percentile_depth] = 0
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 137 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 138 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 139 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 140 | 
            +
                                image=image,
         | 
| 141 | 
            +
                                resolution=resolution,
         | 
| 142 | 
            +
                                depthmap=depthmap,
         | 
| 143 | 
            +
                                intrinsics=intrinsics,
         | 
| 144 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 145 | 
            +
                            )
         | 
| 146 | 
            +
                        )
         | 
| 147 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 150 | 
            +
                        views.append(
         | 
| 151 | 
            +
                            dict(
         | 
| 152 | 
            +
                                img=image,
         | 
| 153 | 
            +
                                depthmap=depthmap,
         | 
| 154 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 155 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 156 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 157 | 
            +
                                dataset="MVSSynth",
         | 
| 158 | 
            +
                                label=scene_name,
         | 
| 159 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 160 | 
            +
                            )
         | 
| 161 | 
            +
                        )
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    return views
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            def get_parser():
         | 
| 167 | 
            +
                import argparse
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 170 | 
            +
                parser.add_argument(
         | 
| 171 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/mvs_synth", type=str
         | 
| 172 | 
            +
                )
         | 
| 173 | 
            +
                parser.add_argument(
         | 
| 174 | 
            +
                    "-dmd",
         | 
| 175 | 
            +
                    "--dataset_metadata_dir",
         | 
| 176 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 177 | 
            +
                    type=str,
         | 
| 178 | 
            +
                )
         | 
| 179 | 
            +
                parser.add_argument(
         | 
| 180 | 
            +
                    "-nv",
         | 
| 181 | 
            +
                    "--num_of_views",
         | 
| 182 | 
            +
                    default=2,
         | 
| 183 | 
            +
                    type=int,
         | 
| 184 | 
            +
                )
         | 
| 185 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                return parser
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            if __name__ == "__main__":
         | 
| 191 | 
            +
                import rerun as rr
         | 
| 192 | 
            +
                from tqdm import tqdm
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 195 | 
            +
                from mapanything.utils.image import rgb
         | 
| 196 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                parser = get_parser()
         | 
| 199 | 
            +
                script_add_rerun_args(
         | 
| 200 | 
            +
                    parser
         | 
| 201 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 202 | 
            +
                args = parser.parse_args()
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                dataset = MVSSynthWAI(
         | 
| 205 | 
            +
                    num_views=args.num_of_views,
         | 
| 206 | 
            +
                    split="train",
         | 
| 207 | 
            +
                    covisibility_thres=0.25,
         | 
| 208 | 
            +
                    ROOT=args.root_dir,
         | 
| 209 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 210 | 
            +
                    resolution=(518, 294),
         | 
| 211 | 
            +
                    aug_crop=16,
         | 
| 212 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 213 | 
            +
                    data_norm_type="dinov2",
         | 
| 214 | 
            +
                )
         | 
| 215 | 
            +
                # dataset = MVSSynthWAI(
         | 
| 216 | 
            +
                #     num_views=args.num_of_views,
         | 
| 217 | 
            +
                #     split="val",
         | 
| 218 | 
            +
                #     covisibility_thres=0.25,
         | 
| 219 | 
            +
                #     ROOT=args.root_dir,
         | 
| 220 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 221 | 
            +
                #     resolution=(518, 294),
         | 
| 222 | 
            +
                #     seed=777,
         | 
| 223 | 
            +
                #     transform="imgnorm",
         | 
| 224 | 
            +
                #     data_norm_type="dinov2",
         | 
| 225 | 
            +
                # )
         | 
| 226 | 
            +
                print(dataset.get_stats())
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                if args.viz:
         | 
| 229 | 
            +
                    rr.script_setup(args, "MVSSynth_Dataloader")
         | 
| 230 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 231 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 236 | 
            +
                    views = dataset[idx]
         | 
| 237 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 238 | 
            +
                    sample_name = f"{idx}"
         | 
| 239 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 240 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 241 | 
            +
                    print(sample_name)
         | 
| 242 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 243 | 
            +
                        image = rgb(
         | 
| 244 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 245 | 
            +
                        )
         | 
| 246 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 247 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 248 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 249 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 250 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 251 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 252 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 253 | 
            +
                        else:
         | 
| 254 | 
            +
                            non_ambiguous_mask = None
         | 
| 255 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 256 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 257 | 
            +
                        else:
         | 
| 258 | 
            +
                            prior_depth_along_ray = None
         | 
| 259 | 
            +
                        if args.viz:
         | 
| 260 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 261 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 262 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 263 | 
            +
                            # Log camera info and loaded data
         | 
| 264 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 265 | 
            +
                            rr.log(
         | 
| 266 | 
            +
                                base_name,
         | 
| 267 | 
            +
                                rr.Transform3D(
         | 
| 268 | 
            +
                                    translation=pose[:3, 3],
         | 
| 269 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 270 | 
            +
                                ),
         | 
| 271 | 
            +
                            )
         | 
| 272 | 
            +
                            rr.log(
         | 
| 273 | 
            +
                                f"{base_name}/pinhole",
         | 
| 274 | 
            +
                                rr.Pinhole(
         | 
| 275 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 276 | 
            +
                                    height=height,
         | 
| 277 | 
            +
                                    width=width,
         | 
| 278 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 279 | 
            +
                                ),
         | 
| 280 | 
            +
                            )
         | 
| 281 | 
            +
                            rr.log(
         | 
| 282 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 283 | 
            +
                                rr.Image(image),
         | 
| 284 | 
            +
                            )
         | 
| 285 | 
            +
                            rr.log(
         | 
| 286 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 287 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 288 | 
            +
                            )
         | 
| 289 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 290 | 
            +
                                rr.log(
         | 
| 291 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 292 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 293 | 
            +
                                )
         | 
| 294 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 295 | 
            +
                                rr.log(
         | 
| 296 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 297 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 298 | 
            +
                                )
         | 
| 299 | 
            +
                            # Log points in 3D
         | 
| 300 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 301 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 302 | 
            +
                            rr.log(
         | 
| 303 | 
            +
                                pts_name,
         | 
| 304 | 
            +
                                rr.Points3D(
         | 
| 305 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 306 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 307 | 
            +
                                ),
         | 
| 308 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/paralleldomain4d.py
    ADDED
    
    | @@ -0,0 +1,309 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            Parallel Domain 4D Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 15 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class ParallelDomain4DWAI(BaseDataset):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                Parallel Domain 4D dataset containing large diversity of synthetic AV scenes.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    *args,
         | 
| 26 | 
            +
                    ROOT,
         | 
| 27 | 
            +
                    dataset_metadata_dir,
         | 
| 28 | 
            +
                    split,
         | 
| 29 | 
            +
                    overfit_num_sets=None,
         | 
| 30 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 31 | 
            +
                    specific_scene_name: str = None,
         | 
| 32 | 
            +
                    **kwargs,
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    Initialize the dataset attributes.
         | 
| 36 | 
            +
                    Args:
         | 
| 37 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 38 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 39 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 40 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 41 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 42 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    # Initialize the dataset attributes
         | 
| 45 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 46 | 
            +
                    self.ROOT = ROOT
         | 
| 47 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 48 | 
            +
                    self.split = split
         | 
| 49 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 50 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 51 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 52 | 
            +
                    self._load_data()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # Define the dataset type flags
         | 
| 55 | 
            +
                    self.is_metric_scale = True
         | 
| 56 | 
            +
                    self.is_synthetic = True
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def _load_data(self):
         | 
| 59 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 60 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 61 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 62 | 
            +
                        self.dataset_metadata_dir,
         | 
| 63 | 
            +
                        self.split,
         | 
| 64 | 
            +
                        f"paralleldomain4d_scene_list_{self.split}.npy",
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Get the list of all scenes
         | 
| 69 | 
            +
                    if not self.sample_specific_scene:
         | 
| 70 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 73 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 76 | 
            +
                    # Get the scene name of the sampled index
         | 
| 77 | 
            +
                    scene_index = sampled_idx
         | 
| 78 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 81 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 82 | 
            +
                    scene_meta = load_data(
         | 
| 83 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 86 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 89 | 
            +
                    covisibility_version_key = "v0"
         | 
| 90 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 91 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 94 | 
            +
                    covisibility_map_name = next(
         | 
| 95 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 98 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 103 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 104 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 108 | 
            +
                    views = []
         | 
| 109 | 
            +
                    for view_index in view_indices:
         | 
| 110 | 
            +
                        # Load the data corresponding to the view
         | 
| 111 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 112 | 
            +
                        view_data = load_frame(
         | 
| 113 | 
            +
                            scene_root,
         | 
| 114 | 
            +
                            view_file_name,
         | 
| 115 | 
            +
                            modalities=["image", "depth"],
         | 
| 116 | 
            +
                            scene_meta=scene_meta,
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # Convert necessary data to numpy
         | 
| 120 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 121 | 
            +
                        image = image[:, :, :3]  # RGBA to RGB
         | 
| 122 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 123 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 128 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
         | 
| 131 | 
            +
                        non_ambiguous_mask = (depthmap > 0).astype(int)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                        # Mask out the outlier depth (horizon depth)
         | 
| 134 | 
            +
                        percentile_depth = np.percentile(depthmap, 95)
         | 
| 135 | 
            +
                        depthmap[depthmap > percentile_depth] = 0
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 138 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 139 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 140 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 141 | 
            +
                                image=image,
         | 
| 142 | 
            +
                                resolution=resolution,
         | 
| 143 | 
            +
                                depthmap=depthmap,
         | 
| 144 | 
            +
                                intrinsics=intrinsics,
         | 
| 145 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 146 | 
            +
                            )
         | 
| 147 | 
            +
                        )
         | 
| 148 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 151 | 
            +
                        views.append(
         | 
| 152 | 
            +
                            dict(
         | 
| 153 | 
            +
                                img=image,
         | 
| 154 | 
            +
                                depthmap=depthmap,
         | 
| 155 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 156 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 157 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 158 | 
            +
                                dataset="ParallelDomain4D",
         | 
| 159 | 
            +
                                label=scene_name,
         | 
| 160 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 161 | 
            +
                            )
         | 
| 162 | 
            +
                        )
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    return views
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def get_parser():
         | 
| 168 | 
            +
                import argparse
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 171 | 
            +
                parser.add_argument(
         | 
| 172 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/paralleldomain4d", type=str
         | 
| 173 | 
            +
                )
         | 
| 174 | 
            +
                parser.add_argument(
         | 
| 175 | 
            +
                    "-dmd",
         | 
| 176 | 
            +
                    "--dataset_metadata_dir",
         | 
| 177 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 178 | 
            +
                    type=str,
         | 
| 179 | 
            +
                )
         | 
| 180 | 
            +
                parser.add_argument(
         | 
| 181 | 
            +
                    "-nv",
         | 
| 182 | 
            +
                    "--num_of_views",
         | 
| 183 | 
            +
                    default=2,
         | 
| 184 | 
            +
                    type=int,
         | 
| 185 | 
            +
                )
         | 
| 186 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                return parser
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            if __name__ == "__main__":
         | 
| 192 | 
            +
                import rerun as rr
         | 
| 193 | 
            +
                from tqdm import tqdm
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 196 | 
            +
                from mapanything.utils.image import rgb
         | 
| 197 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                parser = get_parser()
         | 
| 200 | 
            +
                script_add_rerun_args(
         | 
| 201 | 
            +
                    parser
         | 
| 202 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 203 | 
            +
                args = parser.parse_args()
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                dataset = ParallelDomain4DWAI(
         | 
| 206 | 
            +
                    num_views=args.num_of_views,
         | 
| 207 | 
            +
                    split="train",
         | 
| 208 | 
            +
                    covisibility_thres=0.25,
         | 
| 209 | 
            +
                    ROOT=args.root_dir,
         | 
| 210 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 211 | 
            +
                    resolution=(518, 392),
         | 
| 212 | 
            +
                    aug_crop=16,
         | 
| 213 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 214 | 
            +
                    data_norm_type="dinov2",
         | 
| 215 | 
            +
                )
         | 
| 216 | 
            +
                # dataset = ParallelDomain4DWAI(
         | 
| 217 | 
            +
                #     num_views=args.num_of_views,
         | 
| 218 | 
            +
                #     split="val",
         | 
| 219 | 
            +
                #     covisibility_thres=0.25,
         | 
| 220 | 
            +
                #     ROOT=args.root_dir,
         | 
| 221 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 222 | 
            +
                #     resolution=(518, 392),
         | 
| 223 | 
            +
                #     seed=777,
         | 
| 224 | 
            +
                #     transform="imgnorm",
         | 
| 225 | 
            +
                #     data_norm_type="dinov2",
         | 
| 226 | 
            +
                # )
         | 
| 227 | 
            +
                print(dataset.get_stats())
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                if args.viz:
         | 
| 230 | 
            +
                    rr.script_setup(args, "ParallelDomain4D_Dataloader")
         | 
| 231 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 232 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 237 | 
            +
                    views = dataset[idx]
         | 
| 238 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 239 | 
            +
                    sample_name = f"{idx}"
         | 
| 240 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 241 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 242 | 
            +
                    print(sample_name)
         | 
| 243 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 244 | 
            +
                        image = rgb(
         | 
| 245 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 246 | 
            +
                        )
         | 
| 247 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 248 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 249 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 250 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 251 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 252 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 253 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 254 | 
            +
                        else:
         | 
| 255 | 
            +
                            non_ambiguous_mask = None
         | 
| 256 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 257 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 258 | 
            +
                        else:
         | 
| 259 | 
            +
                            prior_depth_along_ray = None
         | 
| 260 | 
            +
                        if args.viz:
         | 
| 261 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 262 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 263 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 264 | 
            +
                            # Log camera info and loaded data
         | 
| 265 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 266 | 
            +
                            rr.log(
         | 
| 267 | 
            +
                                base_name,
         | 
| 268 | 
            +
                                rr.Transform3D(
         | 
| 269 | 
            +
                                    translation=pose[:3, 3],
         | 
| 270 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 271 | 
            +
                                ),
         | 
| 272 | 
            +
                            )
         | 
| 273 | 
            +
                            rr.log(
         | 
| 274 | 
            +
                                f"{base_name}/pinhole",
         | 
| 275 | 
            +
                                rr.Pinhole(
         | 
| 276 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 277 | 
            +
                                    height=height,
         | 
| 278 | 
            +
                                    width=width,
         | 
| 279 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 280 | 
            +
                                ),
         | 
| 281 | 
            +
                            )
         | 
| 282 | 
            +
                            rr.log(
         | 
| 283 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 284 | 
            +
                                rr.Image(image),
         | 
| 285 | 
            +
                            )
         | 
| 286 | 
            +
                            rr.log(
         | 
| 287 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 288 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 289 | 
            +
                            )
         | 
| 290 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 291 | 
            +
                                rr.log(
         | 
| 292 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 293 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 294 | 
            +
                                )
         | 
| 295 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 296 | 
            +
                                rr.log(
         | 
| 297 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 298 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 299 | 
            +
                                )
         | 
| 300 | 
            +
                            # Log points in 3D
         | 
| 301 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 302 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 303 | 
            +
                            rr.log(
         | 
| 304 | 
            +
                                pts_name,
         | 
| 305 | 
            +
                                rr.Points3D(
         | 
| 306 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 307 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 308 | 
            +
                                ),
         | 
| 309 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/sailvos3d.py
    ADDED
    
    | @@ -0,0 +1,308 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            SAIL-VOS 3D Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 15 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class SAILVOS3DWAI(BaseDataset):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                SAIL-VOS 3D dataset containing large diversity of synthetic in-the-wild cut scenes from GTA.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    *args,
         | 
| 26 | 
            +
                    ROOT,
         | 
| 27 | 
            +
                    dataset_metadata_dir,
         | 
| 28 | 
            +
                    split,
         | 
| 29 | 
            +
                    overfit_num_sets=None,
         | 
| 30 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 31 | 
            +
                    specific_scene_name: str = None,
         | 
| 32 | 
            +
                    **kwargs,
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    Initialize the dataset attributes.
         | 
| 36 | 
            +
                    Args:
         | 
| 37 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 38 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 39 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 40 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 41 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 42 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    # Initialize the dataset attributes
         | 
| 45 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 46 | 
            +
                    self.ROOT = ROOT
         | 
| 47 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 48 | 
            +
                    self.split = split
         | 
| 49 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 50 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 51 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 52 | 
            +
                    self._load_data()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # Define the dataset type flags
         | 
| 55 | 
            +
                    self.is_metric_scale = True
         | 
| 56 | 
            +
                    self.is_synthetic = True
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def _load_data(self):
         | 
| 59 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 60 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 61 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 62 | 
            +
                        self.dataset_metadata_dir,
         | 
| 63 | 
            +
                        self.split,
         | 
| 64 | 
            +
                        f"sailvos3d_scene_list_{self.split}.npy",
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Get the list of all scenes
         | 
| 69 | 
            +
                    if not self.sample_specific_scene:
         | 
| 70 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 73 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 76 | 
            +
                    # Get the scene name of the sampled index
         | 
| 77 | 
            +
                    scene_index = sampled_idx
         | 
| 78 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 81 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 82 | 
            +
                    scene_meta = load_data(
         | 
| 83 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 86 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 89 | 
            +
                    covisibility_version_key = "v0"
         | 
| 90 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 91 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 94 | 
            +
                    covisibility_map_name = next(
         | 
| 95 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 98 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 103 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 104 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 108 | 
            +
                    views = []
         | 
| 109 | 
            +
                    for view_index in view_indices:
         | 
| 110 | 
            +
                        # Load the data corresponding to the view
         | 
| 111 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 112 | 
            +
                        view_data = load_frame(
         | 
| 113 | 
            +
                            scene_root,
         | 
| 114 | 
            +
                            view_file_name,
         | 
| 115 | 
            +
                            modalities=["image", "depth"],
         | 
| 116 | 
            +
                            scene_meta=scene_meta,
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # Convert necessary data to numpy
         | 
| 120 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 121 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 122 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 123 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 127 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
         | 
| 130 | 
            +
                        non_ambiguous_mask = (depthmap > 0).astype(int)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                        # Mask out the outlier depth (horizon depth)
         | 
| 133 | 
            +
                        percentile_depth = np.percentile(depthmap, 95)
         | 
| 134 | 
            +
                        depthmap[depthmap > percentile_depth] = 0
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 137 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 138 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 139 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 140 | 
            +
                                image=image,
         | 
| 141 | 
            +
                                resolution=resolution,
         | 
| 142 | 
            +
                                depthmap=depthmap,
         | 
| 143 | 
            +
                                intrinsics=intrinsics,
         | 
| 144 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 145 | 
            +
                            )
         | 
| 146 | 
            +
                        )
         | 
| 147 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 150 | 
            +
                        views.append(
         | 
| 151 | 
            +
                            dict(
         | 
| 152 | 
            +
                                img=image,
         | 
| 153 | 
            +
                                depthmap=depthmap,
         | 
| 154 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 155 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 156 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 157 | 
            +
                                dataset="SAILVOS3D",
         | 
| 158 | 
            +
                                label=scene_name,
         | 
| 159 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 160 | 
            +
                            )
         | 
| 161 | 
            +
                        )
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    return views
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            def get_parser():
         | 
| 167 | 
            +
                import argparse
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 170 | 
            +
                parser.add_argument(
         | 
| 171 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/sailvos3d", type=str
         | 
| 172 | 
            +
                )
         | 
| 173 | 
            +
                parser.add_argument(
         | 
| 174 | 
            +
                    "-dmd",
         | 
| 175 | 
            +
                    "--dataset_metadata_dir",
         | 
| 176 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 177 | 
            +
                    type=str,
         | 
| 178 | 
            +
                )
         | 
| 179 | 
            +
                parser.add_argument(
         | 
| 180 | 
            +
                    "-nv",
         | 
| 181 | 
            +
                    "--num_of_views",
         | 
| 182 | 
            +
                    default=2,
         | 
| 183 | 
            +
                    type=int,
         | 
| 184 | 
            +
                )
         | 
| 185 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                return parser
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            if __name__ == "__main__":
         | 
| 191 | 
            +
                import rerun as rr
         | 
| 192 | 
            +
                from tqdm import tqdm
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 195 | 
            +
                from mapanything.utils.image import rgb
         | 
| 196 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                parser = get_parser()
         | 
| 199 | 
            +
                script_add_rerun_args(
         | 
| 200 | 
            +
                    parser
         | 
| 201 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 202 | 
            +
                args = parser.parse_args()
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                dataset = SAILVOS3DWAI(
         | 
| 205 | 
            +
                    num_views=args.num_of_views,
         | 
| 206 | 
            +
                    split="train",
         | 
| 207 | 
            +
                    covisibility_thres=0.25,
         | 
| 208 | 
            +
                    ROOT=args.root_dir,
         | 
| 209 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 210 | 
            +
                    resolution=(518, 336),
         | 
| 211 | 
            +
                    aug_crop=16,
         | 
| 212 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 213 | 
            +
                    data_norm_type="dinov2",
         | 
| 214 | 
            +
                )
         | 
| 215 | 
            +
                # dataset = SAILVOS3DWAI(
         | 
| 216 | 
            +
                #     num_views=args.num_of_views,
         | 
| 217 | 
            +
                #     split="val",
         | 
| 218 | 
            +
                #     covisibility_thres=0.25,
         | 
| 219 | 
            +
                #     ROOT=args.root_dir,
         | 
| 220 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 221 | 
            +
                #     resolution=(518, 336),
         | 
| 222 | 
            +
                #     seed=777,
         | 
| 223 | 
            +
                #     transform="imgnorm",
         | 
| 224 | 
            +
                #     data_norm_type="dinov2",
         | 
| 225 | 
            +
                # )
         | 
| 226 | 
            +
                print(dataset.get_stats())
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                if args.viz:
         | 
| 229 | 
            +
                    rr.script_setup(args, "SAILVOS3D_Dataloader")
         | 
| 230 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 231 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 236 | 
            +
                    views = dataset[idx]
         | 
| 237 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 238 | 
            +
                    sample_name = f"{idx}"
         | 
| 239 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 240 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 241 | 
            +
                    print(sample_name)
         | 
| 242 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 243 | 
            +
                        image = rgb(
         | 
| 244 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 245 | 
            +
                        )
         | 
| 246 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 247 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 248 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 249 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 250 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 251 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 252 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 253 | 
            +
                        else:
         | 
| 254 | 
            +
                            non_ambiguous_mask = None
         | 
| 255 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 256 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 257 | 
            +
                        else:
         | 
| 258 | 
            +
                            prior_depth_along_ray = None
         | 
| 259 | 
            +
                        if args.viz:
         | 
| 260 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 261 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 262 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 263 | 
            +
                            # Log camera info and loaded data
         | 
| 264 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 265 | 
            +
                            rr.log(
         | 
| 266 | 
            +
                                base_name,
         | 
| 267 | 
            +
                                rr.Transform3D(
         | 
| 268 | 
            +
                                    translation=pose[:3, 3],
         | 
| 269 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 270 | 
            +
                                ),
         | 
| 271 | 
            +
                            )
         | 
| 272 | 
            +
                            rr.log(
         | 
| 273 | 
            +
                                f"{base_name}/pinhole",
         | 
| 274 | 
            +
                                rr.Pinhole(
         | 
| 275 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 276 | 
            +
                                    height=height,
         | 
| 277 | 
            +
                                    width=width,
         | 
| 278 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 279 | 
            +
                                ),
         | 
| 280 | 
            +
                            )
         | 
| 281 | 
            +
                            rr.log(
         | 
| 282 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 283 | 
            +
                                rr.Image(image),
         | 
| 284 | 
            +
                            )
         | 
| 285 | 
            +
                            rr.log(
         | 
| 286 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 287 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 288 | 
            +
                            )
         | 
| 289 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 290 | 
            +
                                rr.log(
         | 
| 291 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 292 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 293 | 
            +
                                )
         | 
| 294 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 295 | 
            +
                                rr.log(
         | 
| 296 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 297 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 298 | 
            +
                                )
         | 
| 299 | 
            +
                            # Log points in 3D
         | 
| 300 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 301 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 302 | 
            +
                            rr.log(
         | 
| 303 | 
            +
                                pts_name,
         | 
| 304 | 
            +
                                rr.Points3D(
         | 
| 305 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 306 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 307 | 
            +
                                ),
         | 
| 308 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/scannetpp.py
    ADDED
    
    | @@ -0,0 +1,307 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            ScanNet++V2 Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 15 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class ScanNetPPWAI(BaseDataset):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                ScanNet++V2 dataset containing large diversity of indoor scenes.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    *args,
         | 
| 26 | 
            +
                    ROOT,
         | 
| 27 | 
            +
                    dataset_metadata_dir,
         | 
| 28 | 
            +
                    split,
         | 
| 29 | 
            +
                    overfit_num_sets=None,
         | 
| 30 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 31 | 
            +
                    specific_scene_name: str = None,
         | 
| 32 | 
            +
                    **kwargs,
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    Initialize the dataset attributes.
         | 
| 36 | 
            +
                    Args:
         | 
| 37 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 38 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 39 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 40 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 41 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 42 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    # Initialize the dataset attributes
         | 
| 45 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 46 | 
            +
                    self.ROOT = ROOT
         | 
| 47 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 48 | 
            +
                    self.split = split
         | 
| 49 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 50 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 51 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 52 | 
            +
                    self._load_data()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # Define the dataset type flags
         | 
| 55 | 
            +
                    self.is_metric_scale = True
         | 
| 56 | 
            +
                    self.is_synthetic = False
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def _load_data(self):
         | 
| 59 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 60 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 61 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 62 | 
            +
                        self.dataset_metadata_dir,
         | 
| 63 | 
            +
                        self.split,
         | 
| 64 | 
            +
                        f"scannetppv2_scene_list_{self.split}.npy",
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Get the list of all scenes
         | 
| 69 | 
            +
                    if not self.sample_specific_scene:
         | 
| 70 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 73 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 76 | 
            +
                    # Get the scene name of the sampled index
         | 
| 77 | 
            +
                    scene_index = sampled_idx
         | 
| 78 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 81 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 82 | 
            +
                    scene_meta = load_data(
         | 
| 83 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 86 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 89 | 
            +
                    covisibility_version_key = "v0"
         | 
| 90 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 91 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 94 | 
            +
                    covisibility_map_name = next(
         | 
| 95 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 98 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 103 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 104 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 108 | 
            +
                    views = []
         | 
| 109 | 
            +
                    for view_index in view_indices:
         | 
| 110 | 
            +
                        # Load the data corresponding to the view
         | 
| 111 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 112 | 
            +
                        view_data = load_frame(
         | 
| 113 | 
            +
                            scene_root,
         | 
| 114 | 
            +
                            view_file_name,
         | 
| 115 | 
            +
                            modalities=["image", "rendered_depth"],
         | 
| 116 | 
            +
                            scene_meta=scene_meta,
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # Convert necessary data to numpy
         | 
| 120 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 121 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 122 | 
            +
                        depthmap = view_data["rendered_depth"].numpy().astype(np.float32)
         | 
| 123 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 127 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 130 | 
            +
                        image, depthmap, intrinsics = self._crop_resize_if_necessary(
         | 
| 131 | 
            +
                            image=image,
         | 
| 132 | 
            +
                            resolution=resolution,
         | 
| 133 | 
            +
                            depthmap=depthmap,
         | 
| 134 | 
            +
                            intrinsics=intrinsics,
         | 
| 135 | 
            +
                            additional_quantities=None,
         | 
| 136 | 
            +
                        )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 139 | 
            +
                        views.append(
         | 
| 140 | 
            +
                            dict(
         | 
| 141 | 
            +
                                img=image,
         | 
| 142 | 
            +
                                depthmap=depthmap,
         | 
| 143 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 144 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 145 | 
            +
                                dataset="ScanNetPP",
         | 
| 146 | 
            +
                                label=scene_name,
         | 
| 147 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 148 | 
            +
                            )
         | 
| 149 | 
            +
                        )
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    return views
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            def get_parser():
         | 
| 155 | 
            +
                import argparse
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 158 | 
            +
                parser.add_argument(
         | 
| 159 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/scannetppv2", type=str
         | 
| 160 | 
            +
                )
         | 
| 161 | 
            +
                parser.add_argument(
         | 
| 162 | 
            +
                    "-dmd",
         | 
| 163 | 
            +
                    "--dataset_metadata_dir",
         | 
| 164 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 165 | 
            +
                    type=str,
         | 
| 166 | 
            +
                )
         | 
| 167 | 
            +
                parser.add_argument(
         | 
| 168 | 
            +
                    "-nv",
         | 
| 169 | 
            +
                    "--num_of_views",
         | 
| 170 | 
            +
                    default=2,
         | 
| 171 | 
            +
                    type=int,
         | 
| 172 | 
            +
                )
         | 
| 173 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                return parser
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            if __name__ == "__main__":
         | 
| 179 | 
            +
                import rerun as rr
         | 
| 180 | 
            +
                from tqdm import tqdm
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 183 | 
            +
                from mapanything.utils.image import rgb
         | 
| 184 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                parser = get_parser()
         | 
| 187 | 
            +
                script_add_rerun_args(
         | 
| 188 | 
            +
                    parser
         | 
| 189 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 190 | 
            +
                args = parser.parse_args()
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                dataset = ScanNetPPWAI(
         | 
| 193 | 
            +
                    num_views=args.num_of_views,
         | 
| 194 | 
            +
                    split="train",
         | 
| 195 | 
            +
                    covisibility_thres=0.25,
         | 
| 196 | 
            +
                    ROOT=args.root_dir,
         | 
| 197 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 198 | 
            +
                    resolution=(518, 336),
         | 
| 199 | 
            +
                    aug_crop=16,
         | 
| 200 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 201 | 
            +
                    data_norm_type="dinov2",
         | 
| 202 | 
            +
                )
         | 
| 203 | 
            +
                # dataset = ScanNetPPWAI(
         | 
| 204 | 
            +
                #     num_views=args.num_of_views,
         | 
| 205 | 
            +
                #     split="val",
         | 
| 206 | 
            +
                #     covisibility_thres=0.25,
         | 
| 207 | 
            +
                #     ROOT=args.root_dir,
         | 
| 208 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 209 | 
            +
                #     resolution=(518, 336),
         | 
| 210 | 
            +
                #     seed=777,
         | 
| 211 | 
            +
                #     transform="imgnorm",
         | 
| 212 | 
            +
                #     data_norm_type="dinov2",
         | 
| 213 | 
            +
                # )
         | 
| 214 | 
            +
                # dataset = ScanNetPPWAI(
         | 
| 215 | 
            +
                #     num_views=args.num_of_views,
         | 
| 216 | 
            +
                #     split="test",
         | 
| 217 | 
            +
                #     covisibility_thres=0.25,
         | 
| 218 | 
            +
                #     ROOT=args.root_dir,
         | 
| 219 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 220 | 
            +
                #     resolution=(518, 336),
         | 
| 221 | 
            +
                #     seed=777,
         | 
| 222 | 
            +
                #     transform="imgnorm",
         | 
| 223 | 
            +
                #     data_norm_type="dinov2",
         | 
| 224 | 
            +
                # )
         | 
| 225 | 
            +
                print(dataset.get_stats())
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                if args.viz:
         | 
| 228 | 
            +
                    rr.script_setup(args, "ScanNetPP_Dataloader")
         | 
| 229 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 230 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 235 | 
            +
                    views = dataset[idx]
         | 
| 236 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 237 | 
            +
                    sample_name = f"{idx}"
         | 
| 238 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 239 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 240 | 
            +
                    print(sample_name)
         | 
| 241 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 242 | 
            +
                        image = rgb(
         | 
| 243 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 244 | 
            +
                        )
         | 
| 245 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 246 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 247 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 248 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 249 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 250 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 251 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 252 | 
            +
                        else:
         | 
| 253 | 
            +
                            non_ambiguous_mask = None
         | 
| 254 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 255 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 256 | 
            +
                        else:
         | 
| 257 | 
            +
                            prior_depth_along_ray = None
         | 
| 258 | 
            +
                        if args.viz:
         | 
| 259 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 260 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 261 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 262 | 
            +
                            # Log camera info and loaded data
         | 
| 263 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 264 | 
            +
                            rr.log(
         | 
| 265 | 
            +
                                base_name,
         | 
| 266 | 
            +
                                rr.Transform3D(
         | 
| 267 | 
            +
                                    translation=pose[:3, 3],
         | 
| 268 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 269 | 
            +
                                ),
         | 
| 270 | 
            +
                            )
         | 
| 271 | 
            +
                            rr.log(
         | 
| 272 | 
            +
                                f"{base_name}/pinhole",
         | 
| 273 | 
            +
                                rr.Pinhole(
         | 
| 274 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 275 | 
            +
                                    height=height,
         | 
| 276 | 
            +
                                    width=width,
         | 
| 277 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 278 | 
            +
                                ),
         | 
| 279 | 
            +
                            )
         | 
| 280 | 
            +
                            rr.log(
         | 
| 281 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 282 | 
            +
                                rr.Image(image),
         | 
| 283 | 
            +
                            )
         | 
| 284 | 
            +
                            rr.log(
         | 
| 285 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 286 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 287 | 
            +
                            )
         | 
| 288 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 289 | 
            +
                                rr.log(
         | 
| 290 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 291 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 292 | 
            +
                                )
         | 
| 293 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 294 | 
            +
                                rr.log(
         | 
| 295 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 296 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 297 | 
            +
                                )
         | 
| 298 | 
            +
                            # Log points in 3D
         | 
| 299 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 300 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 301 | 
            +
                            rr.log(
         | 
| 302 | 
            +
                                pts_name,
         | 
| 303 | 
            +
                                rr.Points3D(
         | 
| 304 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 305 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 306 | 
            +
                                ),
         | 
| 307 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/spring.py
    ADDED
    
    | @@ -0,0 +1,316 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            Spring Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import cv2
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 16 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class SpringWAI(BaseDataset):
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                Spring dataset containing high-quality large-scale in-the-wild scenes with unique animated objects.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    *args,
         | 
| 27 | 
            +
                    ROOT,
         | 
| 28 | 
            +
                    dataset_metadata_dir,
         | 
| 29 | 
            +
                    split,
         | 
| 30 | 
            +
                    overfit_num_sets=None,
         | 
| 31 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 32 | 
            +
                    specific_scene_name: str = None,
         | 
| 33 | 
            +
                    **kwargs,
         | 
| 34 | 
            +
                ):
         | 
| 35 | 
            +
                    """
         | 
| 36 | 
            +
                    Initialize the dataset attributes.
         | 
| 37 | 
            +
                    Args:
         | 
| 38 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 39 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 40 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 41 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 42 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 43 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    # Initialize the dataset attributes
         | 
| 46 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 47 | 
            +
                    self.ROOT = ROOT
         | 
| 48 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 49 | 
            +
                    self.split = split
         | 
| 50 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 51 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 52 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 53 | 
            +
                    self._load_data()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Define the dataset type flags
         | 
| 56 | 
            +
                    self.is_metric_scale = True
         | 
| 57 | 
            +
                    self.is_synthetic = True
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def _load_data(self):
         | 
| 60 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 61 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 62 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 63 | 
            +
                        self.dataset_metadata_dir,
         | 
| 64 | 
            +
                        self.split,
         | 
| 65 | 
            +
                        f"spring_scene_list_{self.split}.npy",
         | 
| 66 | 
            +
                    )
         | 
| 67 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # Get the list of all scenes
         | 
| 70 | 
            +
                    if not self.sample_specific_scene:
         | 
| 71 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 72 | 
            +
                    else:
         | 
| 73 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 74 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 77 | 
            +
                    # Get the scene name of the sampled index
         | 
| 78 | 
            +
                    scene_index = sampled_idx
         | 
| 79 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 82 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 83 | 
            +
                    scene_meta = load_data(
         | 
| 84 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 87 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 90 | 
            +
                    covisibility_version_key = "v0"
         | 
| 91 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 92 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
                    covisibility_map_name = next(
         | 
| 95 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 96 | 
            +
                    )  # Assumes only npy file in directory is covisibility map
         | 
| 97 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 98 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 103 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 104 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 108 | 
            +
                    views = []
         | 
| 109 | 
            +
                    for view_index in view_indices:
         | 
| 110 | 
            +
                        # Load the data corresponding to the view
         | 
| 111 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 112 | 
            +
                        view_data = load_frame(
         | 
| 113 | 
            +
                            scene_root,
         | 
| 114 | 
            +
                            view_file_name,
         | 
| 115 | 
            +
                            modalities=["image", "depth", "skymask", "pred_mask/moge2"],
         | 
| 116 | 
            +
                            scene_meta=scene_meta,
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # Convert necessary data to numpy
         | 
| 120 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 121 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 122 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 123 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        # Get the sky mask and mask out GT depth
         | 
| 127 | 
            +
                        sky_mask = view_data["skymask"].numpy().astype(int)
         | 
| 128 | 
            +
                        depthmap = np.where(sky_mask, 0, depthmap)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 131 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                        # Get the non_ambiguous_mask and ensure it matches image resolution
         | 
| 134 | 
            +
                        non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
         | 
| 135 | 
            +
                        non_ambiguous_mask = cv2.resize(
         | 
| 136 | 
            +
                            non_ambiguous_mask,
         | 
| 137 | 
            +
                            (image.shape[1], image.shape[0]),
         | 
| 138 | 
            +
                            interpolation=cv2.INTER_NEAREST,
         | 
| 139 | 
            +
                        )
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                        # Mask out the GT depth using the non_ambiguous_mask
         | 
| 142 | 
            +
                        depthmap = np.where(non_ambiguous_mask, depthmap, 0)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 145 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 146 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 147 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 148 | 
            +
                                image=image,
         | 
| 149 | 
            +
                                resolution=resolution,
         | 
| 150 | 
            +
                                depthmap=depthmap,
         | 
| 151 | 
            +
                                intrinsics=intrinsics,
         | 
| 152 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 153 | 
            +
                            )
         | 
| 154 | 
            +
                        )
         | 
| 155 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 158 | 
            +
                        views.append(
         | 
| 159 | 
            +
                            dict(
         | 
| 160 | 
            +
                                img=image,
         | 
| 161 | 
            +
                                depthmap=depthmap,
         | 
| 162 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 163 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 164 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 165 | 
            +
                                dataset="Spring",
         | 
| 166 | 
            +
                                label=scene_name,
         | 
| 167 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 168 | 
            +
                            )
         | 
| 169 | 
            +
                        )
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    return views
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            def get_parser():
         | 
| 175 | 
            +
                import argparse
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 178 | 
            +
                parser.add_argument(
         | 
| 179 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/spring", type=str
         | 
| 180 | 
            +
                )
         | 
| 181 | 
            +
                parser.add_argument(
         | 
| 182 | 
            +
                    "-dmd",
         | 
| 183 | 
            +
                    "--dataset_metadata_dir",
         | 
| 184 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 185 | 
            +
                    type=str,
         | 
| 186 | 
            +
                )
         | 
| 187 | 
            +
                parser.add_argument(
         | 
| 188 | 
            +
                    "-nv",
         | 
| 189 | 
            +
                    "--num_of_views",
         | 
| 190 | 
            +
                    default=2,
         | 
| 191 | 
            +
                    type=int,
         | 
| 192 | 
            +
                )
         | 
| 193 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                return parser
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            if __name__ == "__main__":
         | 
| 199 | 
            +
                import rerun as rr
         | 
| 200 | 
            +
                from tqdm import tqdm
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 203 | 
            +
                from mapanything.utils.image import rgb
         | 
| 204 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                parser = get_parser()
         | 
| 207 | 
            +
                script_add_rerun_args(
         | 
| 208 | 
            +
                    parser
         | 
| 209 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 210 | 
            +
                args = parser.parse_args()
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                dataset = SpringWAI(
         | 
| 213 | 
            +
                    num_views=args.num_of_views,
         | 
| 214 | 
            +
                    split="train",
         | 
| 215 | 
            +
                    covisibility_thres=0.25,
         | 
| 216 | 
            +
                    ROOT=args.root_dir,
         | 
| 217 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 218 | 
            +
                    resolution=(518, 294),
         | 
| 219 | 
            +
                    aug_crop=16,
         | 
| 220 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 221 | 
            +
                    data_norm_type="dinov2",
         | 
| 222 | 
            +
                )
         | 
| 223 | 
            +
                # dataset = SpringWAI(
         | 
| 224 | 
            +
                #     num_views=args.num_of_views,
         | 
| 225 | 
            +
                #     split="val",
         | 
| 226 | 
            +
                #     covisibility_thres=0.25,
         | 
| 227 | 
            +
                #     ROOT=args.root_dir,
         | 
| 228 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 229 | 
            +
                #     resolution=(518, 294),
         | 
| 230 | 
            +
                #     seed=777,
         | 
| 231 | 
            +
                #     transform="imgnorm",
         | 
| 232 | 
            +
                #     data_norm_type="dinov2",
         | 
| 233 | 
            +
                # )
         | 
| 234 | 
            +
                print(dataset.get_stats())
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                if args.viz:
         | 
| 237 | 
            +
                    rr.script_setup(args, "Spring_Dataloader")
         | 
| 238 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 239 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 244 | 
            +
                    views = dataset[idx]
         | 
| 245 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 246 | 
            +
                    sample_name = f"{idx}"
         | 
| 247 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 248 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 249 | 
            +
                    print(sample_name)
         | 
| 250 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 251 | 
            +
                        image = rgb(
         | 
| 252 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 253 | 
            +
                        )
         | 
| 254 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 255 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 256 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 257 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 258 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 259 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 260 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 261 | 
            +
                        else:
         | 
| 262 | 
            +
                            non_ambiguous_mask = None
         | 
| 263 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 264 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 265 | 
            +
                        else:
         | 
| 266 | 
            +
                            prior_depth_along_ray = None
         | 
| 267 | 
            +
                        if args.viz:
         | 
| 268 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 269 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 270 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 271 | 
            +
                            # Log camera info and loaded data
         | 
| 272 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 273 | 
            +
                            rr.log(
         | 
| 274 | 
            +
                                base_name,
         | 
| 275 | 
            +
                                rr.Transform3D(
         | 
| 276 | 
            +
                                    translation=pose[:3, 3],
         | 
| 277 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 278 | 
            +
                                ),
         | 
| 279 | 
            +
                            )
         | 
| 280 | 
            +
                            rr.log(
         | 
| 281 | 
            +
                                f"{base_name}/pinhole",
         | 
| 282 | 
            +
                                rr.Pinhole(
         | 
| 283 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 284 | 
            +
                                    height=height,
         | 
| 285 | 
            +
                                    width=width,
         | 
| 286 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 287 | 
            +
                                ),
         | 
| 288 | 
            +
                            )
         | 
| 289 | 
            +
                            rr.log(
         | 
| 290 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 291 | 
            +
                                rr.Image(image),
         | 
| 292 | 
            +
                            )
         | 
| 293 | 
            +
                            rr.log(
         | 
| 294 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 295 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 296 | 
            +
                            )
         | 
| 297 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 298 | 
            +
                                rr.log(
         | 
| 299 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 300 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 301 | 
            +
                                )
         | 
| 302 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 303 | 
            +
                                rr.log(
         | 
| 304 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 305 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 306 | 
            +
                                )
         | 
| 307 | 
            +
                            # Log points in 3D
         | 
| 308 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 309 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 310 | 
            +
                            rr.log(
         | 
| 311 | 
            +
                                pts_name,
         | 
| 312 | 
            +
                                rr.Points3D(
         | 
| 313 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 314 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 315 | 
            +
                                ),
         | 
| 316 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/tav2_wb.py
    ADDED
    
    | @@ -0,0 +1,328 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            TartanAirV2-WB Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import cv2
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 16 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class TartanAirV2WBWAI(BaseDataset):
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                TartanAirV2-WB dataset containing vastly-sized in-the-wild synthetic scenes.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    *args,
         | 
| 27 | 
            +
                    ROOT,
         | 
| 28 | 
            +
                    dataset_metadata_dir,
         | 
| 29 | 
            +
                    split,
         | 
| 30 | 
            +
                    overfit_num_sets=None,
         | 
| 31 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 32 | 
            +
                    specific_scene_name: str = None,
         | 
| 33 | 
            +
                    **kwargs,
         | 
| 34 | 
            +
                ):
         | 
| 35 | 
            +
                    """
         | 
| 36 | 
            +
                    Initialize the dataset attributes.
         | 
| 37 | 
            +
                    Args:
         | 
| 38 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 39 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 40 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 41 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 42 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 43 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    # Initialize the dataset attributes
         | 
| 46 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 47 | 
            +
                    self.ROOT = ROOT
         | 
| 48 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 49 | 
            +
                    self.split = split
         | 
| 50 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 51 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 52 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 53 | 
            +
                    self._load_data()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Define the dataset type flags
         | 
| 56 | 
            +
                    self.is_metric_scale = True
         | 
| 57 | 
            +
                    self.is_synthetic = True
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def _load_data(self):
         | 
| 60 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 61 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 62 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 63 | 
            +
                        self.dataset_metadata_dir,
         | 
| 64 | 
            +
                        self.split,
         | 
| 65 | 
            +
                        f"tav2_wb_scene_list_{self.split}.npy",
         | 
| 66 | 
            +
                    )
         | 
| 67 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # Get the list of all scenes
         | 
| 70 | 
            +
                    if not self.sample_specific_scene:
         | 
| 71 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 72 | 
            +
                    else:
         | 
| 73 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 74 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 77 | 
            +
                    # Get the scene name of the sampled index
         | 
| 78 | 
            +
                    scene_index = sampled_idx
         | 
| 79 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 82 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 83 | 
            +
                    scene_meta = load_data(
         | 
| 84 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 87 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 90 | 
            +
                    covisibility_version_key = "v0"
         | 
| 91 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 92 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 95 | 
            +
                    covisibility_map_name = next(
         | 
| 96 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 99 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 104 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 105 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 109 | 
            +
                    views = []
         | 
| 110 | 
            +
                    for view_index in view_indices:
         | 
| 111 | 
            +
                        # Load the data corresponding to the view
         | 
| 112 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 113 | 
            +
                        view_data = load_frame(
         | 
| 114 | 
            +
                            scene_root,
         | 
| 115 | 
            +
                            view_file_name,
         | 
| 116 | 
            +
                            modalities=["image", "depth", "pred_mask/moge2"],
         | 
| 117 | 
            +
                            scene_meta=scene_meta,
         | 
| 118 | 
            +
                        )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                        # Convert necessary data to numpy
         | 
| 121 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 122 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 123 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 128 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        # Mask out the outlier depth caused due to transparent windows in TartanAirV2
         | 
| 131 | 
            +
                        percentile_depth = np.percentile(depthmap, 95)
         | 
| 132 | 
            +
                        depthmap[depthmap > percentile_depth] = 0
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                        # Get the non_ambiguous_mask and ensure it matches image resolution
         | 
| 135 | 
            +
                        non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
         | 
| 136 | 
            +
                        non_ambiguous_mask = cv2.resize(
         | 
| 137 | 
            +
                            non_ambiguous_mask,
         | 
| 138 | 
            +
                            (image.shape[1], image.shape[0]),
         | 
| 139 | 
            +
                            interpolation=cv2.INTER_NEAREST,
         | 
| 140 | 
            +
                        )
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                        # Mask out the GT depth using the non_ambiguous_mask
         | 
| 143 | 
            +
                        depthmap = np.where(non_ambiguous_mask, depthmap, 0)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 146 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 147 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 148 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 149 | 
            +
                                image=image,
         | 
| 150 | 
            +
                                resolution=resolution,
         | 
| 151 | 
            +
                                depthmap=depthmap,
         | 
| 152 | 
            +
                                intrinsics=intrinsics,
         | 
| 153 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 154 | 
            +
                            )
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 159 | 
            +
                        views.append(
         | 
| 160 | 
            +
                            dict(
         | 
| 161 | 
            +
                                img=image,
         | 
| 162 | 
            +
                                depthmap=depthmap,
         | 
| 163 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 164 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 165 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 166 | 
            +
                                dataset="TartanAirV2WB",
         | 
| 167 | 
            +
                                label=scene_name,
         | 
| 168 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 169 | 
            +
                            )
         | 
| 170 | 
            +
                        )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    return views
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            def get_parser():
         | 
| 176 | 
            +
                import argparse
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 179 | 
            +
                parser.add_argument(
         | 
| 180 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/tav2_wb", type=str
         | 
| 181 | 
            +
                )
         | 
| 182 | 
            +
                parser.add_argument(
         | 
| 183 | 
            +
                    "-dmd",
         | 
| 184 | 
            +
                    "--dataset_metadata_dir",
         | 
| 185 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 186 | 
            +
                    type=str,
         | 
| 187 | 
            +
                )
         | 
| 188 | 
            +
                parser.add_argument(
         | 
| 189 | 
            +
                    "-nv",
         | 
| 190 | 
            +
                    "--num_of_views",
         | 
| 191 | 
            +
                    default=2,
         | 
| 192 | 
            +
                    type=int,
         | 
| 193 | 
            +
                )
         | 
| 194 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                return parser
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
            if __name__ == "__main__":
         | 
| 200 | 
            +
                import rerun as rr
         | 
| 201 | 
            +
                from tqdm import tqdm
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 204 | 
            +
                from mapanything.utils.image import rgb
         | 
| 205 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                parser = get_parser()
         | 
| 208 | 
            +
                script_add_rerun_args(
         | 
| 209 | 
            +
                    parser
         | 
| 210 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 211 | 
            +
                args = parser.parse_args()
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                dataset = TartanAirV2WBWAI(
         | 
| 214 | 
            +
                    num_views=args.num_of_views,
         | 
| 215 | 
            +
                    split="train",
         | 
| 216 | 
            +
                    covisibility_thres=0.25,
         | 
| 217 | 
            +
                    ROOT=args.root_dir,
         | 
| 218 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 219 | 
            +
                    resolution=(518, 518),
         | 
| 220 | 
            +
                    aug_crop=16,
         | 
| 221 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 222 | 
            +
                    data_norm_type="dinov2",
         | 
| 223 | 
            +
                )
         | 
| 224 | 
            +
                # dataset = TartanAirV2WBWAI(
         | 
| 225 | 
            +
                #     num_views=args.num_of_views,
         | 
| 226 | 
            +
                #     split="val",
         | 
| 227 | 
            +
                #     covisibility_thres=0.25,
         | 
| 228 | 
            +
                #     ROOT=args.root_dir,
         | 
| 229 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 230 | 
            +
                #     resolution=(518, 518),
         | 
| 231 | 
            +
                #     seed=777,
         | 
| 232 | 
            +
                #     transform="imgnorm",
         | 
| 233 | 
            +
                #     data_norm_type="dinov2",
         | 
| 234 | 
            +
                # )
         | 
| 235 | 
            +
                # dataset = TartanAirV2WBWAI(
         | 
| 236 | 
            +
                #     num_views=args.num_of_views,
         | 
| 237 | 
            +
                #     split="test",
         | 
| 238 | 
            +
                #     covisibility_thres=0.25,
         | 
| 239 | 
            +
                #     ROOT=args.root_dir,
         | 
| 240 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 241 | 
            +
                #     resolution=(518, 518),
         | 
| 242 | 
            +
                #     seed=777,
         | 
| 243 | 
            +
                #     transform="imgnorm",
         | 
| 244 | 
            +
                #     data_norm_type="dinov2",
         | 
| 245 | 
            +
                # )
         | 
| 246 | 
            +
                print(dataset.get_stats())
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                if args.viz:
         | 
| 249 | 
            +
                    rr.script_setup(args, "TartanAirV2WB_Dataloader")
         | 
| 250 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 251 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 256 | 
            +
                    views = dataset[idx]
         | 
| 257 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 258 | 
            +
                    sample_name = f"{idx}"
         | 
| 259 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 260 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 261 | 
            +
                    print(sample_name)
         | 
| 262 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 263 | 
            +
                        image = rgb(
         | 
| 264 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 265 | 
            +
                        )
         | 
| 266 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 267 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 268 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 269 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 270 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 271 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 272 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 273 | 
            +
                        else:
         | 
| 274 | 
            +
                            non_ambiguous_mask = None
         | 
| 275 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 276 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 277 | 
            +
                        else:
         | 
| 278 | 
            +
                            prior_depth_along_ray = None
         | 
| 279 | 
            +
                        if args.viz:
         | 
| 280 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 281 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 282 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 283 | 
            +
                            # Log camera info and loaded data
         | 
| 284 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 285 | 
            +
                            rr.log(
         | 
| 286 | 
            +
                                base_name,
         | 
| 287 | 
            +
                                rr.Transform3D(
         | 
| 288 | 
            +
                                    translation=pose[:3, 3],
         | 
| 289 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 290 | 
            +
                                ),
         | 
| 291 | 
            +
                            )
         | 
| 292 | 
            +
                            rr.log(
         | 
| 293 | 
            +
                                f"{base_name}/pinhole",
         | 
| 294 | 
            +
                                rr.Pinhole(
         | 
| 295 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 296 | 
            +
                                    height=height,
         | 
| 297 | 
            +
                                    width=width,
         | 
| 298 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 299 | 
            +
                                ),
         | 
| 300 | 
            +
                            )
         | 
| 301 | 
            +
                            rr.log(
         | 
| 302 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 303 | 
            +
                                rr.Image(image),
         | 
| 304 | 
            +
                            )
         | 
| 305 | 
            +
                            rr.log(
         | 
| 306 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 307 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 308 | 
            +
                            )
         | 
| 309 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 310 | 
            +
                                rr.log(
         | 
| 311 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 312 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 313 | 
            +
                                )
         | 
| 314 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 315 | 
            +
                                rr.log(
         | 
| 316 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 317 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 318 | 
            +
                                )
         | 
| 319 | 
            +
                            # Log points in 3D
         | 
| 320 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 321 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 322 | 
            +
                            rr.log(
         | 
| 323 | 
            +
                                pts_name,
         | 
| 324 | 
            +
                                rr.Points3D(
         | 
| 325 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 326 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 327 | 
            +
                                ),
         | 
| 328 | 
            +
                            )
         | 
    	
        mapanything/datasets/wai/unrealstereo4k.py
    ADDED
    
    | @@ -0,0 +1,309 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            UnrealStereo4K Dataset using WAI format data.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from mapanything.datasets.base.base_dataset import BaseDataset
         | 
| 15 | 
            +
            from mapanything.utils.wai.core import load_data, load_frame
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class UnrealStereo4KWAI(BaseDataset):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                UnrealStereo4K dataset containing synthetic in-the-wild scenes.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    *args,
         | 
| 26 | 
            +
                    ROOT,
         | 
| 27 | 
            +
                    dataset_metadata_dir,
         | 
| 28 | 
            +
                    split,
         | 
| 29 | 
            +
                    overfit_num_sets=None,
         | 
| 30 | 
            +
                    sample_specific_scene: bool = False,
         | 
| 31 | 
            +
                    specific_scene_name: str = None,
         | 
| 32 | 
            +
                    **kwargs,
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    Initialize the dataset attributes.
         | 
| 36 | 
            +
                    Args:
         | 
| 37 | 
            +
                        ROOT: Root directory of the dataset.
         | 
| 38 | 
            +
                        dataset_metadata_dir: Path to the dataset metadata directory.
         | 
| 39 | 
            +
                        split: Dataset split (train, val, test).
         | 
| 40 | 
            +
                        overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
         | 
| 41 | 
            +
                        sample_specific_scene: Whether to sample a specific scene from the dataset.
         | 
| 42 | 
            +
                        specific_scene_name: Name of the specific scene to sample.
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    # Initialize the dataset attributes
         | 
| 45 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 46 | 
            +
                    self.ROOT = ROOT
         | 
| 47 | 
            +
                    self.dataset_metadata_dir = dataset_metadata_dir
         | 
| 48 | 
            +
                    self.split = split
         | 
| 49 | 
            +
                    self.overfit_num_sets = overfit_num_sets
         | 
| 50 | 
            +
                    self.sample_specific_scene = sample_specific_scene
         | 
| 51 | 
            +
                    self.specific_scene_name = specific_scene_name
         | 
| 52 | 
            +
                    self._load_data()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # Define the dataset type flags
         | 
| 55 | 
            +
                    self.is_metric_scale = True
         | 
| 56 | 
            +
                    self.is_synthetic = True
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def _load_data(self):
         | 
| 59 | 
            +
                    "Load the precomputed dataset metadata"
         | 
| 60 | 
            +
                    # Load the dataset metadata corresponding to the split
         | 
| 61 | 
            +
                    split_metadata_path = os.path.join(
         | 
| 62 | 
            +
                        self.dataset_metadata_dir,
         | 
| 63 | 
            +
                        self.split,
         | 
| 64 | 
            +
                        f"unrealstereo4k_scene_list_{self.split}.npy",
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    split_scene_list = np.load(split_metadata_path, allow_pickle=True)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Get the list of all scenes
         | 
| 69 | 
            +
                    if not self.sample_specific_scene:
         | 
| 70 | 
            +
                        self.scenes = list(split_scene_list)
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        self.scenes = [self.specific_scene_name]
         | 
| 73 | 
            +
                    self.num_of_scenes = len(self.scenes)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _get_views(self, sampled_idx, num_views_to_sample, resolution):
         | 
| 76 | 
            +
                    # Get the scene name of the sampled index
         | 
| 77 | 
            +
                    scene_index = sampled_idx
         | 
| 78 | 
            +
                    scene_name = self.scenes[scene_index]
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Get the metadata corresponding to the scene
         | 
| 81 | 
            +
                    scene_root = os.path.join(self.ROOT, scene_name)
         | 
| 82 | 
            +
                    scene_meta = load_data(
         | 
| 83 | 
            +
                        os.path.join(scene_root, "scene_meta.json"), "scene_meta"
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    scene_file_names = list(scene_meta["frame_names"].keys())
         | 
| 86 | 
            +
                    num_views_in_scene = len(scene_file_names)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Load the scene pairwise covisibility mmap
         | 
| 89 | 
            +
                    covisibility_version_key = "v0"
         | 
| 90 | 
            +
                    covisibility_map_dir = os.path.join(
         | 
| 91 | 
            +
                        scene_root, "covisibility", covisibility_version_key
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    # Assumes only npy file in directory is covisibility map
         | 
| 94 | 
            +
                    covisibility_map_name = next(
         | 
| 95 | 
            +
                        f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                    covisibility_map_path = os.path.join(
         | 
| 98 | 
            +
                        scene_root, "covisibility", covisibility_version_key, covisibility_map_name
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    pairwise_covisibility = load_data(covisibility_map_path, "mmap")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Get the indices of the N views in the scene
         | 
| 103 | 
            +
                    view_indices = self._sample_view_indices(
         | 
| 104 | 
            +
                        num_views_to_sample, num_views_in_scene, pairwise_covisibility
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Get the views corresponding to the selected view indices
         | 
| 108 | 
            +
                    views = []
         | 
| 109 | 
            +
                    for view_index in view_indices:
         | 
| 110 | 
            +
                        # Load the data corresponding to the view
         | 
| 111 | 
            +
                        view_file_name = scene_file_names[view_index]
         | 
| 112 | 
            +
                        view_data = load_frame(
         | 
| 113 | 
            +
                            scene_root,
         | 
| 114 | 
            +
                            view_file_name,
         | 
| 115 | 
            +
                            modalities=["image", "depth"],
         | 
| 116 | 
            +
                            scene_meta=scene_meta,
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # Convert necessary data to numpy
         | 
| 120 | 
            +
                        image = view_data["image"].permute(1, 2, 0).numpy()
         | 
| 121 | 
            +
                        image = image[:, :, :3]  # RGBA to RGB
         | 
| 122 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 123 | 
            +
                        depthmap = view_data["depth"].numpy().astype(np.float32)
         | 
| 124 | 
            +
                        intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
         | 
| 125 | 
            +
                        c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        # Ensure that the depthmap has all valid values
         | 
| 128 | 
            +
                        depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
         | 
| 131 | 
            +
                        non_ambiguous_mask = (depthmap > 0).astype(int)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                        # Mask out the outlier depth (horizon depth)
         | 
| 134 | 
            +
                        percentile_depth = np.percentile(depthmap, 95)
         | 
| 135 | 
            +
                        depthmap[depthmap > percentile_depth] = 0
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                        # Resize the data to match the desired resolution
         | 
| 138 | 
            +
                        additional_quantities_to_resize = [non_ambiguous_mask]
         | 
| 139 | 
            +
                        image, depthmap, intrinsics, additional_quantities_to_resize = (
         | 
| 140 | 
            +
                            self._crop_resize_if_necessary(
         | 
| 141 | 
            +
                                image=image,
         | 
| 142 | 
            +
                                resolution=resolution,
         | 
| 143 | 
            +
                                depthmap=depthmap,
         | 
| 144 | 
            +
                                intrinsics=intrinsics,
         | 
| 145 | 
            +
                                additional_quantities=additional_quantities_to_resize,
         | 
| 146 | 
            +
                            )
         | 
| 147 | 
            +
                        )
         | 
| 148 | 
            +
                        non_ambiguous_mask = additional_quantities_to_resize[0]
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                        # Append the view dictionary to the list of views
         | 
| 151 | 
            +
                        views.append(
         | 
| 152 | 
            +
                            dict(
         | 
| 153 | 
            +
                                img=image,
         | 
| 154 | 
            +
                                depthmap=depthmap,
         | 
| 155 | 
            +
                                camera_pose=c2w_pose,  # cam2world
         | 
| 156 | 
            +
                                camera_intrinsics=intrinsics,
         | 
| 157 | 
            +
                                non_ambiguous_mask=non_ambiguous_mask,
         | 
| 158 | 
            +
                                dataset="UnrealStereo4K",
         | 
| 159 | 
            +
                                label=scene_name,
         | 
| 160 | 
            +
                                instance=os.path.join("images", str(view_file_name)),
         | 
| 161 | 
            +
                            )
         | 
| 162 | 
            +
                        )
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    return views
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def get_parser():
         | 
| 168 | 
            +
                import argparse
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 171 | 
            +
                parser.add_argument(
         | 
| 172 | 
            +
                    "-rd", "--root_dir", default="/fsx/xrtech/data/unrealstereo4k", type=str
         | 
| 173 | 
            +
                )
         | 
| 174 | 
            +
                parser.add_argument(
         | 
| 175 | 
            +
                    "-dmd",
         | 
| 176 | 
            +
                    "--dataset_metadata_dir",
         | 
| 177 | 
            +
                    default="/fsx/nkeetha/mapanything_dataset_metadata",
         | 
| 178 | 
            +
                    type=str,
         | 
| 179 | 
            +
                )
         | 
| 180 | 
            +
                parser.add_argument(
         | 
| 181 | 
            +
                    "-nv",
         | 
| 182 | 
            +
                    "--num_of_views",
         | 
| 183 | 
            +
                    default=2,
         | 
| 184 | 
            +
                    type=int,
         | 
| 185 | 
            +
                )
         | 
| 186 | 
            +
                parser.add_argument("--viz", action="store_true")
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                return parser
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            if __name__ == "__main__":
         | 
| 192 | 
            +
                import rerun as rr
         | 
| 193 | 
            +
                from tqdm import tqdm
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                from mapanything.datasets.base.base_dataset import view_name
         | 
| 196 | 
            +
                from mapanything.utils.image import rgb
         | 
| 197 | 
            +
                from mapanything.utils.viz import script_add_rerun_args
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                parser = get_parser()
         | 
| 200 | 
            +
                script_add_rerun_args(
         | 
| 201 | 
            +
                    parser
         | 
| 202 | 
            +
                )  # Options: --headless, --connect, --serve, --addr, --save, --stdout
         | 
| 203 | 
            +
                args = parser.parse_args()
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                dataset = UnrealStereo4KWAI(
         | 
| 206 | 
            +
                    num_views=args.num_of_views,
         | 
| 207 | 
            +
                    split="train",
         | 
| 208 | 
            +
                    covisibility_thres=0.25,
         | 
| 209 | 
            +
                    ROOT=args.root_dir,
         | 
| 210 | 
            +
                    dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 211 | 
            +
                    resolution=(518, 294),
         | 
| 212 | 
            +
                    aug_crop=16,
         | 
| 213 | 
            +
                    transform="colorjitter+grayscale+gaublur",
         | 
| 214 | 
            +
                    data_norm_type="dinov2",
         | 
| 215 | 
            +
                )
         | 
| 216 | 
            +
                # dataset = UnrealStereo4KWAI(
         | 
| 217 | 
            +
                #     num_views=args.num_of_views,
         | 
| 218 | 
            +
                #     split="val",
         | 
| 219 | 
            +
                #     covisibility_thres=0.25,
         | 
| 220 | 
            +
                #     ROOT=args.root_dir,
         | 
| 221 | 
            +
                #     dataset_metadata_dir=args.dataset_metadata_dir,
         | 
| 222 | 
            +
                #     resolution=(518, 294),
         | 
| 223 | 
            +
                #     seed=777,
         | 
| 224 | 
            +
                #     transform="imgnorm",
         | 
| 225 | 
            +
                #     data_norm_type="dinov2",
         | 
| 226 | 
            +
                # )
         | 
| 227 | 
            +
                print(dataset.get_stats())
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                if args.viz:
         | 
| 230 | 
            +
                    rr.script_setup(args, "UnrealStereo4K_Dataloader")
         | 
| 231 | 
            +
                    rr.set_time("stable_time", sequence=0)
         | 
| 232 | 
            +
                    rr.log("world", rr.ViewCoordinates.RDF, static=True)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                for num, idx in enumerate(tqdm(sampled_indices)):
         | 
| 237 | 
            +
                    views = dataset[idx]
         | 
| 238 | 
            +
                    assert len(views) == args.num_of_views
         | 
| 239 | 
            +
                    sample_name = f"{idx}"
         | 
| 240 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 241 | 
            +
                        sample_name += f" {view_name(views[view_idx])}"
         | 
| 242 | 
            +
                    print(sample_name)
         | 
| 243 | 
            +
                    for view_idx in range(args.num_of_views):
         | 
| 244 | 
            +
                        image = rgb(
         | 
| 245 | 
            +
                            views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
         | 
| 246 | 
            +
                        )
         | 
| 247 | 
            +
                        depthmap = views[view_idx]["depthmap"]
         | 
| 248 | 
            +
                        pose = views[view_idx]["camera_pose"]
         | 
| 249 | 
            +
                        intrinsics = views[view_idx]["camera_intrinsics"]
         | 
| 250 | 
            +
                        pts3d = views[view_idx]["pts3d"]
         | 
| 251 | 
            +
                        valid_mask = views[view_idx]["valid_mask"]
         | 
| 252 | 
            +
                        if "non_ambiguous_mask" in views[view_idx]:
         | 
| 253 | 
            +
                            non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
         | 
| 254 | 
            +
                        else:
         | 
| 255 | 
            +
                            non_ambiguous_mask = None
         | 
| 256 | 
            +
                        if "prior_depth_along_ray" in views[view_idx]:
         | 
| 257 | 
            +
                            prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
         | 
| 258 | 
            +
                        else:
         | 
| 259 | 
            +
                            prior_depth_along_ray = None
         | 
| 260 | 
            +
                        if args.viz:
         | 
| 261 | 
            +
                            rr.set_time("stable_time", sequence=num)
         | 
| 262 | 
            +
                            base_name = f"world/view_{view_idx}"
         | 
| 263 | 
            +
                            pts_name = f"world/view_{view_idx}_pointcloud"
         | 
| 264 | 
            +
                            # Log camera info and loaded data
         | 
| 265 | 
            +
                            height, width = image.shape[0], image.shape[1]
         | 
| 266 | 
            +
                            rr.log(
         | 
| 267 | 
            +
                                base_name,
         | 
| 268 | 
            +
                                rr.Transform3D(
         | 
| 269 | 
            +
                                    translation=pose[:3, 3],
         | 
| 270 | 
            +
                                    mat3x3=pose[:3, :3],
         | 
| 271 | 
            +
                                ),
         | 
| 272 | 
            +
                            )
         | 
| 273 | 
            +
                            rr.log(
         | 
| 274 | 
            +
                                f"{base_name}/pinhole",
         | 
| 275 | 
            +
                                rr.Pinhole(
         | 
| 276 | 
            +
                                    image_from_camera=intrinsics,
         | 
| 277 | 
            +
                                    height=height,
         | 
| 278 | 
            +
                                    width=width,
         | 
| 279 | 
            +
                                    camera_xyz=rr.ViewCoordinates.RDF,
         | 
| 280 | 
            +
                                ),
         | 
| 281 | 
            +
                            )
         | 
| 282 | 
            +
                            rr.log(
         | 
| 283 | 
            +
                                f"{base_name}/pinhole/rgb",
         | 
| 284 | 
            +
                                rr.Image(image),
         | 
| 285 | 
            +
                            )
         | 
| 286 | 
            +
                            rr.log(
         | 
| 287 | 
            +
                                f"{base_name}/pinhole/depth",
         | 
| 288 | 
            +
                                rr.DepthImage(depthmap),
         | 
| 289 | 
            +
                            )
         | 
| 290 | 
            +
                            if prior_depth_along_ray is not None:
         | 
| 291 | 
            +
                                rr.log(
         | 
| 292 | 
            +
                                    f"prior_depth_along_ray_{view_idx}",
         | 
| 293 | 
            +
                                    rr.DepthImage(prior_depth_along_ray),
         | 
| 294 | 
            +
                                )
         | 
| 295 | 
            +
                            if non_ambiguous_mask is not None:
         | 
| 296 | 
            +
                                rr.log(
         | 
| 297 | 
            +
                                    f"{base_name}/pinhole/non_ambiguous_mask",
         | 
| 298 | 
            +
                                    rr.SegmentationImage(non_ambiguous_mask.astype(int)),
         | 
| 299 | 
            +
                                )
         | 
| 300 | 
            +
                            # Log points in 3D
         | 
| 301 | 
            +
                            filtered_pts = pts3d[valid_mask]
         | 
| 302 | 
            +
                            filtered_pts_col = image[valid_mask]
         | 
| 303 | 
            +
                            rr.log(
         | 
| 304 | 
            +
                                pts_name,
         | 
| 305 | 
            +
                                rr.Points3D(
         | 
| 306 | 
            +
                                    positions=filtered_pts.reshape(-1, 3),
         | 
| 307 | 
            +
                                    colors=filtered_pts_col.reshape(-1, 3),
         | 
| 308 | 
            +
                                ),
         | 
| 309 | 
            +
                            )
         | 
    	
        mapanything/models/__init__.py
    ADDED
    
    | @@ -0,0 +1,190 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            Model Factory for MapAnything
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import importlib.util
         | 
| 11 | 
            +
            import logging
         | 
| 12 | 
            +
            import warnings
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import numpy as np
         | 
| 15 | 
            +
            from omegaconf import DictConfig, OmegaConf
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # Core models that are always available
         | 
| 18 | 
            +
            from mapanything.models.mapanything import (
         | 
| 19 | 
            +
                MapAnything,
         | 
| 20 | 
            +
                MapAnythingAblations,
         | 
| 21 | 
            +
                ModularDUSt3R,
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # Suppress DINOv2 warnings
         | 
| 25 | 
            +
            logging.getLogger("dinov2").setLevel(logging.WARNING)
         | 
| 26 | 
            +
            warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning)
         | 
| 27 | 
            +
            warnings.filterwarnings(
         | 
| 28 | 
            +
                "ignore", message="xFormers is not available", category=UserWarning
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def resolve_special_float(value):
         | 
| 33 | 
            +
                if value == "inf":
         | 
| 34 | 
            +
                    return np.inf
         | 
| 35 | 
            +
                elif value == "-inf":
         | 
| 36 | 
            +
                    return -np.inf
         | 
| 37 | 
            +
                else:
         | 
| 38 | 
            +
                    raise ValueError(f"Unknown special float value: {value}")
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def init_model(
         | 
| 42 | 
            +
                model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False
         | 
| 43 | 
            +
            ):
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                Initialize a model using OmegaConf configuration.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                Args:
         | 
| 48 | 
            +
                    model_str (str): Name of the model class to create.
         | 
| 49 | 
            +
                    model_config (DictConfig): OmegaConf model configuration.
         | 
| 50 | 
            +
                    torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub.
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                if not OmegaConf.has_resolver("special_float"):
         | 
| 53 | 
            +
                    OmegaConf.register_new_resolver("special_float", resolve_special_float)
         | 
| 54 | 
            +
                model_dict = OmegaConf.to_container(model_config, resolve=True)
         | 
| 55 | 
            +
                model = model_factory(
         | 
| 56 | 
            +
                    model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict
         | 
| 57 | 
            +
                )
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                return model
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            # Define model configurations with import paths
         | 
| 63 | 
            +
            MODEL_CONFIGS = {
         | 
| 64 | 
            +
                # Core models
         | 
| 65 | 
            +
                "mapanything": {
         | 
| 66 | 
            +
                    "class": MapAnything,
         | 
| 67 | 
            +
                },
         | 
| 68 | 
            +
                "mapanything_ablations": {
         | 
| 69 | 
            +
                    "class": MapAnythingAblations,
         | 
| 70 | 
            +
                },
         | 
| 71 | 
            +
                "modular_dust3r": {
         | 
| 72 | 
            +
                    "class": ModularDUSt3R,
         | 
| 73 | 
            +
                },
         | 
| 74 | 
            +
                # External models
         | 
| 75 | 
            +
                "anycalib": {
         | 
| 76 | 
            +
                    "module": "mapanything.models.external.anycalib",
         | 
| 77 | 
            +
                    "class_name": "AnyCalibWrapper",
         | 
| 78 | 
            +
                },
         | 
| 79 | 
            +
                "dust3r": {
         | 
| 80 | 
            +
                    "module": "mapanything.models.external.dust3r",
         | 
| 81 | 
            +
                    "class_name": "DUSt3RBAWrapper",
         | 
| 82 | 
            +
                },
         | 
| 83 | 
            +
                "mast3r": {
         | 
| 84 | 
            +
                    "module": "mapanything.models.external.mast3r",
         | 
| 85 | 
            +
                    "class_name": "MASt3RSGAWrapper",
         | 
| 86 | 
            +
                },
         | 
| 87 | 
            +
                "moge": {
         | 
| 88 | 
            +
                    "module": "mapanything.models.external.moge",
         | 
| 89 | 
            +
                    "class_name": "MoGeWrapper",
         | 
| 90 | 
            +
                },
         | 
| 91 | 
            +
                "must3r": {
         | 
| 92 | 
            +
                    "module": "mapanything.models.external.must3r",
         | 
| 93 | 
            +
                    "class_name": "MUSt3RWrapper",
         | 
| 94 | 
            +
                },
         | 
| 95 | 
            +
                "pi3": {
         | 
| 96 | 
            +
                    "module": "mapanything.models.external.pi3",
         | 
| 97 | 
            +
                    "class_name": "Pi3Wrapper",
         | 
| 98 | 
            +
                },
         | 
| 99 | 
            +
                "pow3r": {
         | 
| 100 | 
            +
                    "module": "mapanything.models.external.pow3r",
         | 
| 101 | 
            +
                    "class_name": "Pow3RWrapper",
         | 
| 102 | 
            +
                },
         | 
| 103 | 
            +
                "pow3r_ba": {
         | 
| 104 | 
            +
                    "module": "mapanything.models.external.pow3r",
         | 
| 105 | 
            +
                    "class_name": "Pow3RBAWrapper",
         | 
| 106 | 
            +
                },
         | 
| 107 | 
            +
                "vggt": {
         | 
| 108 | 
            +
                    "module": "mapanything.models.external.vggt",
         | 
| 109 | 
            +
                    "class_name": "VGGTWrapper",
         | 
| 110 | 
            +
                },
         | 
| 111 | 
            +
                # Add other model classes here
         | 
| 112 | 
            +
            }
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def check_module_exists(module_path):
         | 
| 116 | 
            +
                """
         | 
| 117 | 
            +
                Check if a module can be imported without actually importing it.
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                Args:
         | 
| 120 | 
            +
                    module_path (str): The path to the module to check.
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                Returns:
         | 
| 123 | 
            +
                    bool: True if the module can be imported, False otherwise.
         | 
| 124 | 
            +
                """
         | 
| 125 | 
            +
                return importlib.util.find_spec(module_path) is not None
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def model_factory(model_str: str, **kwargs):
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                Model factory for MapAnything.
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                Args:
         | 
| 133 | 
            +
                    model_str (str): Name of the model to create.
         | 
| 134 | 
            +
                    **kwargs: Additional keyword arguments to pass to the model constructor.
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                Returns:
         | 
| 137 | 
            +
                   nn.Module: An instance of the specified model.
         | 
| 138 | 
            +
                """
         | 
| 139 | 
            +
                if model_str not in MODEL_CONFIGS:
         | 
| 140 | 
            +
                    raise ValueError(
         | 
| 141 | 
            +
                        f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}"
         | 
| 142 | 
            +
                    )
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                model_config = MODEL_CONFIGS[model_str]
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                # Handle core models directly
         | 
| 147 | 
            +
                if "class" in model_config:
         | 
| 148 | 
            +
                    model_class = model_config["class"]
         | 
| 149 | 
            +
                # Handle external models with dynamic imports
         | 
| 150 | 
            +
                elif "module" in model_config:
         | 
| 151 | 
            +
                    module_path = model_config["module"]
         | 
| 152 | 
            +
                    class_name = model_config["class_name"]
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # Check if the module can be imported
         | 
| 155 | 
            +
                    if not check_module_exists(module_path):
         | 
| 156 | 
            +
                        raise ImportError(
         | 
| 157 | 
            +
                            f"Model '{model_str}' requires module '{module_path}' which is not installed. "
         | 
| 158 | 
            +
                            f"Please install the corresponding submodule or package."
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    # Dynamically import the module and get the class
         | 
| 162 | 
            +
                    try:
         | 
| 163 | 
            +
                        module = importlib.import_module(module_path)
         | 
| 164 | 
            +
                        model_class = getattr(module, class_name)
         | 
| 165 | 
            +
                    except (ImportError, AttributeError) as e:
         | 
| 166 | 
            +
                        raise ImportError(
         | 
| 167 | 
            +
                            f"Failed to import {class_name} from {module_path}: {str(e)}"
         | 
| 168 | 
            +
                        )
         | 
| 169 | 
            +
                else:
         | 
| 170 | 
            +
                    raise ValueError(f"Invalid model configuration for {model_str}")
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                print(f"Initializing {model_class} with kwargs: {kwargs}")
         | 
| 173 | 
            +
                if model_str != "org_dust3r":
         | 
| 174 | 
            +
                    return model_class(**kwargs)
         | 
| 175 | 
            +
                else:
         | 
| 176 | 
            +
                    eval_str = kwargs.get("model_eval_str", None)
         | 
| 177 | 
            +
                    return eval(eval_str)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
             | 
| 180 | 
            +
            def get_available_models() -> list:
         | 
| 181 | 
            +
                """
         | 
| 182 | 
            +
                Get a list of available models in MapAnything.
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                Returns:
         | 
| 185 | 
            +
                    list: A list of available model names.
         | 
| 186 | 
            +
                """
         | 
| 187 | 
            +
                return list(MODEL_CONFIGS.keys())
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            __all__ = ["model_factory", "get_available_models"]
         | 
    	
        mapanything/models/__pycache__/__init__.cpython-312.pyc
    ADDED
    
    | Binary file (5.63 kB). View file | 
|  | 
    	
        mapanything/models/external/README.md
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # External Model Code for Benchmarking & Re-Training
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            This directory contains external model code that we use to train and benchmark external models fairly. These libraries are not part of the core MapAnything codebase and are included for only benchmarking purposes. The code in this directory is licensed under the same license as the source code from which it was derived, unless otherwise specified.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            The open-source Apache 2.0 License of MapAnything does not apply to these libraries.
         | 
    	
        mapanything/models/external/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        mapanything/models/external/anycalib/__init__.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            Inference wrapper for AnyCalib
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from anycalib import AnyCalib
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from mapanything.utils.geometry import get_rays_in_camera_frame
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class AnyCalibWrapper(torch.nn.Module):
         | 
| 17 | 
            +
                def __init__(
         | 
| 18 | 
            +
                    self,
         | 
| 19 | 
            +
                    name,
         | 
| 20 | 
            +
                    model_id="anycalib_pinhole",
         | 
| 21 | 
            +
                    **kwargs,
         | 
| 22 | 
            +
                ):
         | 
| 23 | 
            +
                    super().__init__()
         | 
| 24 | 
            +
                    self.name = name
         | 
| 25 | 
            +
                    self.model_id = model_id
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    # Initialize the model
         | 
| 28 | 
            +
                    self.model = AnyCalib(model_id=self.model_id)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, views):
         | 
| 31 | 
            +
                    """
         | 
| 32 | 
            +
                    Forward pass wrapper for AnyCalib.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    Assumption:
         | 
| 35 | 
            +
                    - The number of input views is 1.
         | 
| 36 | 
            +
                    - The output camera model is pinhole (fx, fy, cx, cy).
         | 
| 37 | 
            +
                      This can be relaxed by not hardcoding the cam_id.
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    Args:
         | 
| 40 | 
            +
                        views (List[dict]): List of dictionaries containing the input views' images and instance information.
         | 
| 41 | 
            +
                                            Length of the list should be 1.
         | 
| 42 | 
            +
                                            Each dictionary should contain the following keys:
         | 
| 43 | 
            +
                                                "img" (tensor): Image tensor of shape (B, C, H, W).
         | 
| 44 | 
            +
                                                "data_norm_type" (list): ["identity"]
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    Returns:
         | 
| 47 | 
            +
                        List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.
         | 
| 48 | 
            +
                    """
         | 
| 49 | 
            +
                    # Check that the number of input views is 1
         | 
| 50 | 
            +
                    assert len(views) == 1, "AnyCalib only supports 1 input view."
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    # Get input shape of the images and batch size per view
         | 
| 53 | 
            +
                    _, _, height, width = views[0]["img"].shape
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Check the data norm type
         | 
| 56 | 
            +
                    # AnyCalib expects a normalized image but without the DINOv2 mean and std applied ("identity")
         | 
| 57 | 
            +
                    data_norm_type = views[0]["data_norm_type"][0]
         | 
| 58 | 
            +
                    assert data_norm_type == "identity", (
         | 
| 59 | 
            +
                        "AnyCalib expects a normalized image but without the DINOv2 mean and std applied"
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # Run AnyCalib inference
         | 
| 63 | 
            +
                    # Corresponding batched output dictionary:
         | 
| 64 | 
            +
                    # {
         | 
| 65 | 
            +
                    #      "intrinsics": List[(D_i,) tensors] for each camera model "i" at the original input resolution,
         | 
| 66 | 
            +
                    #      "fov_field": (B, N, 2) tensor with the regressed FoV field by the network. N≈320^2 (resolution close to the one seen during training),
         | 
| 67 | 
            +
                    #      "tangent_coords": alias for "fov_field",
         | 
| 68 | 
            +
                    #      "rays": (B, N, 3) tensor with the corresponding (via the exponential map) ray directions in the camera frame (x right, y down, z forward),
         | 
| 69 | 
            +
                    #      "pred_size": (H, W) tuple with the image size used by the network. It can be used e.g. for resizing the FoV/ray fields to the original image size.
         | 
| 70 | 
            +
                    # }
         | 
| 71 | 
            +
                    # For "pinhole" camera model, the intrinsics are (fx, fy, cx, cy).
         | 
| 72 | 
            +
                    model_outputs = self.model.predict(views[0]["img"], cam_id="pinhole")
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # Convert the list of intrinsics to a tensor
         | 
| 75 | 
            +
                    intrinsics = []
         | 
| 76 | 
            +
                    for intrinsics_per_sample in model_outputs["intrinsics"]:
         | 
| 77 | 
            +
                        pred_fx, pred_fy, pred_cx, pred_cy = intrinsics_per_sample
         | 
| 78 | 
            +
                        intrinsics_per_sample = torch.tensor(
         | 
| 79 | 
            +
                            [
         | 
| 80 | 
            +
                                [pred_fx, 0, pred_cx],
         | 
| 81 | 
            +
                                [0, pred_fy, pred_cy],
         | 
| 82 | 
            +
                                [0, 0, 1],
         | 
| 83 | 
            +
                            ],
         | 
| 84 | 
            +
                            device=views[0]["img"].device,
         | 
| 85 | 
            +
                        )
         | 
| 86 | 
            +
                        intrinsics.append(intrinsics_per_sample)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Convert the list of intrinsics to a tensor of size (batch_size_per_view, 3, 3)
         | 
| 89 | 
            +
                    intrinsics = torch.stack(intrinsics)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    # Get the ray directions
         | 
| 92 | 
            +
                    with torch.autocast("cuda", enabled=False):
         | 
| 93 | 
            +
                        _, ray_directions = get_rays_in_camera_frame(
         | 
| 94 | 
            +
                            intrinsics, height, width, normalize_to_unit_sphere=True
         | 
| 95 | 
            +
                        )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    # Return the output in MapAnything format
         | 
| 98 | 
            +
                    res = [{"ray_directions": ray_directions, "intrinsics": intrinsics}]
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    return res
         | 
    	
        mapanything/models/external/dinov2/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            __version__ = "0.0.1"
         | 
    	
        mapanything/models/external/dinov2/hub/__init__.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
    	
        mapanything/models/external/dinov2/hub/backbones.py
    ADDED
    
    | @@ -0,0 +1,183 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from enum import Enum
         | 
| 7 | 
            +
            from typing import Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from mapanything.models.external.dinov2.hub.utils import (
         | 
| 12 | 
            +
                _DINOV2_BASE_URL,
         | 
| 13 | 
            +
                _make_dinov2_model_name,
         | 
| 14 | 
            +
            )
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class Weights(Enum):
         | 
| 18 | 
            +
                LVD142M = "LVD142M"
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def _make_dinov2_model(
         | 
| 22 | 
            +
                *,
         | 
| 23 | 
            +
                arch_name: str = "vit_large",
         | 
| 24 | 
            +
                img_size: int = 518,
         | 
| 25 | 
            +
                patch_size: int = 14,
         | 
| 26 | 
            +
                init_values: float = 1.0,
         | 
| 27 | 
            +
                ffn_layer: str = "mlp",
         | 
| 28 | 
            +
                block_chunks: int = 0,
         | 
| 29 | 
            +
                num_register_tokens: int = 0,
         | 
| 30 | 
            +
                interpolate_antialias: bool = False,
         | 
| 31 | 
            +
                interpolate_offset: float = 0.1,
         | 
| 32 | 
            +
                pretrained: bool = True,
         | 
| 33 | 
            +
                weights: Union[Weights, str] = Weights.LVD142M,
         | 
| 34 | 
            +
                **kwargs,
         | 
| 35 | 
            +
            ):
         | 
| 36 | 
            +
                from ..models import vision_transformer as vits
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                if isinstance(weights, str):
         | 
| 39 | 
            +
                    try:
         | 
| 40 | 
            +
                        weights = Weights[weights]
         | 
| 41 | 
            +
                    except KeyError:
         | 
| 42 | 
            +
                        raise AssertionError(f"Unsupported weights: {weights}")
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                model_base_name = _make_dinov2_model_name(arch_name, patch_size)
         | 
| 45 | 
            +
                vit_kwargs = dict(
         | 
| 46 | 
            +
                    img_size=img_size,
         | 
| 47 | 
            +
                    patch_size=patch_size,
         | 
| 48 | 
            +
                    init_values=init_values,
         | 
| 49 | 
            +
                    ffn_layer=ffn_layer,
         | 
| 50 | 
            +
                    block_chunks=block_chunks,
         | 
| 51 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 52 | 
            +
                    interpolate_antialias=interpolate_antialias,
         | 
| 53 | 
            +
                    interpolate_offset=interpolate_offset,
         | 
| 54 | 
            +
                )
         | 
| 55 | 
            +
                vit_kwargs.update(**kwargs)
         | 
| 56 | 
            +
                model = vits.__dict__[arch_name](**vit_kwargs)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                if pretrained:
         | 
| 59 | 
            +
                    model_full_name = _make_dinov2_model_name(
         | 
| 60 | 
            +
                        arch_name, patch_size, num_register_tokens
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
                    url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
         | 
| 63 | 
            +
                    state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
         | 
| 64 | 
            +
                    model.load_state_dict(state_dict, strict=True)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return model
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def dinov2_vits14(
         | 
| 70 | 
            +
                *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
         | 
| 71 | 
            +
            ):
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
                DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                return _make_dinov2_model(
         | 
| 76 | 
            +
                    arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs
         | 
| 77 | 
            +
                )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def dinov2_vitb14(
         | 
| 81 | 
            +
                *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
         | 
| 82 | 
            +
            ):
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
                DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                return _make_dinov2_model(
         | 
| 87 | 
            +
                    arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs
         | 
| 88 | 
            +
                )
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def dinov2_vitl14(
         | 
| 92 | 
            +
                *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
         | 
| 93 | 
            +
            ):
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                return _make_dinov2_model(
         | 
| 98 | 
            +
                    arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs
         | 
| 99 | 
            +
                )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def dinov2_vitg14(
         | 
| 103 | 
            +
                *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
         | 
| 104 | 
            +
            ):
         | 
| 105 | 
            +
                """
         | 
| 106 | 
            +
                DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 107 | 
            +
                """
         | 
| 108 | 
            +
                return _make_dinov2_model(
         | 
| 109 | 
            +
                    arch_name="vit_giant2",
         | 
| 110 | 
            +
                    ffn_layer="swiglufused",
         | 
| 111 | 
            +
                    weights=weights,
         | 
| 112 | 
            +
                    pretrained=pretrained,
         | 
| 113 | 
            +
                    **kwargs,
         | 
| 114 | 
            +
                )
         | 
| 115 | 
            +
             | 
| 116 | 
            +
             | 
| 117 | 
            +
            def dinov2_vits14_reg(
         | 
| 118 | 
            +
                *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
         | 
| 119 | 
            +
            ):
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 122 | 
            +
                """
         | 
| 123 | 
            +
                return _make_dinov2_model(
         | 
| 124 | 
            +
                    arch_name="vit_small",
         | 
| 125 | 
            +
                    pretrained=pretrained,
         | 
| 126 | 
            +
                    weights=weights,
         | 
| 127 | 
            +
                    num_register_tokens=4,
         | 
| 128 | 
            +
                    interpolate_antialias=True,
         | 
| 129 | 
            +
                    interpolate_offset=0.0,
         | 
| 130 | 
            +
                    **kwargs,
         | 
| 131 | 
            +
                )
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def dinov2_vitb14_reg(
         | 
| 135 | 
            +
                *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
         | 
| 136 | 
            +
            ):
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 139 | 
            +
                """
         | 
| 140 | 
            +
                return _make_dinov2_model(
         | 
| 141 | 
            +
                    arch_name="vit_base",
         | 
| 142 | 
            +
                    pretrained=pretrained,
         | 
| 143 | 
            +
                    weights=weights,
         | 
| 144 | 
            +
                    num_register_tokens=4,
         | 
| 145 | 
            +
                    interpolate_antialias=True,
         | 
| 146 | 
            +
                    interpolate_offset=0.0,
         | 
| 147 | 
            +
                    **kwargs,
         | 
| 148 | 
            +
                )
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            def dinov2_vitl14_reg(
         | 
| 152 | 
            +
                *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
         | 
| 153 | 
            +
            ):
         | 
| 154 | 
            +
                """
         | 
| 155 | 
            +
                DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 156 | 
            +
                """
         | 
| 157 | 
            +
                return _make_dinov2_model(
         | 
| 158 | 
            +
                    arch_name="vit_large",
         | 
| 159 | 
            +
                    pretrained=pretrained,
         | 
| 160 | 
            +
                    weights=weights,
         | 
| 161 | 
            +
                    num_register_tokens=4,
         | 
| 162 | 
            +
                    interpolate_antialias=True,
         | 
| 163 | 
            +
                    interpolate_offset=0.0,
         | 
| 164 | 
            +
                    **kwargs,
         | 
| 165 | 
            +
                )
         | 
| 166 | 
            +
             | 
| 167 | 
            +
             | 
| 168 | 
            +
            def dinov2_vitg14_reg(
         | 
| 169 | 
            +
                *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
         | 
| 170 | 
            +
            ):
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 173 | 
            +
                """
         | 
| 174 | 
            +
                return _make_dinov2_model(
         | 
| 175 | 
            +
                    arch_name="vit_giant2",
         | 
| 176 | 
            +
                    ffn_layer="swiglufused",
         | 
| 177 | 
            +
                    weights=weights,
         | 
| 178 | 
            +
                    pretrained=pretrained,
         | 
| 179 | 
            +
                    num_register_tokens=4,
         | 
| 180 | 
            +
                    interpolate_antialias=True,
         | 
| 181 | 
            +
                    interpolate_offset=0.0,
         | 
| 182 | 
            +
                    **kwargs,
         | 
| 183 | 
            +
                )
         | 
    	
        mapanything/models/external/dinov2/hub/utils.py
    ADDED
    
    | @@ -0,0 +1,42 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import itertools
         | 
| 7 | 
            +
            import math
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def _make_dinov2_model_name(
         | 
| 17 | 
            +
                arch_name: str, patch_size: int, num_register_tokens: int = 0
         | 
| 18 | 
            +
            ) -> str:
         | 
| 19 | 
            +
                compact_arch_name = arch_name.replace("_", "")[:4]
         | 
| 20 | 
            +
                registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
         | 
| 21 | 
            +
                return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class CenterPadding(nn.Module):
         | 
| 25 | 
            +
                def __init__(self, multiple):
         | 
| 26 | 
            +
                    super().__init__()
         | 
| 27 | 
            +
                    self.multiple = multiple
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def _get_pad(self, size):
         | 
| 30 | 
            +
                    new_size = math.ceil(size / self.multiple) * self.multiple
         | 
| 31 | 
            +
                    pad_size = new_size - size
         | 
| 32 | 
            +
                    pad_size_left = pad_size // 2
         | 
| 33 | 
            +
                    pad_size_right = pad_size - pad_size_left
         | 
| 34 | 
            +
                    return pad_size_left, pad_size_right
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                @torch.inference_mode()
         | 
| 37 | 
            +
                def forward(self, x):
         | 
| 38 | 
            +
                    pads = list(
         | 
| 39 | 
            +
                        itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
                    output = F.pad(x, pads)
         | 
| 42 | 
            +
                    return output
         | 
    	
        mapanything/models/external/dinov2/layers/__init__.py
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from mapanything.models.external.dinov2.layers.dino_head import DINOHead  # noqa
         | 
| 7 | 
            +
            from mapanything.models.external.dinov2.layers.mlp import Mlp  # noqa
         | 
| 8 | 
            +
            from mapanything.models.external.dinov2.layers.patch_embed import PatchEmbed  # noqa
         | 
| 9 | 
            +
            from mapanything.models.external.dinov2.layers.swiglu_ffn import (
         | 
| 10 | 
            +
                SwiGLUFFN,  # noqa
         | 
| 11 | 
            +
                SwiGLUFFNFused,  # noqa
         | 
| 12 | 
            +
            )
         | 
| 13 | 
            +
            from mapanything.models.external.dinov2.layers.block import NestedTensorBlock  # noqa
         | 
| 14 | 
            +
            from mapanything.models.external.dinov2.layers.attention import MemEffAttention  # noqa
         | 
    	
        mapanything/models/external/dinov2/layers/attention.py
    ADDED
    
    | @@ -0,0 +1,90 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from torch import nn, Tensor
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            logger = logging.getLogger("dinov2")
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
         | 
| 19 | 
            +
            try:
         | 
| 20 | 
            +
                if XFORMERS_ENABLED:
         | 
| 21 | 
            +
                    from xformers.ops import memory_efficient_attention, unbind
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    XFORMERS_AVAILABLE = True
         | 
| 24 | 
            +
                    # warnings.warn("xFormers is available (Attention)")
         | 
| 25 | 
            +
                else:
         | 
| 26 | 
            +
                    # warnings.warn("xFormers is disabled (Attention)")
         | 
| 27 | 
            +
                    raise ImportError
         | 
| 28 | 
            +
            except ImportError:
         | 
| 29 | 
            +
                XFORMERS_AVAILABLE = False
         | 
| 30 | 
            +
                # warnings.warn("xFormers is not available (Attention)")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class Attention(nn.Module):
         | 
| 34 | 
            +
                def __init__(
         | 
| 35 | 
            +
                    self,
         | 
| 36 | 
            +
                    dim: int,
         | 
| 37 | 
            +
                    num_heads: int = 8,
         | 
| 38 | 
            +
                    qkv_bias: bool = False,
         | 
| 39 | 
            +
                    proj_bias: bool = True,
         | 
| 40 | 
            +
                    attn_drop: float = 0.0,
         | 
| 41 | 
            +
                    proj_drop: float = 0.0,
         | 
| 42 | 
            +
                ) -> None:
         | 
| 43 | 
            +
                    super().__init__()
         | 
| 44 | 
            +
                    self.num_heads = num_heads
         | 
| 45 | 
            +
                    head_dim = dim // num_heads
         | 
| 46 | 
            +
                    self.scale = head_dim**-0.5
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 49 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 50 | 
            +
                    self.proj = nn.Linear(dim, dim, bias=proj_bias)
         | 
| 51 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def forward(self, x: Tensor, attn_bias=None) -> Tensor:
         | 
| 54 | 
            +
                    B, N, C = x.shape
         | 
| 55 | 
            +
                    qkv = (
         | 
| 56 | 
            +
                        self.qkv(x)
         | 
| 57 | 
            +
                        .reshape(B, N, 3, self.num_heads, C // self.num_heads)
         | 
| 58 | 
            +
                        .permute(2, 0, 3, 1, 4)
         | 
| 59 | 
            +
                    )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
         | 
| 62 | 
            +
                    attn = q @ k.transpose(-2, -1)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 65 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
         | 
| 68 | 
            +
                    x = self.proj(x)
         | 
| 69 | 
            +
                    x = self.proj_drop(x)
         | 
| 70 | 
            +
                    return x
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            class MemEffAttention(Attention):
         | 
| 74 | 
            +
                def forward(self, x: Tensor, attn_bias=None) -> Tensor:
         | 
| 75 | 
            +
                    if not XFORMERS_AVAILABLE:
         | 
| 76 | 
            +
                        if attn_bias is not None:
         | 
| 77 | 
            +
                            raise AssertionError("xFormers is required for using nested tensors")
         | 
| 78 | 
            +
                        return super().forward(x)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    B, N, C = x.shape
         | 
| 81 | 
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    q, k, v = unbind(qkv, 2)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
         | 
| 86 | 
            +
                    x = x.reshape([B, N, C])
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    x = self.proj(x)
         | 
| 89 | 
            +
                    x = self.proj_drop(x)
         | 
| 90 | 
            +
                    return x
         | 
    	
        mapanything/models/external/dinov2/layers/block.py
    ADDED
    
    | @@ -0,0 +1,290 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            from typing import Any, Callable, Dict, List, Tuple
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            from torch import nn, Tensor
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from mapanything.models.external.dinov2.layers.attention import (
         | 
| 18 | 
            +
                Attention,
         | 
| 19 | 
            +
                MemEffAttention,
         | 
| 20 | 
            +
            )
         | 
| 21 | 
            +
            from mapanything.models.external.dinov2.layers.drop_path import DropPath
         | 
| 22 | 
            +
            from mapanything.models.external.dinov2.layers.layer_scale import LayerScale
         | 
| 23 | 
            +
            from mapanything.models.external.dinov2.layers.mlp import Mlp
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            logger = logging.getLogger("dinov2")
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
         | 
| 29 | 
            +
            try:
         | 
| 30 | 
            +
                if XFORMERS_ENABLED:
         | 
| 31 | 
            +
                    from xformers.ops import fmha, index_select_cat, scaled_index_add
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    XFORMERS_AVAILABLE = True
         | 
| 34 | 
            +
                    # warnings.warn("xFormers is available (Block)")
         | 
| 35 | 
            +
                else:
         | 
| 36 | 
            +
                    # warnings.warn("xFormers is disabled (Block)")
         | 
| 37 | 
            +
                    raise ImportError
         | 
| 38 | 
            +
            except ImportError:
         | 
| 39 | 
            +
                XFORMERS_AVAILABLE = False
         | 
| 40 | 
            +
                # warnings.warn("xFormers is not available (Block)")
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            class Block(nn.Module):
         | 
| 44 | 
            +
                def __init__(
         | 
| 45 | 
            +
                    self,
         | 
| 46 | 
            +
                    dim: int,
         | 
| 47 | 
            +
                    num_heads: int,
         | 
| 48 | 
            +
                    mlp_ratio: float = 4.0,
         | 
| 49 | 
            +
                    qkv_bias: bool = False,
         | 
| 50 | 
            +
                    proj_bias: bool = True,
         | 
| 51 | 
            +
                    ffn_bias: bool = True,
         | 
| 52 | 
            +
                    drop: float = 0.0,
         | 
| 53 | 
            +
                    attn_drop: float = 0.0,
         | 
| 54 | 
            +
                    init_values=None,
         | 
| 55 | 
            +
                    drop_path: float = 0.0,
         | 
| 56 | 
            +
                    act_layer: Callable[..., nn.Module] = nn.GELU,
         | 
| 57 | 
            +
                    norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
         | 
| 58 | 
            +
                    attn_class: Callable[..., nn.Module] = Attention,
         | 
| 59 | 
            +
                    ffn_layer: Callable[..., nn.Module] = Mlp,
         | 
| 60 | 
            +
                ) -> None:
         | 
| 61 | 
            +
                    super().__init__()
         | 
| 62 | 
            +
                    # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
         | 
| 63 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 64 | 
            +
                    self.attn = attn_class(
         | 
| 65 | 
            +
                        dim,
         | 
| 66 | 
            +
                        num_heads=num_heads,
         | 
| 67 | 
            +
                        qkv_bias=qkv_bias,
         | 
| 68 | 
            +
                        proj_bias=proj_bias,
         | 
| 69 | 
            +
                        attn_drop=attn_drop,
         | 
| 70 | 
            +
                        proj_drop=drop,
         | 
| 71 | 
            +
                    )
         | 
| 72 | 
            +
                    self.ls1 = (
         | 
| 73 | 
            +
                        LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                    self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 78 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 79 | 
            +
                    self.mlp = ffn_layer(
         | 
| 80 | 
            +
                        in_features=dim,
         | 
| 81 | 
            +
                        hidden_features=mlp_hidden_dim,
         | 
| 82 | 
            +
                        act_layer=act_layer,
         | 
| 83 | 
            +
                        drop=drop,
         | 
| 84 | 
            +
                        bias=ffn_bias,
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    self.ls2 = (
         | 
| 87 | 
            +
                        LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
                    self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    self.sample_drop_ratio = drop_path
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 94 | 
            +
                    def attn_residual_func(x: Tensor) -> Tensor:
         | 
| 95 | 
            +
                        return self.ls1(self.attn(self.norm1(x)))
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    def ffn_residual_func(x: Tensor) -> Tensor:
         | 
| 98 | 
            +
                        return self.ls2(self.mlp(self.norm2(x)))
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    if self.training and self.sample_drop_ratio > 0.1:
         | 
| 101 | 
            +
                        # the overhead is compensated only for a drop path rate larger than 0.1
         | 
| 102 | 
            +
                        x = drop_add_residual_stochastic_depth(
         | 
| 103 | 
            +
                            x,
         | 
| 104 | 
            +
                            residual_func=attn_residual_func,
         | 
| 105 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 106 | 
            +
                        )
         | 
| 107 | 
            +
                        x = drop_add_residual_stochastic_depth(
         | 
| 108 | 
            +
                            x,
         | 
| 109 | 
            +
                            residual_func=ffn_residual_func,
         | 
| 110 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 111 | 
            +
                        )
         | 
| 112 | 
            +
                    elif self.training and self.sample_drop_ratio > 0.0:
         | 
| 113 | 
            +
                        x = x + self.drop_path1(attn_residual_func(x))
         | 
| 114 | 
            +
                        x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        x = x + attn_residual_func(x)
         | 
| 117 | 
            +
                        x = x + ffn_residual_func(x)
         | 
| 118 | 
            +
                    return x
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def drop_add_residual_stochastic_depth(
         | 
| 122 | 
            +
                x: Tensor,
         | 
| 123 | 
            +
                residual_func: Callable[[Tensor], Tensor],
         | 
| 124 | 
            +
                sample_drop_ratio: float = 0.0,
         | 
| 125 | 
            +
            ) -> Tensor:
         | 
| 126 | 
            +
                # 1) extract subset using permutation
         | 
| 127 | 
            +
                b, n, d = x.shape
         | 
| 128 | 
            +
                sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
         | 
| 129 | 
            +
                brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
         | 
| 130 | 
            +
                x_subset = x[brange]
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                # 2) apply residual_func to get residual
         | 
| 133 | 
            +
                residual = residual_func(x_subset)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                x_flat = x.flatten(1)
         | 
| 136 | 
            +
                residual = residual.flatten(1)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                residual_scale_factor = b / sample_subset_size
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                # 3) add the residual
         | 
| 141 | 
            +
                x_plus_residual = torch.index_add(
         | 
| 142 | 
            +
                    x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
         | 
| 143 | 
            +
                )
         | 
| 144 | 
            +
                return x_plus_residual.view_as(x)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def get_branges_scales(x, sample_drop_ratio=0.0):
         | 
| 148 | 
            +
                b, n, d = x.shape
         | 
| 149 | 
            +
                sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
         | 
| 150 | 
            +
                brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
         | 
| 151 | 
            +
                residual_scale_factor = b / sample_subset_size
         | 
| 152 | 
            +
                return brange, residual_scale_factor
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
         | 
| 156 | 
            +
                if scaling_vector is None:
         | 
| 157 | 
            +
                    x_flat = x.flatten(1)
         | 
| 158 | 
            +
                    residual = residual.flatten(1)
         | 
| 159 | 
            +
                    x_plus_residual = torch.index_add(
         | 
| 160 | 
            +
                        x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
         | 
| 161 | 
            +
                    )
         | 
| 162 | 
            +
                else:
         | 
| 163 | 
            +
                    x_plus_residual = scaled_index_add(
         | 
| 164 | 
            +
                        x,
         | 
| 165 | 
            +
                        brange,
         | 
| 166 | 
            +
                        residual.to(dtype=x.dtype),
         | 
| 167 | 
            +
                        scaling=scaling_vector,
         | 
| 168 | 
            +
                        alpha=residual_scale_factor,
         | 
| 169 | 
            +
                    )
         | 
| 170 | 
            +
                return x_plus_residual
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            attn_bias_cache: Dict[Tuple, Any] = {}
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 176 | 
            +
            def get_attn_bias_and_cat(x_list, branges=None):
         | 
| 177 | 
            +
                """
         | 
| 178 | 
            +
                this will perform the index select, cat the tensors, and provide the attn_bias from cache
         | 
| 179 | 
            +
                """
         | 
| 180 | 
            +
                batch_sizes = (
         | 
| 181 | 
            +
                    [b.shape[0] for b in branges]
         | 
| 182 | 
            +
                    if branges is not None
         | 
| 183 | 
            +
                    else [x.shape[0] for x in x_list]
         | 
| 184 | 
            +
                )
         | 
| 185 | 
            +
                all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
         | 
| 186 | 
            +
                if all_shapes not in attn_bias_cache.keys():
         | 
| 187 | 
            +
                    seqlens = []
         | 
| 188 | 
            +
                    for b, x in zip(batch_sizes, x_list):
         | 
| 189 | 
            +
                        for _ in range(b):
         | 
| 190 | 
            +
                            seqlens.append(x.shape[1])
         | 
| 191 | 
            +
                    attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
         | 
| 192 | 
            +
                    attn_bias._batch_sizes = batch_sizes
         | 
| 193 | 
            +
                    attn_bias_cache[all_shapes] = attn_bias
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                if branges is not None:
         | 
| 196 | 
            +
                    cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
         | 
| 197 | 
            +
                        1, -1, x_list[0].shape[-1]
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
                else:
         | 
| 200 | 
            +
                    tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
         | 
| 201 | 
            +
                    cat_tensors = torch.cat(tensors_bs1, dim=1)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                return attn_bias_cache[all_shapes], cat_tensors
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            def drop_add_residual_stochastic_depth_list(
         | 
| 207 | 
            +
                x_list: List[Tensor],
         | 
| 208 | 
            +
                residual_func: Callable[[Tensor, Any], Tensor],
         | 
| 209 | 
            +
                sample_drop_ratio: float = 0.0,
         | 
| 210 | 
            +
                scaling_vector=None,
         | 
| 211 | 
            +
            ) -> Tensor:
         | 
| 212 | 
            +
                # 1) generate random set of indices for dropping samples in the batch
         | 
| 213 | 
            +
                branges_scales = [
         | 
| 214 | 
            +
                    get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
         | 
| 215 | 
            +
                ]
         | 
| 216 | 
            +
                branges = [s[0] for s in branges_scales]
         | 
| 217 | 
            +
                residual_scale_factors = [s[1] for s in branges_scales]
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                # 2) get attention bias and index+concat the tensors
         | 
| 220 | 
            +
                attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                # 3) apply residual_func to get residual, and split the result
         | 
| 223 | 
            +
                residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                outputs = []
         | 
| 226 | 
            +
                for x, brange, residual, residual_scale_factor in zip(
         | 
| 227 | 
            +
                    x_list, branges, residual_list, residual_scale_factors
         | 
| 228 | 
            +
                ):
         | 
| 229 | 
            +
                    outputs.append(
         | 
| 230 | 
            +
                        add_residual(
         | 
| 231 | 
            +
                            x, brange, residual, residual_scale_factor, scaling_vector
         | 
| 232 | 
            +
                        ).view_as(x)
         | 
| 233 | 
            +
                    )
         | 
| 234 | 
            +
                return outputs
         | 
| 235 | 
            +
             | 
| 236 | 
            +
             | 
| 237 | 
            +
            class NestedTensorBlock(Block):
         | 
| 238 | 
            +
                def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
         | 
| 239 | 
            +
                    """
         | 
| 240 | 
            +
                    x_list contains a list of tensors to nest together and run
         | 
| 241 | 
            +
                    """
         | 
| 242 | 
            +
                    assert isinstance(self.attn, MemEffAttention)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    if self.training and self.sample_drop_ratio > 0.0:
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                        def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 247 | 
            +
                            return self.attn(self.norm1(x), attn_bias=attn_bias)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                        def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 250 | 
            +
                            return self.mlp(self.norm2(x))
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                        x_list = drop_add_residual_stochastic_depth_list(
         | 
| 253 | 
            +
                            x_list,
         | 
| 254 | 
            +
                            residual_func=attn_residual_func,
         | 
| 255 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 256 | 
            +
                            scaling_vector=self.ls1.gamma
         | 
| 257 | 
            +
                            if isinstance(self.ls1, LayerScale)
         | 
| 258 | 
            +
                            else None,
         | 
| 259 | 
            +
                        )
         | 
| 260 | 
            +
                        x_list = drop_add_residual_stochastic_depth_list(
         | 
| 261 | 
            +
                            x_list,
         | 
| 262 | 
            +
                            residual_func=ffn_residual_func,
         | 
| 263 | 
            +
                            sample_drop_ratio=self.sample_drop_ratio,
         | 
| 264 | 
            +
                            scaling_vector=self.ls2.gamma
         | 
| 265 | 
            +
                            if isinstance(self.ls1, LayerScale)
         | 
| 266 | 
            +
                            else None,
         | 
| 267 | 
            +
                        )
         | 
| 268 | 
            +
                        return x_list
         | 
| 269 | 
            +
                    else:
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                        def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 272 | 
            +
                            return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                        def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
         | 
| 275 | 
            +
                            return self.ls2(self.mlp(self.norm2(x)))
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                        attn_bias, x = get_attn_bias_and_cat(x_list)
         | 
| 278 | 
            +
                        x = x + attn_residual_func(x, attn_bias=attn_bias)
         | 
| 279 | 
            +
                        x = x + ffn_residual_func(x)
         | 
| 280 | 
            +
                        return attn_bias.split(x)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                def forward(self, x_or_x_list):
         | 
| 283 | 
            +
                    if isinstance(x_or_x_list, Tensor):
         | 
| 284 | 
            +
                        return super().forward(x_or_x_list)
         | 
| 285 | 
            +
                    elif isinstance(x_or_x_list, list):
         | 
| 286 | 
            +
                        if not XFORMERS_AVAILABLE:
         | 
| 287 | 
            +
                            raise AssertionError("xFormers is required for using nested tensors")
         | 
| 288 | 
            +
                        return self.forward_nested(x_or_x_list)
         | 
| 289 | 
            +
                    else:
         | 
| 290 | 
            +
                        raise AssertionError
         | 
    	
        mapanything/models/external/dinov2/layers/dino_head.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            from torch.nn.init import trunc_normal_
         | 
| 9 | 
            +
            from torch.nn.utils import weight_norm
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class DINOHead(nn.Module):
         | 
| 13 | 
            +
                def __init__(
         | 
| 14 | 
            +
                    self,
         | 
| 15 | 
            +
                    in_dim,
         | 
| 16 | 
            +
                    out_dim,
         | 
| 17 | 
            +
                    use_bn=False,
         | 
| 18 | 
            +
                    nlayers=3,
         | 
| 19 | 
            +
                    hidden_dim=2048,
         | 
| 20 | 
            +
                    bottleneck_dim=256,
         | 
| 21 | 
            +
                    mlp_bias=True,
         | 
| 22 | 
            +
                ):
         | 
| 23 | 
            +
                    super().__init__()
         | 
| 24 | 
            +
                    nlayers = max(nlayers, 1)
         | 
| 25 | 
            +
                    self.mlp = _build_mlp(
         | 
| 26 | 
            +
                        nlayers,
         | 
| 27 | 
            +
                        in_dim,
         | 
| 28 | 
            +
                        bottleneck_dim,
         | 
| 29 | 
            +
                        hidden_dim=hidden_dim,
         | 
| 30 | 
            +
                        use_bn=use_bn,
         | 
| 31 | 
            +
                        bias=mlp_bias,
         | 
| 32 | 
            +
                    )
         | 
| 33 | 
            +
                    self.apply(self._init_weights)
         | 
| 34 | 
            +
                    self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
         | 
| 35 | 
            +
                    self.last_layer.weight_g.data.fill_(1)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def _init_weights(self, m):
         | 
| 38 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 39 | 
            +
                        trunc_normal_(m.weight, std=0.02)
         | 
| 40 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 41 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def forward(self, x):
         | 
| 44 | 
            +
                    x = self.mlp(x)
         | 
| 45 | 
            +
                    eps = 1e-6 if x.dtype == torch.float16 else 1e-12
         | 
| 46 | 
            +
                    x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
         | 
| 47 | 
            +
                    x = self.last_layer(x)
         | 
| 48 | 
            +
                    return x
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def _build_mlp(
         | 
| 52 | 
            +
                nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
         | 
| 53 | 
            +
            ):
         | 
| 54 | 
            +
                if nlayers == 1:
         | 
| 55 | 
            +
                    return nn.Linear(in_dim, bottleneck_dim, bias=bias)
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
         | 
| 58 | 
            +
                    if use_bn:
         | 
| 59 | 
            +
                        layers.append(nn.BatchNorm1d(hidden_dim))
         | 
| 60 | 
            +
                    layers.append(nn.GELU())
         | 
| 61 | 
            +
                    for _ in range(nlayers - 2):
         | 
| 62 | 
            +
                        layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
         | 
| 63 | 
            +
                        if use_bn:
         | 
| 64 | 
            +
                            layers.append(nn.BatchNorm1d(hidden_dim))
         | 
| 65 | 
            +
                        layers.append(nn.GELU())
         | 
| 66 | 
            +
                    layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
         | 
| 67 | 
            +
                    return nn.Sequential(*layers)
         | 
    	
        mapanything/models/external/dinov2/layers/drop_path.py
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            from torch import nn
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def drop_path(x, drop_prob: float = 0.0, training: bool = False):
         | 
| 15 | 
            +
                if drop_prob == 0.0 or not training:
         | 
| 16 | 
            +
                    return x
         | 
| 17 | 
            +
                keep_prob = 1 - drop_prob
         | 
| 18 | 
            +
                shape = (x.shape[0],) + (1,) * (
         | 
| 19 | 
            +
                    x.ndim - 1
         | 
| 20 | 
            +
                )  # work with diff dim tensors, not just 2D ConvNets
         | 
| 21 | 
            +
                random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
         | 
| 22 | 
            +
                if keep_prob > 0.0:
         | 
| 23 | 
            +
                    random_tensor.div_(keep_prob)
         | 
| 24 | 
            +
                output = x * random_tensor
         | 
| 25 | 
            +
                return output
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class DropPath(nn.Module):
         | 
| 29 | 
            +
                """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def __init__(self, drop_prob=None):
         | 
| 32 | 
            +
                    super(DropPath, self).__init__()
         | 
| 33 | 
            +
                    self.drop_prob = drop_prob
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def forward(self, x):
         | 
| 36 | 
            +
                    return drop_path(x, self.drop_prob, self.training)
         | 
    	
        mapanything/models/external/dinov2/layers/layer_scale.py
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from typing import Union
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch import nn, Tensor
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class LayerScale(nn.Module):
         | 
| 15 | 
            +
                def __init__(
         | 
| 16 | 
            +
                    self,
         | 
| 17 | 
            +
                    dim: int,
         | 
| 18 | 
            +
                    init_values: Union[float, Tensor] = 1e-5,
         | 
| 19 | 
            +
                    inplace: bool = False,
         | 
| 20 | 
            +
                ) -> None:
         | 
| 21 | 
            +
                    super().__init__()
         | 
| 22 | 
            +
                    self.inplace = inplace
         | 
| 23 | 
            +
                    self.gamma = nn.Parameter(init_values * torch.ones(dim))
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 26 | 
            +
                    return x.mul_(self.gamma) if self.inplace else x * self.gamma
         | 
    	
        mapanything/models/external/dinov2/layers/mlp.py
    ADDED
    
    | @@ -0,0 +1,40 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            from typing import Callable, Optional
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from torch import nn, Tensor
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class Mlp(nn.Module):
         | 
| 17 | 
            +
                def __init__(
         | 
| 18 | 
            +
                    self,
         | 
| 19 | 
            +
                    in_features: int,
         | 
| 20 | 
            +
                    hidden_features: Optional[int] = None,
         | 
| 21 | 
            +
                    out_features: Optional[int] = None,
         | 
| 22 | 
            +
                    act_layer: Callable[..., nn.Module] = nn.GELU,
         | 
| 23 | 
            +
                    drop: float = 0.0,
         | 
| 24 | 
            +
                    bias: bool = True,
         | 
| 25 | 
            +
                ) -> None:
         | 
| 26 | 
            +
                    super().__init__()
         | 
| 27 | 
            +
                    out_features = out_features or in_features
         | 
| 28 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 29 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
         | 
| 30 | 
            +
                    self.act = act_layer()
         | 
| 31 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
         | 
| 32 | 
            +
                    self.drop = nn.Dropout(drop)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 35 | 
            +
                    x = self.fc1(x)
         | 
| 36 | 
            +
                    x = self.act(x)
         | 
| 37 | 
            +
                    x = self.drop(x)
         | 
| 38 | 
            +
                    x = self.fc2(x)
         | 
| 39 | 
            +
                    x = self.drop(x)
         | 
| 40 | 
            +
                    return x
         | 
    	
        mapanything/models/external/dinov2/layers/patch_embed.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from typing import Callable, Optional, Tuple, Union
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch.nn as nn
         | 
| 13 | 
            +
            from torch import Tensor
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def make_2tuple(x):
         | 
| 17 | 
            +
                if isinstance(x, tuple):
         | 
| 18 | 
            +
                    assert len(x) == 2
         | 
| 19 | 
            +
                    return x
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                assert isinstance(x, int)
         | 
| 22 | 
            +
                return (x, x)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class PatchEmbed(nn.Module):
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                2D image to patch embedding: (B,C,H,W) -> (B,N,D)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Args:
         | 
| 30 | 
            +
                    img_size: Image size.
         | 
| 31 | 
            +
                    patch_size: Patch token size.
         | 
| 32 | 
            +
                    in_chans: Number of input image channels.
         | 
| 33 | 
            +
                    embed_dim: Number of linear projection output channels.
         | 
| 34 | 
            +
                    norm_layer: Normalization layer.
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __init__(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                    img_size: Union[int, Tuple[int, int]] = 224,
         | 
| 40 | 
            +
                    patch_size: Union[int, Tuple[int, int]] = 16,
         | 
| 41 | 
            +
                    in_chans: int = 3,
         | 
| 42 | 
            +
                    embed_dim: int = 768,
         | 
| 43 | 
            +
                    norm_layer: Optional[Callable] = None,
         | 
| 44 | 
            +
                    flatten_embedding: bool = True,
         | 
| 45 | 
            +
                ) -> None:
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    image_HW = make_2tuple(img_size)
         | 
| 49 | 
            +
                    patch_HW = make_2tuple(patch_size)
         | 
| 50 | 
            +
                    patch_grid_size = (
         | 
| 51 | 
            +
                        image_HW[0] // patch_HW[0],
         | 
| 52 | 
            +
                        image_HW[1] // patch_HW[1],
         | 
| 53 | 
            +
                    )
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    self.img_size = image_HW
         | 
| 56 | 
            +
                    self.patch_size = patch_HW
         | 
| 57 | 
            +
                    self.patches_resolution = patch_grid_size
         | 
| 58 | 
            +
                    self.num_patches = patch_grid_size[0] * patch_grid_size[1]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.in_chans = in_chans
         | 
| 61 | 
            +
                    self.embed_dim = embed_dim
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    self.flatten_embedding = flatten_embedding
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    self.proj = nn.Conv2d(
         | 
| 66 | 
            +
                        in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 71 | 
            +
                    _, _, H, W = x.shape
         | 
| 72 | 
            +
                    patch_H, patch_W = self.patch_size
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    assert H % patch_H == 0, (
         | 
| 75 | 
            +
                        f"Input image height {H} is not a multiple of patch height {patch_H}"
         | 
| 76 | 
            +
                    )
         | 
| 77 | 
            +
                    assert W % patch_W == 0, (
         | 
| 78 | 
            +
                        f"Input image width {W} is not a multiple of patch width: {patch_W}"
         | 
| 79 | 
            +
                    )
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    x = self.proj(x)  # B C H W
         | 
| 82 | 
            +
                    H, W = x.size(2), x.size(3)
         | 
| 83 | 
            +
                    x = x.flatten(2).transpose(1, 2)  # B HW C
         | 
| 84 | 
            +
                    x = self.norm(x)
         | 
| 85 | 
            +
                    if not self.flatten_embedding:
         | 
| 86 | 
            +
                        x = x.reshape(-1, H, W, self.embed_dim)  # B H W C
         | 
| 87 | 
            +
                    return x
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def flops(self) -> float:
         | 
| 90 | 
            +
                    Ho, Wo = self.patches_resolution
         | 
| 91 | 
            +
                    flops = (
         | 
| 92 | 
            +
                        Ho
         | 
| 93 | 
            +
                        * Wo
         | 
| 94 | 
            +
                        * self.embed_dim
         | 
| 95 | 
            +
                        * self.in_chans
         | 
| 96 | 
            +
                        * (self.patch_size[0] * self.patch_size[1])
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
                    if self.norm is not None:
         | 
| 99 | 
            +
                        flops += Ho * Wo * self.embed_dim
         | 
| 100 | 
            +
                    return flops
         | 
    	
        mapanything/models/external/dinov2/layers/swiglu_ffn.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            from typing import Callable, Optional
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
            from torch import nn, Tensor
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class SwiGLUFFN(nn.Module):
         | 
| 14 | 
            +
                def __init__(
         | 
| 15 | 
            +
                    self,
         | 
| 16 | 
            +
                    in_features: int,
         | 
| 17 | 
            +
                    hidden_features: Optional[int] = None,
         | 
| 18 | 
            +
                    out_features: Optional[int] = None,
         | 
| 19 | 
            +
                    act_layer: Callable[..., nn.Module] = None,
         | 
| 20 | 
            +
                    drop: float = 0.0,
         | 
| 21 | 
            +
                    bias: bool = True,
         | 
| 22 | 
            +
                ) -> None:
         | 
| 23 | 
            +
                    super().__init__()
         | 
| 24 | 
            +
                    out_features = out_features or in_features
         | 
| 25 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 26 | 
            +
                    self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
         | 
| 27 | 
            +
                    self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def forward(self, x: Tensor) -> Tensor:
         | 
| 30 | 
            +
                    x12 = self.w12(x)
         | 
| 31 | 
            +
                    x1, x2 = x12.chunk(2, dim=-1)
         | 
| 32 | 
            +
                    hidden = F.silu(x1) * x2
         | 
| 33 | 
            +
                    return self.w3(hidden)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
         | 
| 37 | 
            +
            try:
         | 
| 38 | 
            +
                if XFORMERS_ENABLED:
         | 
| 39 | 
            +
                    from xformers.ops import SwiGLU
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    XFORMERS_AVAILABLE = True
         | 
| 42 | 
            +
                    # warnings.warn("xFormers is available (SwiGLU)")
         | 
| 43 | 
            +
                else:
         | 
| 44 | 
            +
                    # warnings.warn("xFormers is disabled (SwiGLU)")
         | 
| 45 | 
            +
                    raise ImportError
         | 
| 46 | 
            +
            except ImportError:
         | 
| 47 | 
            +
                SwiGLU = SwiGLUFFN
         | 
| 48 | 
            +
                XFORMERS_AVAILABLE = False
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # warnings.warn("xFormers is not available (SwiGLU)")
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class SwiGLUFFNFused(SwiGLU):
         | 
| 54 | 
            +
                def __init__(
         | 
| 55 | 
            +
                    self,
         | 
| 56 | 
            +
                    in_features: int,
         | 
| 57 | 
            +
                    hidden_features: Optional[int] = None,
         | 
| 58 | 
            +
                    out_features: Optional[int] = None,
         | 
| 59 | 
            +
                    act_layer: Callable[..., nn.Module] = None,
         | 
| 60 | 
            +
                    drop: float = 0.0,
         | 
| 61 | 
            +
                    bias: bool = True,
         | 
| 62 | 
            +
                ) -> None:
         | 
| 63 | 
            +
                    out_features = out_features or in_features
         | 
| 64 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 65 | 
            +
                    hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
         | 
| 66 | 
            +
                    super().__init__(
         | 
| 67 | 
            +
                        in_features=in_features,
         | 
| 68 | 
            +
                        hidden_features=hidden_features,
         | 
| 69 | 
            +
                        out_features=out_features,
         | 
| 70 | 
            +
                        bias=bias,
         | 
| 71 | 
            +
                    )
         | 
    	
        mapanything/models/external/dinov2/models/__init__.py
    ADDED
    
    | @@ -0,0 +1,44 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import mapanything.models.external.dinov2.models.vision_transformer as vits
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            logger = logging.getLogger("dinov2")
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def build_model(args, only_teacher=False, img_size=224):
         | 
| 14 | 
            +
                args.arch = args.arch.removesuffix("_memeff")
         | 
| 15 | 
            +
                if "vit" in args.arch:
         | 
| 16 | 
            +
                    vit_kwargs = dict(
         | 
| 17 | 
            +
                        img_size=img_size,
         | 
| 18 | 
            +
                        patch_size=args.patch_size,
         | 
| 19 | 
            +
                        init_values=args.layerscale,
         | 
| 20 | 
            +
                        ffn_layer=args.ffn_layer,
         | 
| 21 | 
            +
                        block_chunks=args.block_chunks,
         | 
| 22 | 
            +
                        qkv_bias=args.qkv_bias,
         | 
| 23 | 
            +
                        proj_bias=args.proj_bias,
         | 
| 24 | 
            +
                        ffn_bias=args.ffn_bias,
         | 
| 25 | 
            +
                        num_register_tokens=args.num_register_tokens,
         | 
| 26 | 
            +
                        interpolate_offset=args.interpolate_offset,
         | 
| 27 | 
            +
                        interpolate_antialias=args.interpolate_antialias,
         | 
| 28 | 
            +
                    )
         | 
| 29 | 
            +
                    teacher = vits.__dict__[args.arch](**vit_kwargs)
         | 
| 30 | 
            +
                    if only_teacher:
         | 
| 31 | 
            +
                        return teacher, teacher.embed_dim
         | 
| 32 | 
            +
                    student = vits.__dict__[args.arch](
         | 
| 33 | 
            +
                        **vit_kwargs,
         | 
| 34 | 
            +
                        drop_path_rate=args.drop_path_rate,
         | 
| 35 | 
            +
                        drop_path_uniform=args.drop_path_uniform,
         | 
| 36 | 
            +
                    )
         | 
| 37 | 
            +
                    embed_dim = student.embed_dim
         | 
| 38 | 
            +
                return student, teacher, embed_dim
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def build_model_from_cfg(cfg, only_teacher=False):
         | 
| 42 | 
            +
                return build_model(
         | 
| 43 | 
            +
                    cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size
         | 
| 44 | 
            +
                )
         | 
    	
        mapanything/models/external/dinov2/models/vision_transformer.py
    ADDED
    
    | @@ -0,0 +1,448 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # References:
         | 
| 7 | 
            +
            #   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
         | 
| 8 | 
            +
            #   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import math
         | 
| 11 | 
            +
            from functools import partial
         | 
| 12 | 
            +
            from typing import Callable, Sequence, Tuple, Union
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.nn as nn
         | 
| 16 | 
            +
            from torch.nn.init import trunc_normal_
         | 
| 17 | 
            +
            from torch.utils.checkpoint import checkpoint
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from mapanything.models.external.dinov2.layers import (
         | 
| 20 | 
            +
                MemEffAttention,
         | 
| 21 | 
            +
                Mlp,
         | 
| 22 | 
            +
                NestedTensorBlock as Block,
         | 
| 23 | 
            +
                PatchEmbed,
         | 
| 24 | 
            +
                SwiGLUFFNFused,
         | 
| 25 | 
            +
            )
         | 
| 26 | 
            +
            from mapanything.models.external.pi3.layers.attention import FlashAttention
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            # logger = logging.getLogger("dinov2")
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def named_apply(
         | 
| 32 | 
            +
                fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
         | 
| 33 | 
            +
            ) -> nn.Module:
         | 
| 34 | 
            +
                if not depth_first and include_root:
         | 
| 35 | 
            +
                    fn(module=module, name=name)
         | 
| 36 | 
            +
                for child_name, child_module in module.named_children():
         | 
| 37 | 
            +
                    child_name = ".".join((name, child_name)) if name else child_name
         | 
| 38 | 
            +
                    named_apply(
         | 
| 39 | 
            +
                        fn=fn,
         | 
| 40 | 
            +
                        module=child_module,
         | 
| 41 | 
            +
                        name=child_name,
         | 
| 42 | 
            +
                        depth_first=depth_first,
         | 
| 43 | 
            +
                        include_root=True,
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
                if depth_first and include_root:
         | 
| 46 | 
            +
                    fn(module=module, name=name)
         | 
| 47 | 
            +
                return module
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class BlockChunk(nn.ModuleList):
         | 
| 51 | 
            +
                def forward(self, x):
         | 
| 52 | 
            +
                    for b in self:
         | 
| 53 | 
            +
                        x = b(x)
         | 
| 54 | 
            +
                    return x
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class DinoVisionTransformer(nn.Module):
         | 
| 58 | 
            +
                def __init__(
         | 
| 59 | 
            +
                    self,
         | 
| 60 | 
            +
                    img_size=224,
         | 
| 61 | 
            +
                    patch_size=16,
         | 
| 62 | 
            +
                    in_chans=3,
         | 
| 63 | 
            +
                    embed_dim=768,
         | 
| 64 | 
            +
                    depth=12,
         | 
| 65 | 
            +
                    num_heads=12,
         | 
| 66 | 
            +
                    mlp_ratio=4.0,
         | 
| 67 | 
            +
                    qkv_bias=True,
         | 
| 68 | 
            +
                    ffn_bias=True,
         | 
| 69 | 
            +
                    proj_bias=True,
         | 
| 70 | 
            +
                    drop_path_rate=0.0,
         | 
| 71 | 
            +
                    drop_path_uniform=False,
         | 
| 72 | 
            +
                    init_values=None,  # for layerscale: None or 0 => no layerscale
         | 
| 73 | 
            +
                    embed_layer=PatchEmbed,
         | 
| 74 | 
            +
                    act_layer=nn.GELU,
         | 
| 75 | 
            +
                    block_fn=Block,
         | 
| 76 | 
            +
                    ffn_layer="mlp",
         | 
| 77 | 
            +
                    block_chunks=1,
         | 
| 78 | 
            +
                    num_register_tokens=0,
         | 
| 79 | 
            +
                    interpolate_antialias=False,
         | 
| 80 | 
            +
                    interpolate_offset=0.1,
         | 
| 81 | 
            +
                ):
         | 
| 82 | 
            +
                    """
         | 
| 83 | 
            +
                    Args:
         | 
| 84 | 
            +
                        img_size (int, tuple): input image size
         | 
| 85 | 
            +
                        patch_size (int, tuple): patch size
         | 
| 86 | 
            +
                        in_chans (int): number of input channels
         | 
| 87 | 
            +
                        embed_dim (int): embedding dimension
         | 
| 88 | 
            +
                        depth (int): depth of transformer
         | 
| 89 | 
            +
                        num_heads (int): number of attention heads
         | 
| 90 | 
            +
                        mlp_ratio (int): ratio of mlp hidden dim to embedding dim
         | 
| 91 | 
            +
                        qkv_bias (bool): enable bias for qkv if True
         | 
| 92 | 
            +
                        proj_bias (bool): enable bias for proj in attn if True
         | 
| 93 | 
            +
                        ffn_bias (bool): enable bias for ffn if True
         | 
| 94 | 
            +
                        drop_path_rate (float): stochastic depth rate
         | 
| 95 | 
            +
                        drop_path_uniform (bool): apply uniform drop rate across blocks
         | 
| 96 | 
            +
                        weight_init (str): weight init scheme
         | 
| 97 | 
            +
                        init_values (float): layer-scale init values
         | 
| 98 | 
            +
                        embed_layer (nn.Module): patch embedding layer
         | 
| 99 | 
            +
                        act_layer (nn.Module): MLP activation layer
         | 
| 100 | 
            +
                        block_fn (nn.Module): transformer block class
         | 
| 101 | 
            +
                        ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
         | 
| 102 | 
            +
                        block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
         | 
| 103 | 
            +
                        num_register_tokens: (int) number of extra cls tokens (so-called "registers")
         | 
| 104 | 
            +
                        interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
         | 
| 105 | 
            +
                        interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
         | 
| 106 | 
            +
                    """
         | 
| 107 | 
            +
                    super().__init__()
         | 
| 108 | 
            +
                    norm_layer = partial(nn.LayerNorm, eps=1e-6)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    self.num_features = self.embed_dim = (
         | 
| 111 | 
            +
                        embed_dim  # num_features for consistency with other models
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
                    self.num_tokens = 1
         | 
| 114 | 
            +
                    self.n_blocks = depth
         | 
| 115 | 
            +
                    self.num_heads = num_heads
         | 
| 116 | 
            +
                    self.patch_size = patch_size
         | 
| 117 | 
            +
                    self.num_register_tokens = num_register_tokens
         | 
| 118 | 
            +
                    self.interpolate_antialias = interpolate_antialias
         | 
| 119 | 
            +
                    self.interpolate_offset = interpolate_offset
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    self.patch_embed = embed_layer(
         | 
| 122 | 
            +
                        img_size=img_size,
         | 
| 123 | 
            +
                        patch_size=patch_size,
         | 
| 124 | 
            +
                        in_chans=in_chans,
         | 
| 125 | 
            +
                        embed_dim=embed_dim,
         | 
| 126 | 
            +
                    )
         | 
| 127 | 
            +
                    num_patches = self.patch_embed.num_patches
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         | 
| 130 | 
            +
                    self.pos_embed = nn.Parameter(
         | 
| 131 | 
            +
                        torch.zeros(1, num_patches + self.num_tokens, embed_dim)
         | 
| 132 | 
            +
                    )
         | 
| 133 | 
            +
                    assert num_register_tokens >= 0
         | 
| 134 | 
            +
                    self.register_tokens = (
         | 
| 135 | 
            +
                        nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
         | 
| 136 | 
            +
                        if num_register_tokens
         | 
| 137 | 
            +
                        else None
         | 
| 138 | 
            +
                    )
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    if drop_path_uniform is True:
         | 
| 141 | 
            +
                        dpr = [drop_path_rate] * depth
         | 
| 142 | 
            +
                    else:
         | 
| 143 | 
            +
                        dpr = [
         | 
| 144 | 
            +
                            x.item() for x in torch.linspace(0, drop_path_rate, depth)
         | 
| 145 | 
            +
                        ]  # stochastic depth decay rule
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    if ffn_layer == "mlp":
         | 
| 148 | 
            +
                        # logger.info("using MLP layer as FFN")
         | 
| 149 | 
            +
                        ffn_layer = Mlp
         | 
| 150 | 
            +
                    elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
         | 
| 151 | 
            +
                        # logger.info("using SwiGLU layer as FFN")
         | 
| 152 | 
            +
                        ffn_layer = SwiGLUFFNFused
         | 
| 153 | 
            +
                    elif ffn_layer == "identity":
         | 
| 154 | 
            +
                        # logger.info("using Identity layer as FFN")
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                        def f(*args, **kwargs):
         | 
| 157 | 
            +
                            return nn.Identity()
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                        ffn_layer = f
         | 
| 160 | 
            +
                    else:
         | 
| 161 | 
            +
                        raise NotImplementedError
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    blocks_list = [
         | 
| 164 | 
            +
                        block_fn(
         | 
| 165 | 
            +
                            dim=embed_dim,
         | 
| 166 | 
            +
                            num_heads=num_heads,
         | 
| 167 | 
            +
                            mlp_ratio=mlp_ratio,
         | 
| 168 | 
            +
                            qkv_bias=qkv_bias,
         | 
| 169 | 
            +
                            proj_bias=proj_bias,
         | 
| 170 | 
            +
                            ffn_bias=ffn_bias,
         | 
| 171 | 
            +
                            drop_path=dpr[i],
         | 
| 172 | 
            +
                            norm_layer=norm_layer,
         | 
| 173 | 
            +
                            act_layer=act_layer,
         | 
| 174 | 
            +
                            ffn_layer=ffn_layer,
         | 
| 175 | 
            +
                            init_values=init_values,
         | 
| 176 | 
            +
                            attn_class=FlashAttention,
         | 
| 177 | 
            +
                        )
         | 
| 178 | 
            +
                        for i in range(depth)
         | 
| 179 | 
            +
                    ]
         | 
| 180 | 
            +
                    if block_chunks > 0:
         | 
| 181 | 
            +
                        self.chunked_blocks = True
         | 
| 182 | 
            +
                        chunked_blocks = []
         | 
| 183 | 
            +
                        chunksize = depth // block_chunks
         | 
| 184 | 
            +
                        for i in range(0, depth, chunksize):
         | 
| 185 | 
            +
                            # this is to keep the block index consistent if we chunk the block list
         | 
| 186 | 
            +
                            chunked_blocks.append(
         | 
| 187 | 
            +
                                [nn.Identity()] * i + blocks_list[i : i + chunksize]
         | 
| 188 | 
            +
                            )
         | 
| 189 | 
            +
                        self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
         | 
| 190 | 
            +
                    else:
         | 
| 191 | 
            +
                        self.chunked_blocks = False
         | 
| 192 | 
            +
                        self.blocks = nn.ModuleList(blocks_list)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    self.norm = norm_layer(embed_dim)
         | 
| 195 | 
            +
                    self.head = nn.Identity()
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    self.init_weights()
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def init_weights(self):
         | 
| 202 | 
            +
                    trunc_normal_(self.pos_embed, std=0.02)
         | 
| 203 | 
            +
                    nn.init.normal_(self.cls_token, std=1e-6)
         | 
| 204 | 
            +
                    if self.register_tokens is not None:
         | 
| 205 | 
            +
                        nn.init.normal_(self.register_tokens, std=1e-6)
         | 
| 206 | 
            +
                    named_apply(init_weights_vit_timm, self)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def interpolate_pos_encoding(self, x, w, h):
         | 
| 209 | 
            +
                    previous_dtype = x.dtype
         | 
| 210 | 
            +
                    npatch = x.shape[1] - 1
         | 
| 211 | 
            +
                    N = self.pos_embed.shape[1] - 1
         | 
| 212 | 
            +
                    if npatch == N and w == h:
         | 
| 213 | 
            +
                        return self.pos_embed
         | 
| 214 | 
            +
                    pos_embed = self.pos_embed.float()
         | 
| 215 | 
            +
                    class_pos_embed = pos_embed[:, 0]
         | 
| 216 | 
            +
                    patch_pos_embed = pos_embed[:, 1:]
         | 
| 217 | 
            +
                    dim = x.shape[-1]
         | 
| 218 | 
            +
                    w0 = w // self.patch_size
         | 
| 219 | 
            +
                    h0 = h // self.patch_size
         | 
| 220 | 
            +
                    M = int(math.sqrt(N))  # Recover the number of patches in each dimension
         | 
| 221 | 
            +
                    assert N == M * M
         | 
| 222 | 
            +
                    kwargs = {}
         | 
| 223 | 
            +
                    if self.interpolate_offset:
         | 
| 224 | 
            +
                        # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
         | 
| 225 | 
            +
                        # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
         | 
| 226 | 
            +
                        sx = float(w0 + self.interpolate_offset) / M
         | 
| 227 | 
            +
                        sy = float(h0 + self.interpolate_offset) / M
         | 
| 228 | 
            +
                        kwargs["scale_factor"] = (sx, sy)
         | 
| 229 | 
            +
                    else:
         | 
| 230 | 
            +
                        # Simply specify an output size instead of a scale factor
         | 
| 231 | 
            +
                        kwargs["size"] = (w0, h0)
         | 
| 232 | 
            +
                    patch_pos_embed = nn.functional.interpolate(
         | 
| 233 | 
            +
                        patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
         | 
| 234 | 
            +
                        mode="bicubic",
         | 
| 235 | 
            +
                        antialias=self.interpolate_antialias,
         | 
| 236 | 
            +
                        **kwargs,
         | 
| 237 | 
            +
                    )
         | 
| 238 | 
            +
                    assert (w0, h0) == patch_pos_embed.shape[-2:]
         | 
| 239 | 
            +
                    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         | 
| 240 | 
            +
                    return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
         | 
| 241 | 
            +
                        previous_dtype
         | 
| 242 | 
            +
                    )
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def prepare_tokens_with_masks(self, x, masks=None):
         | 
| 245 | 
            +
                    B, nc, w, h = x.shape
         | 
| 246 | 
            +
                    x = self.patch_embed(x)
         | 
| 247 | 
            +
                    if masks is not None:
         | 
| 248 | 
            +
                        x = torch.where(
         | 
| 249 | 
            +
                            masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
         | 
| 250 | 
            +
                        )
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
         | 
| 253 | 
            +
                    x = x + self.interpolate_pos_encoding(x, w, h)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    if self.register_tokens is not None:
         | 
| 256 | 
            +
                        x = torch.cat(
         | 
| 257 | 
            +
                            (
         | 
| 258 | 
            +
                                x[:, :1],
         | 
| 259 | 
            +
                                self.register_tokens.expand(x.shape[0], -1, -1),
         | 
| 260 | 
            +
                                x[:, 1:],
         | 
| 261 | 
            +
                            ),
         | 
| 262 | 
            +
                            dim=1,
         | 
| 263 | 
            +
                        )
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    return x
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def forward_features_list(self, x_list, masks_list):
         | 
| 268 | 
            +
                    x = [
         | 
| 269 | 
            +
                        self.prepare_tokens_with_masks(x, masks)
         | 
| 270 | 
            +
                        for x, masks in zip(x_list, masks_list)
         | 
| 271 | 
            +
                    ]
         | 
| 272 | 
            +
                    for blk in self.blocks:
         | 
| 273 | 
            +
                        if self.training:
         | 
| 274 | 
            +
                            x = checkpoint(blk, x, use_reentrant=False)
         | 
| 275 | 
            +
                        else:
         | 
| 276 | 
            +
                            x = blk(x)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    all_x = x
         | 
| 279 | 
            +
                    output = []
         | 
| 280 | 
            +
                    for x, masks in zip(all_x, masks_list):
         | 
| 281 | 
            +
                        x_norm = self.norm(x)
         | 
| 282 | 
            +
                        output.append(
         | 
| 283 | 
            +
                            {
         | 
| 284 | 
            +
                                "x_norm_clstoken": x_norm[:, 0],
         | 
| 285 | 
            +
                                "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
         | 
| 286 | 
            +
                                "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
         | 
| 287 | 
            +
                                "x_prenorm": x,
         | 
| 288 | 
            +
                                "masks": masks,
         | 
| 289 | 
            +
                            }
         | 
| 290 | 
            +
                        )
         | 
| 291 | 
            +
                    return output
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                def forward_features(self, x, masks=None):
         | 
| 294 | 
            +
                    if isinstance(x, list):
         | 
| 295 | 
            +
                        return self.forward_features_list(x, masks)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    x = self.prepare_tokens_with_masks(x, masks)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    for blk in self.blocks:
         | 
| 300 | 
            +
                        if self.training:
         | 
| 301 | 
            +
                            x = checkpoint(blk, x, use_reentrant=False)
         | 
| 302 | 
            +
                        else:
         | 
| 303 | 
            +
                            x = blk(x)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    x_norm = self.norm(x)
         | 
| 306 | 
            +
                    return {
         | 
| 307 | 
            +
                        "x_norm_clstoken": x_norm[:, 0],
         | 
| 308 | 
            +
                        "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
         | 
| 309 | 
            +
                        "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
         | 
| 310 | 
            +
                        "x_prenorm": x,
         | 
| 311 | 
            +
                        "masks": masks,
         | 
| 312 | 
            +
                    }
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                def _get_intermediate_layers_not_chunked(self, x, n=1):
         | 
| 315 | 
            +
                    x = self.prepare_tokens_with_masks(x)
         | 
| 316 | 
            +
                    # If n is an int, take the n last blocks. If it's a list, take them
         | 
| 317 | 
            +
                    output, total_block_len = [], len(self.blocks)
         | 
| 318 | 
            +
                    blocks_to_take = (
         | 
| 319 | 
            +
                        range(total_block_len - n, total_block_len) if isinstance(n, int) else n
         | 
| 320 | 
            +
                    )
         | 
| 321 | 
            +
                    for i, blk in enumerate(self.blocks):
         | 
| 322 | 
            +
                        x = blk(x)
         | 
| 323 | 
            +
                        if i in blocks_to_take:
         | 
| 324 | 
            +
                            output.append(x)
         | 
| 325 | 
            +
                    assert len(output) == len(blocks_to_take), (
         | 
| 326 | 
            +
                        f"only {len(output)} / {len(blocks_to_take)} blocks found"
         | 
| 327 | 
            +
                    )
         | 
| 328 | 
            +
                    return output
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def _get_intermediate_layers_chunked(self, x, n=1):
         | 
| 331 | 
            +
                    x = self.prepare_tokens_with_masks(x)
         | 
| 332 | 
            +
                    output, i, total_block_len = [], 0, len(self.blocks[-1])
         | 
| 333 | 
            +
                    # If n is an int, take the n last blocks. If it's a list, take them
         | 
| 334 | 
            +
                    blocks_to_take = (
         | 
| 335 | 
            +
                        range(total_block_len - n, total_block_len) if isinstance(n, int) else n
         | 
| 336 | 
            +
                    )
         | 
| 337 | 
            +
                    for block_chunk in self.blocks:
         | 
| 338 | 
            +
                        for blk in block_chunk[i:]:  # Passing the nn.Identity()
         | 
| 339 | 
            +
                            x = blk(x)
         | 
| 340 | 
            +
                            if i in blocks_to_take:
         | 
| 341 | 
            +
                                output.append(x)
         | 
| 342 | 
            +
                            i += 1
         | 
| 343 | 
            +
                    assert len(output) == len(blocks_to_take), (
         | 
| 344 | 
            +
                        f"only {len(output)} / {len(blocks_to_take)} blocks found"
         | 
| 345 | 
            +
                    )
         | 
| 346 | 
            +
                    return output
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                def get_intermediate_layers(
         | 
| 349 | 
            +
                    self,
         | 
| 350 | 
            +
                    x: torch.Tensor,
         | 
| 351 | 
            +
                    n: Union[int, Sequence] = 1,  # Layers or n last layers to take
         | 
| 352 | 
            +
                    reshape: bool = False,
         | 
| 353 | 
            +
                    return_class_token: bool = False,
         | 
| 354 | 
            +
                    norm=True,
         | 
| 355 | 
            +
                ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
         | 
| 356 | 
            +
                    if self.chunked_blocks:
         | 
| 357 | 
            +
                        outputs = self._get_intermediate_layers_chunked(x, n)
         | 
| 358 | 
            +
                    else:
         | 
| 359 | 
            +
                        outputs = self._get_intermediate_layers_not_chunked(x, n)
         | 
| 360 | 
            +
                    if norm:
         | 
| 361 | 
            +
                        outputs = [self.norm(out) for out in outputs]
         | 
| 362 | 
            +
                    class_tokens = [out[:, 0] for out in outputs]
         | 
| 363 | 
            +
                    outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
         | 
| 364 | 
            +
                    if reshape:
         | 
| 365 | 
            +
                        B, _, w, h = x.shape
         | 
| 366 | 
            +
                        outputs = [
         | 
| 367 | 
            +
                            out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
         | 
| 368 | 
            +
                            .permute(0, 3, 1, 2)
         | 
| 369 | 
            +
                            .contiguous()
         | 
| 370 | 
            +
                            for out in outputs
         | 
| 371 | 
            +
                        ]
         | 
| 372 | 
            +
                    if return_class_token:
         | 
| 373 | 
            +
                        return tuple(zip(outputs, class_tokens))
         | 
| 374 | 
            +
                    return tuple(outputs)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                def forward(self, *args, is_training=False, **kwargs):
         | 
| 377 | 
            +
                    ret = self.forward_features(*args, **kwargs)
         | 
| 378 | 
            +
                    if is_training:
         | 
| 379 | 
            +
                        return ret
         | 
| 380 | 
            +
                    else:
         | 
| 381 | 
            +
                        return self.head(ret["x_norm_clstoken"])
         | 
| 382 | 
            +
             | 
| 383 | 
            +
             | 
| 384 | 
            +
            def init_weights_vit_timm(module: nn.Module, name: str = ""):
         | 
| 385 | 
            +
                """ViT weight initialization, original timm impl (for reproducibility)"""
         | 
| 386 | 
            +
                if isinstance(module, nn.Linear):
         | 
| 387 | 
            +
                    trunc_normal_(module.weight, std=0.02)
         | 
| 388 | 
            +
                    if module.bias is not None:
         | 
| 389 | 
            +
                        nn.init.zeros_(module.bias)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
             | 
| 392 | 
            +
            def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
         | 
| 393 | 
            +
                model = DinoVisionTransformer(
         | 
| 394 | 
            +
                    patch_size=patch_size,
         | 
| 395 | 
            +
                    embed_dim=384,
         | 
| 396 | 
            +
                    depth=12,
         | 
| 397 | 
            +
                    num_heads=6,
         | 
| 398 | 
            +
                    mlp_ratio=4,
         | 
| 399 | 
            +
                    block_fn=partial(Block, attn_class=MemEffAttention),
         | 
| 400 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 401 | 
            +
                    **kwargs,
         | 
| 402 | 
            +
                )
         | 
| 403 | 
            +
                return model
         | 
| 404 | 
            +
             | 
| 405 | 
            +
             | 
| 406 | 
            +
            def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
         | 
| 407 | 
            +
                model = DinoVisionTransformer(
         | 
| 408 | 
            +
                    patch_size=patch_size,
         | 
| 409 | 
            +
                    embed_dim=768,
         | 
| 410 | 
            +
                    depth=12,
         | 
| 411 | 
            +
                    num_heads=12,
         | 
| 412 | 
            +
                    mlp_ratio=4,
         | 
| 413 | 
            +
                    block_fn=partial(Block, attn_class=MemEffAttention),
         | 
| 414 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 415 | 
            +
                    **kwargs,
         | 
| 416 | 
            +
                )
         | 
| 417 | 
            +
                return model
         | 
| 418 | 
            +
             | 
| 419 | 
            +
             | 
| 420 | 
            +
            def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
         | 
| 421 | 
            +
                model = DinoVisionTransformer(
         | 
| 422 | 
            +
                    patch_size=patch_size,
         | 
| 423 | 
            +
                    embed_dim=1024,
         | 
| 424 | 
            +
                    depth=24,
         | 
| 425 | 
            +
                    num_heads=16,
         | 
| 426 | 
            +
                    mlp_ratio=4,
         | 
| 427 | 
            +
                    block_fn=partial(Block, attn_class=MemEffAttention),
         | 
| 428 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 429 | 
            +
                    **kwargs,
         | 
| 430 | 
            +
                )
         | 
| 431 | 
            +
                return model
         | 
| 432 | 
            +
             | 
| 433 | 
            +
             | 
| 434 | 
            +
            def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
         | 
| 435 | 
            +
                """
         | 
| 436 | 
            +
                Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
         | 
| 437 | 
            +
                """
         | 
| 438 | 
            +
                model = DinoVisionTransformer(
         | 
| 439 | 
            +
                    patch_size=patch_size,
         | 
| 440 | 
            +
                    embed_dim=1536,
         | 
| 441 | 
            +
                    depth=40,
         | 
| 442 | 
            +
                    num_heads=24,
         | 
| 443 | 
            +
                    mlp_ratio=4,
         | 
| 444 | 
            +
                    block_fn=partial(Block, attn_class=MemEffAttention),
         | 
| 445 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 446 | 
            +
                    **kwargs,
         | 
| 447 | 
            +
                )
         | 
| 448 | 
            +
                return model
         | 
    	
        mapanything/models/external/dinov2/utils/__init__.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
    	
        mapanything/models/external/dinov2/utils/cluster.py
    ADDED
    
    | @@ -0,0 +1,102 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            from enum import Enum
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
            from typing import Any, Dict, Optional
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class ClusterType(Enum):
         | 
| 13 | 
            +
                AWS = "aws"
         | 
| 14 | 
            +
                FAIR = "fair"
         | 
| 15 | 
            +
                RSC = "rsc"
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def _guess_cluster_type() -> ClusterType:
         | 
| 19 | 
            +
                uname = os.uname()
         | 
| 20 | 
            +
                if uname.sysname == "Linux":
         | 
| 21 | 
            +
                    if uname.release.endswith("-aws"):
         | 
| 22 | 
            +
                        # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
         | 
| 23 | 
            +
                        return ClusterType.AWS
         | 
| 24 | 
            +
                    elif uname.nodename.startswith("rsc"):
         | 
| 25 | 
            +
                        # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
         | 
| 26 | 
            +
                        return ClusterType.RSC
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                return ClusterType.FAIR
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def get_cluster_type(
         | 
| 32 | 
            +
                cluster_type: Optional[ClusterType] = None,
         | 
| 33 | 
            +
            ) -> Optional[ClusterType]:
         | 
| 34 | 
            +
                if cluster_type is None:
         | 
| 35 | 
            +
                    return _guess_cluster_type()
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                return cluster_type
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
         | 
| 41 | 
            +
                cluster_type = get_cluster_type(cluster_type)
         | 
| 42 | 
            +
                if cluster_type is None:
         | 
| 43 | 
            +
                    return None
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                CHECKPOINT_DIRNAMES = {
         | 
| 46 | 
            +
                    ClusterType.AWS: "checkpoints",
         | 
| 47 | 
            +
                    ClusterType.FAIR: "checkpoint",
         | 
| 48 | 
            +
                    ClusterType.RSC: "checkpoint/dino",
         | 
| 49 | 
            +
                }
         | 
| 50 | 
            +
                return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def get_user_checkpoint_path(
         | 
| 54 | 
            +
                cluster_type: Optional[ClusterType] = None,
         | 
| 55 | 
            +
            ) -> Optional[Path]:
         | 
| 56 | 
            +
                checkpoint_path = get_checkpoint_path(cluster_type)
         | 
| 57 | 
            +
                if checkpoint_path is None:
         | 
| 58 | 
            +
                    return None
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                username = os.environ.get("USER")
         | 
| 61 | 
            +
                assert username is not None
         | 
| 62 | 
            +
                return checkpoint_path / username
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
         | 
| 66 | 
            +
                cluster_type = get_cluster_type(cluster_type)
         | 
| 67 | 
            +
                if cluster_type is None:
         | 
| 68 | 
            +
                    return None
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                SLURM_PARTITIONS = {
         | 
| 71 | 
            +
                    ClusterType.AWS: "learnlab",
         | 
| 72 | 
            +
                    ClusterType.FAIR: "learnlab",
         | 
| 73 | 
            +
                    ClusterType.RSC: "learn",
         | 
| 74 | 
            +
                }
         | 
| 75 | 
            +
                return SLURM_PARTITIONS[cluster_type]
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def get_slurm_executor_parameters(
         | 
| 79 | 
            +
                nodes: int,
         | 
| 80 | 
            +
                num_gpus_per_node: int,
         | 
| 81 | 
            +
                cluster_type: Optional[ClusterType] = None,
         | 
| 82 | 
            +
                **kwargs,
         | 
| 83 | 
            +
            ) -> Dict[str, Any]:
         | 
| 84 | 
            +
                # create default parameters
         | 
| 85 | 
            +
                params = {
         | 
| 86 | 
            +
                    "mem_gb": 0,  # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
         | 
| 87 | 
            +
                    "gpus_per_node": num_gpus_per_node,
         | 
| 88 | 
            +
                    "tasks_per_node": num_gpus_per_node,  # one task per GPU
         | 
| 89 | 
            +
                    "cpus_per_task": 10,
         | 
| 90 | 
            +
                    "nodes": nodes,
         | 
| 91 | 
            +
                    "slurm_partition": get_slurm_partition(cluster_type),
         | 
| 92 | 
            +
                }
         | 
| 93 | 
            +
                # apply cluster-specific adjustments
         | 
| 94 | 
            +
                cluster_type = get_cluster_type(cluster_type)
         | 
| 95 | 
            +
                if cluster_type == ClusterType.AWS:
         | 
| 96 | 
            +
                    params["cpus_per_task"] = 12
         | 
| 97 | 
            +
                    del params["mem_gb"]
         | 
| 98 | 
            +
                elif cluster_type == ClusterType.RSC:
         | 
| 99 | 
            +
                    params["cpus_per_task"] = 12
         | 
| 100 | 
            +
                # set additional parameters / apply overrides
         | 
| 101 | 
            +
                params.update(kwargs)
         | 
| 102 | 
            +
                return params
         | 
    	
        mapanything/models/external/dinov2/utils/config.py
    ADDED
    
    | @@ -0,0 +1,74 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
            import math
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import dinov2.distributed as distributed
         | 
| 11 | 
            +
            from dinov2.configs import dinov2_default_config
         | 
| 12 | 
            +
            from dinov2.logging import setup_logging
         | 
| 13 | 
            +
            from dinov2.utils import utils
         | 
| 14 | 
            +
            from omegaconf import OmegaConf
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            logger = logging.getLogger("dinov2")
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def apply_scaling_rules_to_cfg(cfg):  # to fix
         | 
| 20 | 
            +
                if cfg.optim.scaling_rule == "sqrt_wrt_1024":
         | 
| 21 | 
            +
                    base_lr = cfg.optim.base_lr
         | 
| 22 | 
            +
                    cfg.optim.lr = base_lr
         | 
| 23 | 
            +
                    cfg.optim.lr *= math.sqrt(
         | 
| 24 | 
            +
                        cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0
         | 
| 25 | 
            +
                    )
         | 
| 26 | 
            +
                    logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
         | 
| 27 | 
            +
                else:
         | 
| 28 | 
            +
                    raise NotImplementedError
         | 
| 29 | 
            +
                return cfg
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def write_config(cfg, output_dir, name="config.yaml"):
         | 
| 33 | 
            +
                logger.info(OmegaConf.to_yaml(cfg))
         | 
| 34 | 
            +
                saved_cfg_path = os.path.join(output_dir, name)
         | 
| 35 | 
            +
                with open(saved_cfg_path, "w") as f:
         | 
| 36 | 
            +
                    OmegaConf.save(config=cfg, f=f)
         | 
| 37 | 
            +
                return saved_cfg_path
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def get_cfg_from_args(args):
         | 
| 41 | 
            +
                args.output_dir = os.path.abspath(args.output_dir)
         | 
| 42 | 
            +
                args.opts += [f"train.output_dir={args.output_dir}"]
         | 
| 43 | 
            +
                default_cfg = OmegaConf.create(dinov2_default_config)
         | 
| 44 | 
            +
                cfg = OmegaConf.load(args.config_file)
         | 
| 45 | 
            +
                cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
         | 
| 46 | 
            +
                return cfg
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def default_setup(args):
         | 
| 50 | 
            +
                distributed.enable(overwrite=True)
         | 
| 51 | 
            +
                seed = getattr(args, "seed", 0)
         | 
| 52 | 
            +
                rank = distributed.get_global_rank()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                global logger
         | 
| 55 | 
            +
                setup_logging(output=args.output_dir, level=logging.INFO)
         | 
| 56 | 
            +
                logger = logging.getLogger("dinov2")
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                utils.fix_random_seeds(seed + rank)
         | 
| 59 | 
            +
                logger.info("git:\n  {}\n".format(utils.get_sha()))
         | 
| 60 | 
            +
                logger.info(
         | 
| 61 | 
            +
                    "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
         | 
| 62 | 
            +
                )
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def setup(args):
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                Create configs and perform basic setups.
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                cfg = get_cfg_from_args(args)
         | 
| 70 | 
            +
                os.makedirs(args.output_dir, exist_ok=True)
         | 
| 71 | 
            +
                default_setup(args)
         | 
| 72 | 
            +
                apply_scaling_rules_to_cfg(cfg)
         | 
| 73 | 
            +
                write_config(cfg, args.output_dir)
         | 
| 74 | 
            +
                return cfg
         | 
    	
        mapanything/models/external/dinov2/utils/dtype.py
    ADDED
    
    | @@ -0,0 +1,38 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            from typing import Dict, Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            TypeSpec = Union[str, np.dtype, torch.dtype]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
         | 
| 16 | 
            +
                np.dtype("bool"): torch.bool,
         | 
| 17 | 
            +
                np.dtype("uint8"): torch.uint8,
         | 
| 18 | 
            +
                np.dtype("int8"): torch.int8,
         | 
| 19 | 
            +
                np.dtype("int16"): torch.int16,
         | 
| 20 | 
            +
                np.dtype("int32"): torch.int32,
         | 
| 21 | 
            +
                np.dtype("int64"): torch.int64,
         | 
| 22 | 
            +
                np.dtype("float16"): torch.float16,
         | 
| 23 | 
            +
                np.dtype("float32"): torch.float32,
         | 
| 24 | 
            +
                np.dtype("float64"): torch.float64,
         | 
| 25 | 
            +
                np.dtype("complex64"): torch.complex64,
         | 
| 26 | 
            +
                np.dtype("complex128"): torch.complex128,
         | 
| 27 | 
            +
            }
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
         | 
| 31 | 
            +
                if isinstance(dtype, torch.dtype):
         | 
| 32 | 
            +
                    return dtype
         | 
| 33 | 
            +
                if isinstance(dtype, str):
         | 
| 34 | 
            +
                    dtype = np.dtype(dtype)
         | 
| 35 | 
            +
                assert isinstance(dtype, np.dtype), (
         | 
| 36 | 
            +
                    f"Expected an instance of nunpy dtype, got {type(dtype)}"
         | 
| 37 | 
            +
                )
         | 
| 38 | 
            +
                return _NUMPY_TO_TORCH_DTYPE[dtype]
         | 
