fix enforced bf16 data type on SM75 and lower devices
Browse files- modeling_dots_vision.py +7 -3
modeling_dots_vision.py
CHANGED
|
@@ -489,9 +489,13 @@ class DotsVisionTransformer(PreTrainedModel):
|
|
| 489 |
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
| 490 |
return rotary_pos_emb
|
| 491 |
|
| 492 |
-
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=
|
| 493 |
-
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
| 496 |
|
| 497 |
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
|
|
| 489 |
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
| 490 |
return rotary_pos_emb
|
| 491 |
|
| 492 |
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=None) -> torch.Tensor:
|
| 493 |
+
# 尝试修复SM75及之前不支持BF16设备的报错
|
| 494 |
+
# 若未显式指定 bf16,则根据权重 dtype 推断
|
| 495 |
+
if bf16 is None:
|
| 496 |
+
bf16 = (self.dtype == torch.bfloat16)
|
| 497 |
+
# 始终将输入显式对齐到本模块的计算精度,避免 input/bias dtype 不一致
|
| 498 |
+
hidden_states = hidden_states.to(torch.bfloat16 if bf16 else self.dtype)
|
| 499 |
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
| 500 |
|
| 501 |
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|