Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,506 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
"""Grounding DINO 3D loss."""
from __future__ import annotations
from ml_collections import ConfigDict
from vis4d.config import class_config
from vis4d.engine.connectors import LossConnector
from vis4d.engine.loss_module import LossModule
from opendet3d.op.detect3d.grounding_dino_3d import (
GroundingDINO3DLoss,
)
from opendet3d.op.loss.silog_loss import SILogLoss
from opendet3d.zoo.gdino3d.base.connector import (
CONN_DEPTH_LOSS,
CONN_GDINO3D_LOSS,
)
def get_loss_cfg(
params: ConfigDict,
box_coder: ConfigDict,
aux_depth_loss: bool = False,
) -> ConfigDict:
"""Returns the loss configuration."""
# Box 3D loss
box3d_loss = {
"loss": class_config(
GroundingDINO3DLoss,
box_coder=box_coder,
loss_center_weight=params.loss_center_weight,
loss_depth_weight=params.loss_depth_weight,
loss_dim_weight=params.loss_dim_weight,
loss_rot_weight=params.loss_rot_weight,
),
"connector": class_config(
LossConnector, key_mapping=CONN_GDINO3D_LOSS
),
}
losses = [box3d_loss]
# Auxiliary depth loss
if aux_depth_loss:
depth_loss = {
"loss": class_config(SILogLoss),
"weight": params.si_log_weight,
"connector": class_config(
LossConnector, key_mapping=CONN_DEPTH_LOSS
),
}
losses.append(depth_loss)
return class_config(LossModule, losses=losses)
|