|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from accelerate.hooks import add_hook_to_module | 
					
						
						|  | from einops import rearrange | 
					
						
						|  | from s2wrapper import forward as multiscale_forward | 
					
						
						|  | from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor | 
					
						
						|  | from transformers.image_processing_utils import BaseImageProcessor | 
					
						
						|  | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled | 
					
						
						|  | from transformers.models.siglip import SiglipVisionModel | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class VisionTower(nn.Module): | 
					
						
						|  | def __init__(self, vision_tower, args, delay_load=False): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.is_loaded = False | 
					
						
						|  |  | 
					
						
						|  | self.vision_tower_name = vision_tower | 
					
						
						|  | self.select_layer = getattr(args, "mm_vision_select_layer", -2) | 
					
						
						|  | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") | 
					
						
						|  |  | 
					
						
						|  | self.cfg_only = None | 
					
						
						|  |  | 
					
						
						|  | def feature_select(self, image_forward_outs): | 
					
						
						|  | image_features = image_forward_outs.hidden_states[self.select_layer] | 
					
						
						|  | if self.select_feature == "patch": | 
					
						
						|  | image_features = image_features[:, 1:] | 
					
						
						|  | elif self.select_feature == "cls_patch": | 
					
						
						|  | image_features = image_features | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Unexpected select feature: {self.select_feature}") | 
					
						
						|  | return image_features | 
					
						
						|  |  | 
					
						
						|  | def _maybe_resize_pos_embeds( | 
					
						
						|  | self, | 
					
						
						|  | model: PreTrainedModel, | 
					
						
						|  | image_processor: BaseImageProcessor, | 
					
						
						|  | resolution: int = -1, | 
					
						
						|  | interpolate_mode: str = "linear", | 
					
						
						|  | ): | 
					
						
						|  | if resolution in [model.config.image_size, -1]: | 
					
						
						|  | return | 
					
						
						|  | print( | 
					
						
						|  | f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..." | 
					
						
						|  | ) | 
					
						
						|  | embeddings = model.vision_model.embeddings | 
					
						
						|  | patch_size = embeddings.patch_size | 
					
						
						|  | num_new_tokens = int((resolution // patch_size) ** 2) | 
					
						
						|  |  | 
					
						
						|  | old_embeddings = embeddings.position_embedding | 
					
						
						|  | match interpolate_mode: | 
					
						
						|  | case "linear": | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  |  | 
					
						
						|  | if is_deepspeed_zero3_enabled(): | 
					
						
						|  | try: | 
					
						
						|  | import deepspeed | 
					
						
						|  | except ImportError: | 
					
						
						|  | raise ImportError("DeepSpeed is not installed. Please install it with `pip install deepspeed`.") | 
					
						
						|  | with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None): | 
					
						
						|  | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() | 
					
						
						|  | else: | 
					
						
						|  | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() | 
					
						
						|  | new_embeddings = nn.Embedding( | 
					
						
						|  | num_new_tokens, | 
					
						
						|  | old_embedding_dim, | 
					
						
						|  | dtype=old_embeddings.weight.dtype, | 
					
						
						|  | device=old_embeddings.weight.device, | 
					
						
						|  | ) | 
					
						
						|  | mapped_indices = ( | 
					
						
						|  | torch.arange(num_new_tokens).to(old_embeddings.weight.device) | 
					
						
						|  | / (num_new_tokens - 1) | 
					
						
						|  | * (old_num_tokens - 1) | 
					
						
						|  | ) | 
					
						
						|  | floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1) | 
					
						
						|  | ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1) | 
					
						
						|  | if is_deepspeed_zero3_enabled(): | 
					
						
						|  | params = [old_embeddings.weight, new_embeddings.weight] | 
					
						
						|  | with deepspeed.zero.GatheredParameters(params, modifier_rank=0): | 
					
						
						|  | interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[ | 
					
						
						|  | ceil_indices, : | 
					
						
						|  | ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :] | 
					
						
						|  | else: | 
					
						
						|  | interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[ | 
					
						
						|  | ceil_indices, : | 
					
						
						|  | ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :] | 
					
						
						|  | new_embeddings.weight.data = interpolated_embeds | 
					
						
						|  | case _: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | if hasattr(old_embeddings, "_hf_hook"): | 
					
						
						|  | hook = old_embeddings._hf_hook | 
					
						
						|  | add_hook_to_module(new_embeddings, hook) | 
					
						
						|  | new_embeddings.requires_grad_(old_embeddings.weight.requires_grad) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.config.image_size = resolution | 
					
						
						|  | if hasattr(image_processor, "crop_size"): | 
					
						
						|  |  | 
					
						
						|  | image_processor.crop_size = resolution | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | assert hasattr(image_processor, "size") | 
					
						
						|  | image_processor.size = {"height": resolution, "width": resolution} | 
					
						
						|  |  | 
					
						
						|  | embeddings.position_embedding = new_embeddings | 
					
						
						|  | embeddings.image_size = resolution | 
					
						
						|  | embeddings.num_patches = embeddings.num_positions = num_new_tokens | 
					
						
						|  | embeddings.position_ids = ( | 
					
						
						|  | torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, images): | 
					
						
						|  | if type(images) is list: | 
					
						
						|  | image_features = [] | 
					
						
						|  | for image in images: | 
					
						
						|  | image_forward_out = self.vision_tower( | 
					
						
						|  | image.to(device=self.device, dtype=self.dtype).unsqueeze(0), | 
					
						
						|  | output_hidden_states=True, | 
					
						
						|  | ) | 
					
						
						|  | image_feature = self.feature_select(image_forward_out).to(image.dtype) | 
					
						
						|  | image_features.append(image_feature) | 
					
						
						|  | else: | 
					
						
						|  | image_forward_outs = self.vision_tower( | 
					
						
						|  | images.to(device=self.device, dtype=self.dtype), | 
					
						
						|  | output_hidden_states=True, | 
					
						
						|  | ) | 
					
						
						|  | image_features = self.feature_select(image_forward_outs).to(images.dtype) | 
					
						
						|  |  | 
					
						
						|  | return image_features | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def dummy_feature(self): | 
					
						
						|  | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def dtype(self): | 
					
						
						|  | return self.vision_tower.dtype | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def device(self): | 
					
						
						|  | return self.vision_tower.device | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def config(self): | 
					
						
						|  | if self.is_loaded: | 
					
						
						|  | return self.vision_tower.config | 
					
						
						|  | else: | 
					
						
						|  | return self.cfg_only | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def hidden_size(self): | 
					
						
						|  | return self.config.hidden_size | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def num_patches(self): | 
					
						
						|  | return (self.config.image_size // self.config.patch_size) ** 2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class VisionTowerS2(VisionTower): | 
					
						
						|  | def __init__(self, vision_tower, args, delay_load=False): | 
					
						
						|  | super().__init__(vision_tower, args, delay_load) | 
					
						
						|  |  | 
					
						
						|  | self.scales = list(map(int, args.s2_scales.split(","))) | 
					
						
						|  | self.scales.sort() | 
					
						
						|  | self.max_split_size = args.s2_max_split_size | 
					
						
						|  | self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0) | 
					
						
						|  |  | 
					
						
						|  | def forward_feature(self, images): | 
					
						
						|  | image_forward_outs = self.vision_tower( | 
					
						
						|  | images.to(device=self.device, dtype=self.dtype), output_hidden_states=True | 
					
						
						|  | ) | 
					
						
						|  | image_features = self.feature_select(image_forward_outs).to(images.dtype) | 
					
						
						|  | return image_features | 
					
						
						|  |  | 
					
						
						|  | def forward(self, images): | 
					
						
						|  | if type(images) is list: | 
					
						
						|  | image_feature = [] | 
					
						
						|  | for image in images: | 
					
						
						|  | image_feature = multiscale_forward( | 
					
						
						|  | self.forward_feature, | 
					
						
						|  | image.unsqueeze(0), | 
					
						
						|  | img_sizes=self.scales, | 
					
						
						|  | max_split_size=self.max_split_size, | 
					
						
						|  | resize_output_to_idx=self.resize_output_to_scale_idx, | 
					
						
						|  | ) | 
					
						
						|  | image_features.append(image_feature) | 
					
						
						|  | else: | 
					
						
						|  | image_features = multiscale_forward( | 
					
						
						|  | self.forward_feature, | 
					
						
						|  | images, | 
					
						
						|  | img_sizes=self.scales, | 
					
						
						|  | max_split_size=self.max_split_size, | 
					
						
						|  | resize_output_to_idx=self.resize_output_to_scale_idx, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return image_features | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def hidden_size(self): | 
					
						
						|  | return self.config.hidden_size * len(self.scales) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class VisionTowerDynamicS2(VisionTower): | 
					
						
						|  | def __init__(self, vision_tower, args, delay_load=False): | 
					
						
						|  | super().__init__(vision_tower, args, delay_load) | 
					
						
						|  |  | 
					
						
						|  | self.scales = list(map(int, args.s2_scales.split(","))) | 
					
						
						|  | self.scales.sort() | 
					
						
						|  | self.max_split_size = args.s2_max_split_size | 
					
						
						|  | self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0) | 
					
						
						|  |  | 
					
						
						|  | def forward_feature(self, images): | 
					
						
						|  | image_forward_outs = self.vision_tower( | 
					
						
						|  | images.to(device=self.device, dtype=self.dtype), output_hidden_states=True | 
					
						
						|  | ) | 
					
						
						|  | image_features = self.feature_select(image_forward_outs).to(images.dtype) | 
					
						
						|  | return image_features | 
					
						
						|  |  | 
					
						
						|  | def forward(self, images): | 
					
						
						|  | assert type(images) is not list | 
					
						
						|  | image_features = self.forward_feature(images) | 
					
						
						|  |  | 
					
						
						|  | return image_features | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def hidden_size(self): | 
					
						
						|  | return self.config.hidden_size * len(self.scales) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SiglipVisionTower(VisionTower): | 
					
						
						|  | def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None: | 
					
						
						|  | super().__init__(model_name_or_path, config) | 
					
						
						|  | self.vision_tower = SiglipVisionModel.from_pretrained( | 
					
						
						|  | model_name_or_path, | 
					
						
						|  | attn_implementation=config._attn_implementation, | 
					
						
						|  | torch_dtype=eval(config.model_dtype), | 
					
						
						|  | ) | 
					
						
						|  | self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) | 
					
						
						|  | self.is_loaded = True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SiglipVisionTowerS2(VisionTowerS2): | 
					
						
						|  | def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None: | 
					
						
						|  | super().__init__(model_name_or_path, config) | 
					
						
						|  | self.vision_tower = SiglipVisionModel.from_pretrained( | 
					
						
						|  | model_name_or_path, | 
					
						
						|  | attn_implementation=config._attn_implementation, | 
					
						
						|  | torch_dtype=eval(config.model_dtype), | 
					
						
						|  | ) | 
					
						
						|  | self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) | 
					
						
						|  |  | 
					
						
						|  | self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1] | 
					
						
						|  | self.is_loaded = True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2): | 
					
						
						|  | def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None: | 
					
						
						|  | super().__init__(model_name_or_path, config) | 
					
						
						|  | if type(config.model_dtype) == str: | 
					
						
						|  | model_dtype = eval(config.model_dtype) | 
					
						
						|  | else: | 
					
						
						|  | model_dtype = config.model_dtype | 
					
						
						|  |  | 
					
						
						|  | self.vision_tower = SiglipVisionModel.from_pretrained( | 
					
						
						|  | model_name_or_path, | 
					
						
						|  | attn_implementation="flash_attention_2", | 
					
						
						|  | torch_dtype=model_dtype, | 
					
						
						|  | ) | 
					
						
						|  | self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) | 
					
						
						|  |  | 
					
						
						|  | self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0] | 
					
						
						|  | self.is_loaded = True | 
					
						
						|  |  |