Spaces:
Runtime error
Runtime error
Update llava/model/builder.py
Browse files- llava/model/builder.py +16 -14
llava/model/builder.py
CHANGED
|
@@ -139,6 +139,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
| 139 |
if 'llava' in model_name.lower():
|
| 140 |
mm_use_x_start_end = getattr(model.config, "mm_use_x_start_end", False)
|
| 141 |
mm_use_x_patch_token = getattr(model.config, "mm_use_x_patch_token", True)
|
|
|
|
| 142 |
X = model.config.X
|
| 143 |
if mm_use_x_patch_token:
|
| 144 |
for x in X:
|
|
@@ -146,23 +147,24 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
| 146 |
if mm_use_x_start_end:
|
| 147 |
for x in X:
|
| 148 |
tokenizer.add_tokens([DEFAULT_X_START_TOKEN[x.upper()], DEFAULT_X_END_TOKEN[x.upper()]], special_tokens=True)
|
|
|
|
| 149 |
model.resize_token_embeddings(len(tokenizer))
|
| 150 |
print(X)
|
| 151 |
-
if 'Image' in X:
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
|
| 159 |
-
if 'Video' in X:
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
|
| 167 |
if hasattr(model.config, "max_sequence_length"):
|
| 168 |
context_len = model.config.max_sequence_length
|
|
|
|
| 139 |
if 'llava' in model_name.lower():
|
| 140 |
mm_use_x_start_end = getattr(model.config, "mm_use_x_start_end", False)
|
| 141 |
mm_use_x_patch_token = getattr(model.config, "mm_use_x_patch_token", True)
|
| 142 |
+
'''
|
| 143 |
X = model.config.X
|
| 144 |
if mm_use_x_patch_token:
|
| 145 |
for x in X:
|
|
|
|
| 147 |
if mm_use_x_start_end:
|
| 148 |
for x in X:
|
| 149 |
tokenizer.add_tokens([DEFAULT_X_START_TOKEN[x.upper()], DEFAULT_X_END_TOKEN[x.upper()]], special_tokens=True)
|
| 150 |
+
'''
|
| 151 |
model.resize_token_embeddings(len(tokenizer))
|
| 152 |
print(X)
|
| 153 |
+
#if 'Image' in X:
|
| 154 |
+
image_tower = model.get_image_tower()
|
| 155 |
+
if not image_tower.is_loaded:
|
| 156 |
+
image_tower.load_model()
|
| 157 |
+
image_tower.to(device=device, dtype=torch.float16)
|
| 158 |
+
image_processor = image_tower.image_processor
|
| 159 |
+
processor['image'] = image_processor
|
| 160 |
|
| 161 |
+
#if 'Video' in X:
|
| 162 |
+
video_tower = model.get_video_tower()
|
| 163 |
+
if not video_tower.is_loaded:
|
| 164 |
+
video_tower.load_model()
|
| 165 |
+
video_tower.to(device=device, dtype=torch.float16)
|
| 166 |
+
video_processor = video_tower.video_processor
|
| 167 |
+
processor['video'] = video_processor
|
| 168 |
|
| 169 |
if hasattr(model.config, "max_sequence_length"):
|
| 170 |
context_len = model.config.max_sequence_length
|