Spaces:
Running
on
Zero
Running
on
Zero
Update videomind/model/model.py
Browse files- videomind/model/model.py +5 -1
videomind/model/model.py
CHANGED
|
@@ -18,6 +18,10 @@ from .generator import PointGenerator
|
|
| 18 |
from .loss import BundleLoss
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
class AgentQwen2VLConfig(Qwen2VLConfig):
|
| 22 |
model_type = 'agent_qwen2_vl'
|
| 23 |
|
|
@@ -52,7 +56,7 @@ class AgentQwen2VLModel(Qwen2VLModel):
|
|
| 52 |
|
| 53 |
def __init__(self, config):
|
| 54 |
super().__init__(config)
|
| 55 |
-
self.norm.register_forward_pre_hook(
|
| 56 |
|
| 57 |
def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
|
| 58 |
# ensure gradient tracking (in case that embed_tokens has been frozen)
|
|
|
|
| 18 |
from .loss import BundleLoss
|
| 19 |
|
| 20 |
|
| 21 |
+
def cache_state_hook(module, args):
|
| 22 |
+
module.state = args[0]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
class AgentQwen2VLConfig(Qwen2VLConfig):
|
| 26 |
model_type = 'agent_qwen2_vl'
|
| 27 |
|
|
|
|
| 56 |
|
| 57 |
def __init__(self, config):
|
| 58 |
super().__init__(config)
|
| 59 |
+
self.norm.register_forward_pre_hook(cache_state_hook)
|
| 60 |
|
| 61 |
def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
|
| 62 |
# ensure gradient tracking (in case that embed_tokens has been frozen)
|