Commit
·
367cc7f
1
Parent(s):
305557f
Change the call interface and adjust the program execution logic
Browse files- mlcd_seg.py +4 -85
mlcd_seg.py
CHANGED
|
@@ -38,6 +38,7 @@ from PIL import Image
|
|
| 38 |
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
|
| 39 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 40 |
from transformers.generation.utils import GenerateOutput
|
|
|
|
| 41 |
from safetensors.torch import load_file as safetensors_load
|
| 42 |
from .vision_tower import build_vision_tower
|
| 43 |
from .vision_resampler import build_vision_resampler
|
|
@@ -141,10 +142,8 @@ class MLCDSegMetaModel:
|
|
| 141 |
|
| 142 |
def dispatch_weight(self, config):
|
| 143 |
safetensors_set = set()
|
| 144 |
-
|
| 145 |
-
index_file =
|
| 146 |
-
if not index_file.exists():
|
| 147 |
-
os.getenv("")
|
| 148 |
with open(index_file, "r") as safetensors_index:
|
| 149 |
safetensors_map = json.loads(safetensors_index.read())
|
| 150 |
for key, value in safetensors_map["weight_map"].items():
|
|
@@ -156,7 +155,7 @@ class MLCDSegMetaModel:
|
|
| 156 |
projector_weight = {}
|
| 157 |
text2sam_projection_weight = {}
|
| 158 |
for safetensors_file in safetensors_set:
|
| 159 |
-
temp_load = safetensors_load(
|
| 160 |
for key, value in temp_load.items():
|
| 161 |
if key.startswith("model.sam."):
|
| 162 |
sam_weight[key.replace("model.sam.", "")] = value
|
|
@@ -174,86 +173,6 @@ class MLCDSegMetaModel:
|
|
| 174 |
vision_tower = vision_tower[0]
|
| 175 |
return vision_tower
|
| 176 |
|
| 177 |
-
# def initialize_vision_modules(self, model_args, fsdp=None):
|
| 178 |
-
# vision_tower = model_args.vision_tower
|
| 179 |
-
# mm_vision_select_layer = model_args.mm_vision_select_layer
|
| 180 |
-
# mm_vision_select_feature = model_args.mm_vision_select_feature
|
| 181 |
-
# pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
| 182 |
-
# mm_patch_merge_type = model_args.mm_patch_merge_type
|
| 183 |
-
|
| 184 |
-
# self.config.mm_vision_tower = vision_tower
|
| 185 |
-
# self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
|
| 186 |
-
|
| 187 |
-
# if self.get_vision_tower() is None:
|
| 188 |
-
# vision_tower = build_vision_tower(model_args)
|
| 189 |
-
# vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
|
| 190 |
-
# for k, v in vision_resampler.config.items():
|
| 191 |
-
# setattr(self.config, k, v)
|
| 192 |
-
|
| 193 |
-
# if fsdp is not None and len(fsdp) > 0:
|
| 194 |
-
# self.vision_tower = [vision_tower]
|
| 195 |
-
# self.vision_resampler = [vision_resampler]
|
| 196 |
-
# else:
|
| 197 |
-
# self.vision_tower = vision_tower
|
| 198 |
-
# self.vision_resampler = vision_resampler
|
| 199 |
-
# else:
|
| 200 |
-
# if fsdp is not None and len(fsdp) > 0:
|
| 201 |
-
# vision_resampler = self.vision_resampler[0]
|
| 202 |
-
# vision_tower = self.vision_tower[0]
|
| 203 |
-
# else:
|
| 204 |
-
# vision_resampler = self.vision_resampler
|
| 205 |
-
# vision_tower = self.vision_tower
|
| 206 |
-
# vision_tower.load_model()
|
| 207 |
-
|
| 208 |
-
# # In case it is frozen by LoRA
|
| 209 |
-
# for p in self.vision_resampler.parameters():
|
| 210 |
-
# p.requires_grad = True
|
| 211 |
-
|
| 212 |
-
# self.config.use_mm_proj = True
|
| 213 |
-
# self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
|
| 214 |
-
# self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
|
| 215 |
-
# self.config.mm_vision_select_layer = mm_vision_select_layer
|
| 216 |
-
# self.config.mm_vision_select_feature = mm_vision_select_feature
|
| 217 |
-
# self.config.mm_patch_merge_type = mm_patch_merge_type
|
| 218 |
-
|
| 219 |
-
# for key in vars(model_args):
|
| 220 |
-
# if key.startswith('sam_'):
|
| 221 |
-
# setattr(self.config, key, getattr(model_args, key))
|
| 222 |
-
|
| 223 |
-
# if not hasattr(self.config, 'add_faster_video'):
|
| 224 |
-
# if model_args.add_faster_video:
|
| 225 |
-
# embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
| 226 |
-
# self.faster_token = nn.Parameter(
|
| 227 |
-
# torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
|
| 228 |
-
# )
|
| 229 |
-
|
| 230 |
-
# if getattr(self, "mm_projector", None) is None:
|
| 231 |
-
# self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
|
| 232 |
-
|
| 233 |
-
# if "unpad" in mm_patch_merge_type:
|
| 234 |
-
# embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
| 235 |
-
# self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
|
| 236 |
-
|
| 237 |
-
# if getattr(self.config, 'sam_path', None) is not None:
|
| 238 |
-
# self.sam = build_sam_vit_h(self.config.sam_path)
|
| 239 |
-
# self.text2sam_projection = text2sam_projection_layer(self.config)
|
| 240 |
-
# else:
|
| 241 |
-
# if getattr(self.config, 'sam_path', None) is not None and self.config.sam_path !="":
|
| 242 |
-
# self.sam = build_sam_vit_h(self.config.sam_path)
|
| 243 |
-
# self.text2sam_projection = text2sam_projection_layer(self.config)
|
| 244 |
-
# # In case it is frozen by LoRA
|
| 245 |
-
# for p in self.mm_projector.parameters():
|
| 246 |
-
# p.requires_grad = True
|
| 247 |
-
|
| 248 |
-
# if pretrain_mm_mlp_adapter is not None:
|
| 249 |
-
# mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
|
| 250 |
-
|
| 251 |
-
# def get_w(weights, keyword):
|
| 252 |
-
# return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
| 253 |
-
|
| 254 |
-
# incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
|
| 255 |
-
# incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
|
| 256 |
-
|
| 257 |
|
| 258 |
def unpad_image(tensor, original_size):
|
| 259 |
"""
|
|
|
|
| 38 |
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
|
| 39 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 40 |
from transformers.generation.utils import GenerateOutput
|
| 41 |
+
from transformers.utils import cached_file
|
| 42 |
from safetensors.torch import load_file as safetensors_load
|
| 43 |
from .vision_tower import build_vision_tower
|
| 44 |
from .vision_resampler import build_vision_resampler
|
|
|
|
| 142 |
|
| 143 |
def dispatch_weight(self, config):
|
| 144 |
safetensors_set = set()
|
| 145 |
+
repo = getattr(config, "_name_or_path", "'DeepGlint-AI/MLCD-Seg'")
|
| 146 |
+
index_file = cached_file(repo, "model.safetensors.index.json")
|
|
|
|
|
|
|
| 147 |
with open(index_file, "r") as safetensors_index:
|
| 148 |
safetensors_map = json.loads(safetensors_index.read())
|
| 149 |
for key, value in safetensors_map["weight_map"].items():
|
|
|
|
| 155 |
projector_weight = {}
|
| 156 |
text2sam_projection_weight = {}
|
| 157 |
for safetensors_file in safetensors_set:
|
| 158 |
+
temp_load = safetensors_load(cached_file(repo, safetensors_file))
|
| 159 |
for key, value in temp_load.items():
|
| 160 |
if key.startswith("model.sam."):
|
| 161 |
sam_weight[key.replace("model.sam.", "")] = value
|
|
|
|
| 173 |
vision_tower = vision_tower[0]
|
| 174 |
return vision_tower
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
def unpad_image(tensor, original_size):
|
| 178 |
"""
|