Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,824 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
"""3D-MOOD with Swin-B."""
from __future__ import annotations
from vis4d.config import class_config
from vis4d.config.typing import ExperimentConfig
from vis4d.data.io.hdf5 import HDF5Backend
from vis4d.zoo.base import get_default_cfg
from opendet3d.data.datasets.scannet import scannet200_det_map
from opendet3d.zoo.gdino3d.base.callback import get_callback_cfg
from opendet3d.zoo.gdino3d.base.connector import get_data_connector_cfg
from opendet3d.zoo.gdino3d.base.data import get_data_cfg
from opendet3d.zoo.gdino3d.base.dataset.omni3d import get_omni3d_train_cfg
from opendet3d.zoo.gdino3d.base.dataset.open import get_scannet_data_cfg
from opendet3d.zoo.gdino3d.base.loss import get_loss_cfg
from opendet3d.zoo.gdino3d.base.model import (
get_gdino3d_hyperparams_cfg,
get_gdino3d_swin_base_cfg,
)
from opendet3d.zoo.gdino3d.base.optim import get_optim_cfg
from opendet3d.zoo.gdino3d.base.pl import get_pl_cfg
def get_config() -> ExperimentConfig:
"""Returns the 3D-MOOD with Swin-B."""
######################################################
## General Config ##
######################################################
config = get_default_cfg(exp_name="gdino3d_swin-b_scannet200")
config.use_checkpoint = True
# High level hyper parameters
params = get_gdino3d_hyperparams_cfg()
config.params = params
######################################################
## Datasets with augmentations ##
######################################################
data_backend = class_config(HDF5Backend)
test_datasets_cfg = []
# Omni3D
omni3d_data_root = "data/omni3d"
omni3d_train_data_cfg = get_omni3d_train_cfg(
data_root=omni3d_data_root, data_backend=data_backend
)
# Open Datasets
test_datasets_cfg += [
get_scannet_data_cfg(data_backend=data_backend, scannet200=True),
]
config.data = get_data_cfg(
train_datasets=omni3d_train_data_cfg,
test_datasets=test_datasets_cfg,
samples_per_gpu=params.samples_per_gpu,
workers_per_gpu=params.workers_per_gpu,
)
######################################################
## MODEL & LOSS ##
######################################################
config.model, box_coder = get_gdino3d_swin_base_cfg(
params=params,
pretrained="mm_gdino_swin_base_all",
chunked_size=20,
cat_mapping=scannet200_det_map,
use_checkpoint=config.use_checkpoint,
)
config.loss = get_loss_cfg(params, box_coder, aux_depth_loss=True)
######################################################
## OPTIMIZERS ##
######################################################
config.optimizers = get_optim_cfg(params)
######################################################
## DATA CONNECTOR ##
######################################################
config.train_data_connector, config.test_data_connector = (
get_data_connector_cfg()
)
######################################################
## CALLBACKS ##
######################################################
# Open Detect3D Evaluator
open_test_datasets = ["ScanNet200_val"]
callbacks = get_callback_cfg(
output_dir=config.output_dir, open_test_datasets=open_test_datasets
)
config.callbacks = callbacks
######################################################
## PL CLI ##
######################################################
config.pl_trainer = get_pl_cfg(config, params)
return config.value_mode()
|