omnivinci / modeling_vila.py
leoye's picture
Initial commit
fd01e7c
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import logging
import numpy as np
import os
import os.path
import os.path as osp
import shutil
import warnings
from abc import ABC
from collections import OrderedDict, defaultdict, deque
from copy import deepcopy
from itertools import chain
from threading import Thread
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from einops import rearrange
from PIL import Image
from transformers import (
AutoConfig,
AutoModel,
AutoProcessor,
AutoTokenizer,
GenerationConfig,
LogitsProcessor,
PretrainedConfig,
PreTrainedModel,
Qwen2Config,
Qwen2ForCausalLM,
Qwen2PreTrainedModel,
TextIteratorStreamer,
WhisperFeatureExtractor,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import ContextManagers, no_init_weights
from .auto_processor import VILAProcessor
from .base_projector import MultimodalProjector, MultimodalProjectorConfig
from .sound_base_projector import SoundMultimodalProjector, SoundMultimodalProjectorConfig
from .speech_base_projector import SpeechMultimodalProjector, SpeechMultimodalProjectorConfig
from .builder import build_llm_and_tokenizer
from .configuration_vila import VILAConfig
from .constants import *
from .conversation import SeparatorStyle, default_conversation
from .distributed import all_gather as vila_all_gather
from .media import extract_media
from .media_encoder import BasicImageEncoder, BasicVideoEncoder, TSPVideoEncoder, BasicSoundEncoder, CacheFeatures
from .mm_utils import process_image, process_images
from .model_utils_packing import set_seqlens_in_batch
from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
from .tokenizer_utils import tokenize_conversation
from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template
from .constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS_VILA, NUM_EXTRA_TOKENS_XVILA
from .qwen_audio_encoder import Qwen2AudioTower
import whisper
from .audio_encoder import AudioTower
def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
"""Build multimodal projector from path or configuration."""
if model_type_or_path is None:
return None
if config.resume_path:
assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
return MultimodalProjector.from_pretrained(model_type_or_path, config)
else:
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
mm_projector = MultimodalProjector(mm_projector_cfg, config)
return mm_projector
def build_speech_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
"""Build speech multimodal projector from path or configuration."""
if model_type_or_path is None:
return None
if config.resume_path:
assert os.path.exists(model_type_or_path), f"Resume speech mm projector path {model_type_or_path} does not exist!"
_model = SpeechMultimodalProjector.from_pretrained(
model_type_or_path, config, torch_dtype=eval(config.model_dtype)
)
return _model
else:
speech_mm_projector_cfg = SpeechMultimodalProjectorConfig(model_type_or_path)
speech_mm_projector = SpeechMultimodalProjector(speech_mm_projector_cfg, config).to(eval(config.model_dtype))
return speech_mm_projector
def build_sound_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
"""Build sound multimodal projector from path or configuration."""
if model_type_or_path is None:
return None
if type(config.model_dtype) == str:
model_dtype = eval(config.model_dtype)
else:
model_dtype = config.model_dtype
if config.resume_path:
assert os.path.exists(model_type_or_path), f"Resume sound mm projector path {model_type_or_path} does not exist!"
_model = SoundMultimodalProjector.from_pretrained(
model_type_or_path, config, torch_dtype=model_dtype
)
return _model
else:
sound_mm_projector_cfg = SoundMultimodalProjectorConfig(model_type_or_path)
sound_mm_projector = SoundMultimodalProjector(sound_mm_projector_cfg, config).to(model_dtype)
return sound_mm_projector
def check_dot_in_model_path(model_path: str):
"""Check if the model path contains a dot, which may affect model loading."""
if osp.isdir(model_path):
if "." in osp.abspath(model_path):
return True
else:
if "." in model_path:
return True
return False
def get_vila_version(model_path: str) -> str:
VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
for version in VERSIONS:
if version in model_path.lower():
return version
return None
def generate_jinja_template(conv_mode: str) -> str:
if conv_mode == "vicuna_v1":
return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %}
{% set roles = ["user", "assistant"] %}
{% set sep = " " %}
{{ system_prompt }}
{% for message in messages %}
{% if message['role'] == roles[0] %}
{{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }}
{% else %}
{{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }}
{% endif %}
{% endfor %}
{% if messages[-1]['role'] == 'user' %}
{{ "ASSISTANT:" }}
{% endif %}
"""
elif conv_mode == "llama_3":
return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %}
{% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%}
{% set sep = "<|eot_id|>" %}
{{ system_prompt }}
{% for message in messages %}
{% if message['role'] == 'user' %}
{{ roles[0] }}{{ message['content'] }}{{ sep }}
{% else %}
{{ roles[1] }}{{ message['content'] }}{{ sep }}
{% endif %}
{% endfor %}
{% if messages[-1]['role'] == 'user' %}
{{ roles[1] }}
{% endif %}
"""
elif conv_mode == "hermes_2":
return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
{% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
{% set sep = "<|im_end|>" %}
{{ system_prompt }}{{ sep }}
{% for message in messages %}
{% if message['role'] == 'user' %}
{{ roles[0] }}{{ message['content'] }}{{ sep }}
{% else %}
{{ roles[1] }}{{ message['content'] }}{{ sep }}
{% endif %}
{% endfor %}"""
else:
raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
"""Build vision tower from path or configuration."""
# Skip vision tower instantiation if path is None
if model_name_or_path is None:
return None
vision_tower_arch = None
if config.resume_path and "radio" not in model_name_or_path:
assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
vision_tower_arch = vision_tower_cfg.architectures[0].lower()
vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
use_s2 = getattr(config, "s2", False)
use_dynamic_s2 = getattr(config, "dynamic_s2", False)
if "siglip" in vision_tower_name:
if use_dynamic_s2:
vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
elif use_s2:
vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
else:
vision_tower = SiglipVisionTower(model_name_or_path, config)
else:
raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
config.mm_hidden_size = (
vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
)
return vision_tower
def build_audio_tower(model_name_or_path: str, config: PretrainedConfig, encoder_type: str) -> PreTrainedModel:
"""Build audio tower for sound or speech processing."""
assert encoder_type in ["sound", "speech"]
# Skip tower instantiation if path is None
if model_name_or_path is None:
return None
model_type = "af3"
if model_type == "af3":
model = Qwen2AudioTower(model_name_or_path, config)
output_dim = 1280
else:
raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}")
if encoder_type == "sound":
config.sound_hidden_size = output_dim
elif encoder_type == "speech":
config.speech_hidden_size = output_dim
else:
raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}")
return model
class VILAPretrainedModel(PreTrainedModel):
config_class = VILAConfig
main_input_name = "input_embeds"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"]
def __init__(self, config: VILAConfig, *args, **kwargs):
super().__init__(config)
self.config = config
cfgs = get_model_config(config)
if len(cfgs) == 7:
(
llm_cfg,
vision_tower_cfg,
speech_tower_cfg,
sound_tower_cfg,
mm_projector_cfg,
speech_mm_projector_cfg,
sound_mm_projector_cfg,
) = cfgs
else:
raise ValueError(
"`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config."
)
# loading on auto by default
device_map = kwargs.get("device_map", "auto")
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
if speech_tower_cfg:
self.speech_tower = build_audio_tower(speech_tower_cfg, config, encoder_type="speech")
self.speech_mm_projector = build_speech_mm_projector(speech_mm_projector_cfg, config)
if sound_tower_cfg:
self.sound_tower = build_audio_tower(sound_tower_cfg, config, encoder_type="sound")
self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config)
if device_map in ["auto", "cuda"]:
self.mm_projector = self.mm_projector.cuda()
self.vision_tower = self.vision_tower.cuda()
self.speech_tower = self.speech_tower.cuda() if hasattr(self, "speech_tower") else None
self.sound_tower = self.sound_tower.cuda() if hasattr(self, "sound_tower") else None
self.speech_mm_projector = self.speech_mm_projector.cuda() if hasattr(self, "speech_mm_projector") else None
self.sound_mm_projector = self.sound_mm_projector.cuda() if hasattr(self, "sound_mm_projector") else None
# set device_map auto can autoamtically shard llm to different devices
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
self.llm_model_embed_tokens = self.llm.model.embed_tokens
self.tokenizer.padding_side = "left"
self.vocab_size = len(self.tokenizer)
self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer))
self.encoders = {}
for name in ["image", "video", "speech", "sound"]:
encoder_config = getattr(self.config, f"{name}_encoder")
if isinstance(encoder_config, str):
encoder_config = json.loads(encoder_config)
if encoder_config.get("embed_time", False) == "True":
if "trope_dim" not in encoder_config and encoder_config.get("time_embed_type", "") in ["pixel", "lang"]:
encoder_config["trope_dim"] = self.config.hidden_size // 2
print(f"Warning: trope_dim not found in config, defaulting to hidden_size // 2: {encoder_config['trope_dim']}")
encoder_config.pop('_target_')
if name == "video":
self.encoders[name] = TSPVideoEncoder(parent=self, **encoder_config)
elif name == "image":
self.encoders[name] = BasicImageEncoder(self)
else:
self.encoders[name] = BasicSoundEncoder(parent=self, **encoder_config)
self.post_config()
self.is_loaded = True
self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False)
if self.llm_only_need_embed:
print("We only need the embed_tokens in llm.")
del self.llm
self.llm = None
torch.cuda.empty_cache()
assert (
self.llm is not None
or self.vision_tower is not None
or self.speech_tower is not None
or self.mm_projector is not None
or self.speech_mm_projector is not None
), "At least one of the components must be instantiated."
@classmethod
def copy_or_symlink_directory(cls, model_path, output_dir, copy=True):
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Create symlinks for all files in model_path to output_dir
for item in os.listdir(model_path):
src_path = os.path.join(model_path, item)
dst_path = os.path.join(output_dir, item)
# Remove existing file/directory at destination if it exists
if os.path.exists(dst_path):
if os.path.islink(dst_path):
os.unlink(dst_path)
elif os.path.isdir(dst_path):
shutil.rmtree(dst_path)
else:
os.remove(dst_path)
# Create symlink
if copy:
if os.path.isdir(src_path):
shutil.copytree(src_path, dst_path)
else:
shutil.copy2(src_path, dst_path)
print(f"Copied {src_path} to {dst_path}")
else:
os.symlink(src_path, dst_path)
print(f"Created symlink from {src_path} to {dst_path}")
@classmethod
def copy_remote_py_files(cls, output_dir, copy=True):
# copy .py and README for next loading
current_file_path = os.path.abspath(__file__)
current_folder = os.path.dirname(current_file_path)
for file_name in os.listdir(current_folder):
if file_name == "INSTRUCTIONS.md":
src_fname = os.path.join(current_folder, file_name)
dst_fname = os.path.join(output_dir, "README.md")
if os.path.exists(dst_fname):
old_readme = open(dst_fname).read()
else:
old_readme = ""
with open(src_fname) as src, open(dst_fname, "w") as dst:
dst.write(src.read())
dst.write(old_readme)
print("[HF] README", src_fname, "to", dst_fname)
if file_name.endswith(".py") or file_name.endswith(".jinja"):
full_file_name = os.path.join(current_folder, file_name)
if os.path.isfile(full_file_name):
if copy:
shutil.copy(full_file_name, output_dir)
print("[HF] copying", full_file_name, "to", output_dir)
else:
# symlink to ease development
if os.path.exists(os.path.join(output_dir, file_name)):
os.remove(os.path.join(output_dir, file_name))
os.symlink(full_file_name, os.path.join(output_dir, file_name))
print("[HF] linking", full_file_name, "to", output_dir)
def save_pretrained(self, output_dir, state_dict=None, **kwargs):
if state_dict is None:
# other wise fetch from deepspeed
# state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
state_dict = self.state_dict()
if getattr(self, "tokenizer", None):
self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
if self.get_llm():
print(f"saving llm to {osp.join(output_dir, 'llm')}")
self.llm.config._name_or_path = osp.join(output_dir, "llm")
llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
self.config.llm_cfg = self.llm.config
if self.get_vision_tower():
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
vision_tower_state_dict = OrderedDict(
{k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
)
self.vision_tower.vision_tower.save_pretrained(
os.path.join(output_dir, "vision_tower"),
state_dict=vision_tower_state_dict,
)
self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
self.config.vision_tower_cfg = self.vision_tower.config
if hasattr(self.config.vision_tower_cfg, "auto_map"):
if "radio" not in self.get_vision_tower().__class__.__name__.lower():
delattr(self.config.vision_tower_cfg, "auto_map")
if self.get_speech_tower():
print(f"saving speech_tower to {osp.join(output_dir, 'speech_tower')}")
self.speech_tower.config._name_or_path = osp.join(output_dir, "speech_tower").replace(
"tmp-checkpoint", "checkpoint"
)
speech_tower_state_dict = OrderedDict(
{k.split("speech_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "speech_tower" in k}
)
self.speech_tower.audio_tower.save_pretrained(
os.path.join(output_dir, "speech_tower"),
state_dict=speech_tower_state_dict,
)
self.config.speech_tower_cfg = self.speech_tower.config
if self.get_sound_tower():
print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}")
self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower").replace(
"tmp-checkpoint", "checkpoint"
)
sound_tower_state_dict = OrderedDict(
{k.split("sound_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k}
)
self.sound_tower.audio_tower.save_pretrained(
os.path.join(output_dir, "sound_tower"),
state_dict=sound_tower_state_dict,
)
self.config.sound_tower_cfg = self.sound_tower.config
if self.get_mm_projector():
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
mm_projector_state_dict = OrderedDict(
{k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
)
self.mm_projector.save_pretrained(
os.path.join(output_dir, "mm_projector"),
state_dict=mm_projector_state_dict,
)
self.config.mm_projector_cfg = self.mm_projector.config
if self.get_speech_mm_projector():
print(f"saving speech_mm_projector to {osp.join(output_dir, 'speech_mm_projector')}")
self.speech_mm_projector.config._name_or_path = osp.join(output_dir, "speech_mm_projector").replace(
"tmp-checkpoint", "checkpoint"
)
speech_mm_projector_state_dict = OrderedDict(
{k.split("speech_mm_projector.")[-1]: v for k, v in state_dict.items() if "speech_mm_projector" in k}
)
self.speech_mm_projector.save_pretrained(
os.path.join(output_dir, "speech_mm_projector"),
state_dict=speech_mm_projector_state_dict,
)
self.config.speech_mm_projector_cfg = self.speech_mm_projector.config
if self.get_sound_mm_projector():
print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}")
self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector").replace(
"tmp-checkpoint", "checkpoint"
)
sound_mm_projector_state_dict = OrderedDict(
{k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k}
)
self.sound_mm_projector.save_pretrained(
os.path.join(output_dir, "sound_mm_projector"),
state_dict=sound_mm_projector_state_dict,
)
self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
# update and save top-level config
self.config._name_or_path = output_dir
self.config.architectures = [self.__class__.__name__]
self.config.save_pretrained(output_dir)
# copy .py and README for next loading
self.copy_remote_py_files(output_dir)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[str] = None,
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: Optional[bool] = None,
weights_only: bool = True,
**kwargs,
):
# print("DEBUG2", kwargs); input()
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
if kwargs.get("torch_dtype", None) is not None:
config.torch_dtype = kwargs.get("torch_dtype", None)
config.model_dtype = kwargs.get("torch_dtype", None)
if type(kwargs.get("torch_dtype", None)) == str:
kwargs["torch_dtype"] = eval(kwargs.get("torch_dtype", None))
else:
kwargs["torch_dtype"] = kwargs.get("torch_dtype", None)
return cls._from_config(config, **kwargs)
def init_llm(self, llm_config, config, *args, **kwargs):
"""Initialize language model and tokenizer."""
self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
self.pad_token_list = (
self.tokenizer.pad_token_id,
self.tokenizer.eos_token_id,
self.tokenizer.tokenize("<|endoftext|>")[0], # for Qwen
)
self.vocab_size = len(self.tokenizer)
self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer))
# XGrammar tokenizer and grammar compiler
# lazy init only when specified json output during inference
self.grammar_compiler = None
# self.llm.resize_token_embeddings(len(self.tokenizer))
return self.llm, self.tokenizer
def post_config(self):
self.training = self.llm.training
if self.training:
self.train()
else:
self.eval()
# configuration
if getattr(self.config, "llm_cfg", None) is None:
self.config.llm_cfg = self.llm.config
if getattr(self.config, "vision_tower_cfg", None) is None:
self.config.vision_tower_cfg = self.vision_tower.config
if getattr(self.config, "mm_projector_cfg", None) is None:
self.config.mm_projector_cfg = self.mm_projector.config
if getattr(self.config, "speech_tower_cfg", None) is None and hasattr(self, "speech_tower"):
self.config.speech_tower_cfg = self.speech_tower.config
if getattr(self.config, "sound_tower_cfg", None) is None and hasattr(self, "sound_tower"):
self.config.sound_tower_cfg = self.sound_tower.config
if getattr(self.config, "speech_mm_projector_cfg", None) is None and hasattr(self, "speech_mm_projector"):
self.config.speech_mm_projector_cfg = self.speech_mm_projector.config
if getattr(self.config, "sound_mm_projector_cfg", None) is None and hasattr(self, "sound_mm_projector"):
self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
def get_llm(self):
llm = getattr(self, "llm", None)
if type(llm) is list:
llm = llm[0]
return llm
def get_lm_head(self):
lm_head = getattr(self.get_llm(), "lm_head", None)
return lm_head
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def get_speech_tower(self):
speech_tower = getattr(self, "speech_tower", None)
if type(speech_tower) is list:
speech_tower = speech_tower[0]
return speech_tower
def get_sound_tower(self):
sound_tower = getattr(self, "sound_tower", None)
if type(sound_tower) is list:
sound_tower = sound_tower[0]
return sound_tower
def get_mm_projector(self):
mm_projector = getattr(self, "mm_projector", None)
if type(mm_projector) is list:
mm_projector = mm_projector[0]
return mm_projector
def get_sound_mm_projector(self):
sound_mm_projector = getattr(self, "sound_mm_projector", None)
if type(sound_mm_projector) is list:
sound_mm_projector = sound_mm_projector[0]
return sound_mm_projector
def get_speech_tower(self):
speech_tower = getattr(self, "speech_tower", None)
if type(speech_tower) is list:
speech_tower = speech_tower[0]
return speech_tower
def get_speech_mm_projector(self):
speech_mm_projector = getattr(self, "speech_mm_projector", None)
if type(speech_mm_projector) is list:
speech_mm_projector = speech_mm_projector[0]
return speech_mm_projector
def freezed_module_patch(self):
"""
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
"""
if self.training:
if self.get_llm() and not getattr(self.config, "tune_language_model", False):
pass
if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
self.get_vision_tower().eval()
if self.get_speech_tower() and not getattr(self.config, "tune_speech_tower", False):
self.get_speech_tower().eval()
if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False):
self.get_sound_tower().eval()
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
self.get_mm_projector().eval()
if self.get_speech_mm_projector() and not getattr(self.config, "tune_speech_mm_projector", False):
self.get_speech_mm_projector().eval()
if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False):
self.get_sound_mm_projector().eval()
class VILAForCausalLM(VILAPretrainedModel):
def __init__(self, config: VILAConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
def merge_features_for_dynamic_s2(self, image_features, block_sizes):
scales = self.get_vision_tower().scales
resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
image_features_each_image = []
new_block_sizes = []
block_cnt = 0
for block_size_each_image in block_sizes:
if block_size_each_image is None:
cur_features = image_features[block_cnt : block_cnt + 1]
cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
cur_features = cur_features.repeat(1, len(scales), 1, 1)
image_features_each_image.append(cur_features)
new_block_sizes.append((1, 1))
block_cnt += 1
else:
cur_features_each_scale = []
for scale in scales[:-1]:
num_blocks_this_scale = (scale // scales[0]) ** 2
cur_features_each_scale.append(
self.merge_chessboard(
image_features[block_cnt : block_cnt + num_blocks_this_scale],
num_split_h=scale // scales[0],
num_split_w=scale // scales[0],
)
) # 1 * C * H * W
block_cnt += num_blocks_this_scale
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
cur_features_each_scale.append(
self.merge_chessboard(
image_features[block_cnt : block_cnt + num_blocks_last_scale],
num_split_h=block_size_each_image[0],
num_split_w=block_size_each_image[1],
)
) # 1 * C * H * W
block_cnt += num_blocks_last_scale
# resize and concat features from different scales
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
cur_features = torch.cat(
[
F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
cur_features_each_scale[i].dtype
)
for i in range(len(cur_features_each_scale))
],
dim=1,
)
image_features_each_image.append(cur_features)
if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
new_block_sizes.append(block_size_each_image)
else:
new_block_sizes.append(
(
scales[resize_output_to_scale_idx] // scales[0],
scales[resize_output_to_scale_idx] // scales[0],
)
)
assert block_cnt == len(image_features)
return image_features_each_image, new_block_sizes
@staticmethod
def split_chessboard(x, num_split_h, num_split_w):
"""
x: b * c * h * w
out: b * c * h * w
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
"""
B, C, H, W = x.shape
assert H % num_split_h == 0 and W % num_split_w == 0
h, w = H // num_split_h, W // num_split_w
x_split = torch.cat(
[x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w] for i in range(num_split_h) for j in range(num_split_w)],
dim=0,
)
return x_split
@staticmethod
def merge_chessboard(x, num_split_h, num_split_w):
"""
x: b * n * c or b * h * w * c
out: b * c * h * w
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
"""
B = x.shape[0]
if x.dim() == 3:
N = x.shape[1]
x = rearrange(x, "b (h w) c -> b c h w", h=int(N**0.5), w=int(N**0.5))
assert B % (num_split_h * num_split_w) == 0
b = B // (num_split_h * num_split_w)
x_merge = torch.cat(
[
torch.cat(
[x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b] for j in range(num_split_w)], dim=-1
)
for i in range(num_split_h)
],
dim=-2,
)
return x_merge
def encode_video(self, inp, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None):
bs = len(inp)
cache_feas = []
cache_feas_index = []
inp_block_sizes = block_sizes
# handle cache features
for _idx in range(len(inp)):
if type(inp[_idx]) == CacheFeatures:
cache_feas.append(inp[_idx])
cache_feas_index.append(_idx)
raw_images = [_ for _ in inp if type(_) != CacheFeatures]
raw_videos_num_frames = [_.shape[0] for _ in raw_images]
if len(raw_images) > 0:
images = torch.cat(raw_images, dim=0)
else:
images = []
if block_sizes is None:
block_sizes = [None] * len(images)
def _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames):
# load cache features
if len(cache_feas) > 0:
if len(image_features) > 0:
image_features = torch.split(image_features, raw_videos_num_frames)
new_image_features = []
cache_feas_idx = 0
raw_fea_idx = 0
for _idx in range(bs):
if _idx in cache_feas_index:
new_image_features.append(cache_feas[cache_feas_idx].value['features'].to(self.device, self.dtype))
cache_feas_idx += 1
else:
new_image_features.append(image_features[raw_fea_idx])
raw_fea_idx += 1
assert len(new_image_features) == bs
image_features = new_image_features
image_features = torch.cat(image_features, dim=0)
return image_features
if getattr(self.config, "dynamic_s2", False):
if len(images) > 0:
image_features = self.get_vision_tower()(images)
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
image_features = [
self.split_chessboard(x, block_size[0], block_size[1])
for x, block_size in zip(image_features, new_block_sizes)
] # list of B * C * H * W tensors
image_features = torch.cat(
[rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
) # B * N * C
else:
image_features = []
# load cache features
image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames)
# if hasattr(self.config, "save_data") and self.config.save_data and num_frames is not None: # video
# _save_video_features(image_features, mm_info, inp)
if inp_block_sizes is None:
new_block_sizes = [(1, 1)] * len(image_features)
else:
raise ValueError(f"inp_block_sizes is not None: {inp_block_sizes}")
image_features = image_features.to(self.device, self.dtype)
image_features = self.get_mm_projector()(image_features)
image_features = list(
image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
)
image_features = [
self.merge_chessboard(x, block_size[0], block_size[1])
for x, block_size in zip(image_features, new_block_sizes)
] # list of 1 * C * H * W tensors
image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
image_features = torch.stack(image_features, dim=0)
else:
if len(images) > 0:
image_features = self.get_vision_tower()(images)
else:
image_features = []
# load cache features
image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames)
image_features = self.get_mm_projector()(image_features)
return image_features
def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None):
if block_sizes is None:
block_sizes = [None] * len(images)
if getattr(self.config, "dynamic_s2", False):
image_features = self.get_vision_tower()(images)
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
image_features = [
self.split_chessboard(x, block_size[0], block_size[1])
for x, block_size in zip(image_features, new_block_sizes)
] # list of B * C * H * W tensors
image_features = torch.cat(
[rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
) # B * N * C
image_features = self.get_mm_projector()(image_features)
image_features = list(
image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
)
image_features = [
self.merge_chessboard(x, block_size[0], block_size[1])
for x, block_size in zip(image_features, new_block_sizes)
] # list of 1 * C * H * W tensors
image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
image_features = torch.stack(image_features, dim=0)
else:
image_features = self.get_vision_tower()(images)
image_features = self.get_mm_projector()(image_features)
return image_features
def encode_sound(self, sounds, mm_info: Optional[dict] = None):
audio_features, audio_output_lengths = self.get_sound_tower()(sounds)
use_fea_downsample = False
if getattr(self.config, "sound_mm_projector", "") != "":
if "mlp_downsample" in getattr(self.config, "sound_mm_projector", ""):
use_fea_downsample = True
else:
sound_mm_projector_cfg = getattr(self.config, "sound_mm_projector_cfg", None)
if sound_mm_projector_cfg is not None:
if type(sound_mm_projector_cfg) == dict:
if "mlp_downsample" in sound_mm_projector_cfg["sound_mm_projector_type"]:
use_fea_downsample = True
elif type(sound_mm_projector_cfg) == SoundMultimodalProjectorConfig:
if "mlp_downsample" in sound_mm_projector_cfg.sound_mm_projector_type:
use_fea_downsample = True
if not use_fea_downsample:
audio_features = self.get_sound_mm_projector()(audio_features)
if audio_output_lengths is not None:
# split the batch
new_audio_features = []
start = 0
for length in audio_output_lengths:
new_audio_features.append(audio_features[start : start + length])
start += length
audio_features = new_audio_features
if use_fea_downsample:
audio_features = torch.stack(audio_features, dim=0)
audio_features = self.get_sound_mm_projector()(audio_features)
return audio_features
def train(self, mode: bool = True):
super().train(mode)
return self
def _embed(
self,
input_ids: torch.Tensor,
media: Dict[str, List[torch.Tensor]],
media_config: Dict[str, Dict[str, Any]],
labels: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
media = copy.deepcopy(media)
media_config = copy.deepcopy(media_config)
labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
PROCESS_GROUP_MANAGER = None
if PROCESS_GROUP_MANAGER is not None:
for name in media:
self.encoders[name].end_tokens = None
# Extract text and media embeddings
text_embeds = self.llm_model_embed_tokens(input_ids)
mm_info = {}
if "video_info" in media:
video_info = media["video_info"]
del media["video_info"]
mm_info['video_info'] = video_info
else:
video_info = None
if "audio_info" in media:
audio_info = media["audio_info"]
del media["audio_info"]
mm_info['audio_info'] = audio_info
else:
audio_info = None
if media is not None:
media_embeds = self.__embed_media_tokens(media, media_config, mm_info)
else:
# no media was provided, so we just return an empty dict
media_embeds = {}
if PROCESS_GROUP_MANAGER is not None:
media_embeds_video = []
for i, images in enumerate(media_embeds["video"]):
num_video_frame = media["video"][i].shape[0]
media_embeds_video += torch.unbind(images.reshape(num_video_frame, -1, images.shape[-1]))
media_embeds["video"] = deque(media_embeds_video)
# This is a workaround to make sure the dummy embeddings are consumed
while media_embeds.get("dummy"):
dummy_embed = media_embeds["dummy"].popleft()
text_embeds += torch.sum(dummy_embed) * 0
# Based on segment_aud_indices_list and segment_vis_indices_list, get interleaved vis-aud embeddings for video
video_sound_embeds_idx = 0
sep_embed = self.encoders["video"].embed_tokens("\n")
text_embeds = text_embeds.to(self.dtype)
sep_embed = sep_embed.to(text_embeds.dtype)
if video_info is not None and self.config.load_audio_in_video and self.config.interleaved_vis_aud_in_video:
assert self.encoders["video"].end_tokens is None, "end_tokens must be None for interleaved vis-aud in video"
new_video_embeds = deque()
video_embeds_idx = 0
for k in range(len(video_info)):
if video_info[k] is None:
continue
for i in range(len(video_info[k])):
has_audio = video_info[k][i]["has_audio"]
if not has_audio:
new_video_embeds.append(media_embeds["video"][video_embeds_idx])
video_embeds_idx += 1
continue
# Check bounds for sound embeddings
if video_sound_embeds_idx >= len(media_embeds["sound"]):
raise ValueError(f"Sound embeddings index {video_sound_embeds_idx} out of bounds for video_info[{k}][{i}]")
segment_aud_indices_list = video_info[k][i]["segment_aud_indices_list"]
segment_vis_indices_list = video_info[k][i]["segment_vis_indices_list"]
vis_fea_len_per_frame = media_embeds["video"][video_embeds_idx].shape[0] / video_info[k][i]["expected_frame_count"]
aud_fea_len_per_stft_frame = media_embeds["sound"][video_sound_embeds_idx].shape[0] / audio_info[k][i]["new_audio_n_stft_frames"]
vis_end = 0
aud_end = 0
_new_video_embed = []
for j in range(len(segment_vis_indices_list)):
_vis_aud_fea = []
if len(segment_vis_indices_list[j]) > 0:
_new_frames = [int(np.ceil((_frame+1) * vis_fea_len_per_frame)) for _frame in segment_vis_indices_list[j]]
_vis_fea_end = _new_frames[-1]
# Ensure we don't exceed the available features
_vis_fea_end = min(_vis_fea_end, media_embeds["video"][video_embeds_idx].shape[0])
if j == len(segment_vis_indices_list) - 1 and i == len(video_info) - 1 and k == len(video_info[i]) - 1 and not _vis_fea_end == media_embeds["video"][video_embeds_idx].shape[0]:
print(f"Warning: The number of last interleaved video features does not match the video feature length. Expected: {media_embeds['video'][video_embeds_idx].shape[0]}, Got: {_vis_fea_end}")
_vis_fea_end = media_embeds["video"][video_embeds_idx].shape[0]
_vis_fea = media_embeds["video"][video_embeds_idx][vis_end:_vis_fea_end]
vis_end = _vis_fea_end
_vis_aud_fea.append(_vis_fea)
_vis_aud_fea.append(sep_embed)
if len(segment_aud_indices_list[j]) > 0:
_new_audio_indices = [int(np.ceil(_fea * aud_fea_len_per_stft_frame)) for _fea in segment_aud_indices_list[j]]
_aud_fea_end = _new_audio_indices[-1]
# Ensure we don't exceed the available features
_aud_fea_end = min(_aud_fea_end, media_embeds["sound"][video_sound_embeds_idx].shape[0])
_aud_fea = media_embeds["sound"][video_sound_embeds_idx][aud_end:_aud_fea_end]
_vis_aud_fea.append(_aud_fea)
aud_end = _aud_fea_end
_vis_aud_fea.append(sep_embed)
_new_video_embed.append(torch.cat(_vis_aud_fea, dim=0))
video_sound_embeds_idx += 1
new_video_embeds.append(torch.cat(_new_video_embed, dim=0))
video_embeds_idx += 1
assert len(new_video_embeds) == len(media_embeds["video"]), "The number of new video embeddings does not match the number of original video embeddings."
media_embeds["video"] = new_video_embeds
# Remove padding
batch_size = labels.shape[0]
text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
# Build inverse mapping from token ID to media name
media_tokens = {}
for name, token_id in self.tokenizer.media_token_ids.items():
media_tokens[token_id] = name
# Fuse text and media embeddings
inputs_m, labels_m = [], []
sound_embeds_idx = 0
for k in range(batch_size):
inputs_mk, labels_mk = [], []
pos = 0
while pos < len(labels[k]):
if input_ids[k][pos].item() in media_tokens:
name = media_tokens[input_ids[k][pos].item()] if PROCESS_GROUP_MANAGER is None else "video"
if input_ids[k][pos].item() == self.tokenizer.media_token_ids["sound"]:
if self.config.interleaved_vis_aud_in_video:
if sound_embeds_idx < video_sound_embeds_idx:
media_embeds[name].popleft()
sound_embeds_idx += 1
pos += 1
continue
sound_embeds_idx += 1
end = pos + 1
input = media_embeds[name].popleft()
label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
else:
end = pos
while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
end += 1
input = text_embeds[k][pos:end]
label = labels[k][pos:end]
inputs_mk.append(input)
labels_mk.append(label)
pos = end
inputs_m.append(torch.cat(inputs_mk, dim=0))
labels_m.append(torch.cat(labels_mk, dim=0))
inputs, labels = inputs_m, labels_m
inputs[0] += sep_embed.mean() * 0 # dummy embedding
# Check if all media embeddings are consumed
for name in media_embeds:
if media_embeds[name]:
raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.")
# Truncate sequences to `model_max_length` as media embeddings are inserted
inputs, labels = self.__truncate_sequence(inputs, labels)
# Pad sequences to the longest one in the batch
return self.__batchify_sequence(inputs, labels)
def __embed_media_tokens(
self,
media: Dict[str, List[torch.Tensor]],
media_config: Dict[str, Dict[str, Any]],
mm_info,
) -> Dict[str, List[torch.Tensor]]:
embeds = defaultdict(deque)
if self.config.unified_audio_encoder:
assert len(media["speech"]) == 0
for name in media:
_encoder = self.encoders[name]
if name in ["speech", "sound"] and self.config.unified_audio_encoder:
_encoder = self.encoders["sound"]
if self.training:
# Gather metainfo of media objects from all ranks
if name in ["speech", "sound"]:
info = []
if type(media.get(name, {})) is dict:
for _dict in media.get(name, {}):
info.append({k: {"shape": v.shape, "dtype": v.dtype} for k, v in _dict.items()})
elif type(media.get(name, {})) is list:
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
else:
raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}")
infos_list = vila_all_gather(info)
infos = list(chain(*infos_list))
# The entire batch does not contain any media objects of this type.
if not infos:
continue
# for audio encoding, we have to ensure the batch size is the same for all ranks. If not, we need to pad the batch with dummy tensors to the max batch size
max_batch_size = max(len(_info) for _info in infos_list)
missing_batch_size = max_batch_size - len(info)
_media = media.get(name, [])
_medias = list(chain(vila_all_gather(_media)))
if missing_batch_size > 0:
for i in range(missing_batch_size):
# use one of the media tensors to create a dummy tensor
if type(media.get(name, {})) is dict:
_dummy = {k: v.clone().to(device=self.device) for k, v in _medias[0].items()}
elif type(media.get(name, {})) is list:
if type(_medias[0]) is torch.Tensor:
_dummy = _medias[0].clone().to(device=self.device)
elif type(_medias[0]) is np.ndarray:
_dummy = _medias[0].copy()
else:
raise ValueError(f"Unsupported media type: {type(_medias[0])}")
else:
raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}")
_media.append(_dummy)
mm_info["audio_info"].append(["dummy"])
# we need to also align the length of all audio samples in the batch size
cur_batch_max_audio_samples = max(len(_audio) for _audio in _medias)
cur_batch_max_audio_samples = int(np.ceil(cur_batch_max_audio_samples / (self.config.audio_sampling_rate * 30)) * (self.config.audio_sampling_rate * 30)) # should be multiple of 30 seconds
cur_batch_max_audio_samples = min(cur_batch_max_audio_samples, self.config.audio_chunk_length * self.config.audio_sampling_rate)
cur_batch_max_audio_duration = cur_batch_max_audio_samples // self.config.audio_sampling_rate
whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(
self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length
)
# use WhisperFeatureExtractor in transformers to load
new_media = []
aud_idx = 0
for _batch_idx in range(len(mm_info["audio_info"])):
_audio_info = mm_info["audio_info"][_batch_idx]
if _audio_info is not None:
for _mm_idx in range(len(_audio_info)):
_audio = _media[aud_idx]
if type(_audio) is torch.Tensor:
device = _audio.device
dtype = _audio.dtype
_audio = _audio.cpu().float()
else:
# logger.warning(f"The audio type is not a tensor, which is unexpected. Using the device and dtype of the model. media: {media}, mm_info: {mm_info}")
device = self.device
dtype = self.dtype
_audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples)
aud_idx += 1
stft_features = whisper_feature_extractor(
_audio,
sampling_rate=self.config.audio_sampling_rate,
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
).to(device, dtype)
new_media.append(stft_features)
if _audio_info[_mm_idx] != "dummy":
_audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration
_audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples
_audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration
_audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1]
assert aud_idx == len(_media), "The number of audio info does not match the number of audio samples."
_media = new_media
_fea = _encoder(_media, media_config[name], mm_info)
# [751, 1536]
# consume dummy features later
_dummy_fea = _fea[len(info) :]
embeds["dummy"].extend(_dummy_fea)
# remove the dummy features
_real_fea = _fea[: len(info)]
if len(info) > 0:
embeds[name] = deque(_real_fea)
else:
# Gather metainfo of media objects from all ranks
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
infos = list(chain(vila_all_gather(info)))
# The entire batch does not contain any media objects of this type.
if not infos:
continue
# Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
if media.get(name) is None or len(media[name]) == 0:
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
continue
embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
else:
if name == "sound":
all_audio_chunk_lengths = []
for _sample_idx in range(len(media[name])):
for _mm_idx in range(len(mm_info["audio_info"][_sample_idx])):
_new_audio_chunk_length = mm_info["audio_info"][_sample_idx][_mm_idx]["new_audio_chunk_length"]
all_audio_chunk_lengths.append(_new_audio_chunk_length)
cur_batch_max_audio_duration = max(all_audio_chunk_lengths)
cur_batch_max_audio_samples = cur_batch_max_audio_duration * self.config.audio_sampling_rate
# for qwen omni audio
# cur_batch_max_audio_samples = 960000
whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(
self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length
)
new_media = []
_idx = 0
assert len(all_audio_chunk_lengths) == len(media[name]), "The number of audio chunk lengths does not match the number of audio samples."
_media = media.get(name, [])
aud_idx = 0
for _batch_idx in range(len(mm_info["audio_info"])):
_audio_info = mm_info["audio_info"][_batch_idx]
if _audio_info is not None:
for _mm_idx in range(len(_audio_info)):
_audio = _media[aud_idx]
if type(_audio) is torch.Tensor:
device = _audio.device
dtype = _audio.dtype
_audio = _audio.cpu().float()
else:
device = self.device
dtype = self.dtype
_audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples)
aud_idx += 1
stft_features = whisper_feature_extractor(
_audio,
sampling_rate=self.config.audio_sampling_rate,
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
).to(device, dtype)
new_media.append(stft_features)
if _audio_info[_mm_idx] != "dummy":
_audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration
_audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples
_audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration
_audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1]
media[name] = new_media
if len(media[name]) > 0:
embeds[name] = deque(_encoder(media[name], media_config[name], mm_info))
return embeds
def __truncate_sequence(
self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
labels = [label[: self.tokenizer.model_max_length] for label in labels]
return inputs, labels
def __batchify_sequence(
self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = len(inputs)
device = inputs[0].device
hidden_size = inputs[0].shape[1]
max_length = max(inputs[k].shape[0] for k in range(batch_size))
attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
inputs_p, labels_p = [], []
for k in range(batch_size):
size_pk = max_length - inputs[k].shape[0]
inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
if self.tokenizer.padding_side == "right":
attention_mask[k, inputs[k].shape[0] :] = False
inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
labels_pk = torch.cat([labels[k], labels_pk], dim=0)
else:
labels[k] = labels[k].to(device)
attention_mask[k, : -inputs[k].shape[0]] = False
inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
inputs_p.append(inputs_pk)
labels_p.append(labels_pk)
inputs = torch.stack(inputs_p, dim=0)
labels = torch.stack(labels_p, dim=0)
return inputs, labels, attention_mask
def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
# Handle sequence parallelism
PROCESS_GROUP_MANAGER = None
# We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
if PROCESS_GROUP_MANAGER is not None:
sp_degree = PROCESS_GROUP_MANAGER.sp_degree
sp_rank = PROCESS_GROUP_MANAGER.sp_rank
sp_group = PROCESS_GROUP_MANAGER.sp_pg
ring_degree = PROCESS_GROUP_MANAGER.ring_degree
ring_rank = PROCESS_GROUP_MANAGER.ring_rank
ring_type = PROCESS_GROUP_MANAGER.ring_type
ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
bs, shard_seqlen = position_ids.shape
sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
if sp_rank == 0:
original_start_id = 0
else:
original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
# Gather attention_mask, position_ids, labels and input_embeds
all_inputs_embeds = torch.zeros(
bs,
torch.sum(sp_seq_len_cat),
inputs_embeds.shape[-1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
).contiguous()
all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
dist.barrier(group=sp_group)
dist.all_reduce(all_inputs_embeds, group=sp_group)
dist.barrier(group=sp_group)
attention_mask_list = [
torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
for i in range(sp_degree)
]
position_ids_list = [
torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
for i in range(sp_degree)
]
labels_list = [
torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
]
dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
dist.all_gather(position_ids_list, position_ids, group=sp_group)
dist.all_gather(labels_list, labels, group=sp_group)
effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
global_attention_mask_list = []
global_position_ids_list = []
global_labels_list = []
global_inputs_embeds_list = []
for i in range(bs):
global_attention_mask_batch_list = []
global_position_ids_batch_list = []
global_labels_batch_list = []
global_inputs_embeds_batch_list = []
for j in range(sp_degree):
eff_len = effective_seqlen_batch_list[i][j]
prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
global_labels_batch_list.append(labels_list[j][i, :eff_len])
global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
global_attention_mask = torch.nn.utils.rnn.pad_sequence(
global_attention_mask_list, batch_first=True, padding_value=False
)
global_position_ids = torch.nn.utils.rnn.pad_sequence(
global_position_ids_list, batch_first=True, padding_value=-1
)
global_labels = torch.nn.utils.rnn.pad_sequence(
global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
)
global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
global_inputs_embeds_list, batch_first=True, padding_value=0
)
# Re-shard the inputs
if ring_degree > 1:
total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
new_seqlen_per_rank = total_effective_seqlen // sp_degree
assert torch.all(
total_effective_seqlen % sp_degree == 0
), "total_effective_seqlen must be divisible by sp_degree"
max_new_seqlen = torch.max(new_seqlen_per_rank).item()
new_attention_mask = torch.zeros(
(bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
)
new_position_ids = torch.zeros(
(bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
)
new_labels = torch.full(
(bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
)
new_inputs_embeds = torch.zeros(
(bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
dtype=global_inputs_embeds.dtype,
device=global_inputs_embeds.device,
)
if ring_type == "ring_varlen":
for i in range(bs):
start_idx = new_seqlen_per_rank[i] * sp_rank
end_idx = start_idx + new_seqlen_per_rank[i]
new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
i, start_idx:end_idx, :
]
elif ring_type == "zigzag_ring_varlen":
chunk_size = total_effective_seqlen // (2 * sp_degree)
for i in range(bs):
# Zigzag pattern indices
if sp_degree == ring_degree:
forward_rank_idx = sp_rank
backward_rank_idx = 2 * sp_degree - sp_rank - 1
else:
ulysses_offset = ulysses_rank * ring_degree * 2
forward_rank_idx = ring_rank + ulysses_offset
backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
# Calculate start and end indices for the forward and backward zigzag
start_idx_fwd = forward_rank_idx * chunk_size[i]
end_idx_fwd = start_idx_fwd + chunk_size[i]
start_idx_bwd = backward_rank_idx * chunk_size[i]
end_idx_bwd = start_idx_bwd + chunk_size[i]
# Fill new tensors with zigzag data
new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
i, start_idx_bwd:end_idx_bwd
]
new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
i, start_idx_bwd:end_idx_bwd
]
new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
i, start_idx_bwd:end_idx_bwd, :
]
else:
raise ValueError(f"Invalid ring_type: {ring_type}")
else:
global_seq_len = global_attention_mask.shape[-1]
seq_len_sharded = global_seq_len // sp_degree
start_idx_reshard = seq_len_sharded * sp_rank
end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
new_attention_mask = torch.narrow(
global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
)
new_position_ids = torch.narrow(
global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
)
new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
new_inputs_embeds = torch.narrow(
global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
)
return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
device = inputs_embeds.device
batch_size = inputs_embeds.shape[0]
seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
# Pack all sequences together
inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
# Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
# Mask the first token of each sequence to avoid contamination
for label in labels_p:
label[0] = IGNORE_INDEX
# Batch the data
inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
if hasattr(
self, "pad_to_multiple_of"
): # related to quantization, please refer to ModelArguments for more information.
assert len(labels_p.shape) == 2
batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
hidden_size = inputs_embeds_p.shape[-1]
if max_length % self.pad_to_multiple_of != 0:
max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
difference = max_length - cur_length
inputs_embeds_p = torch.cat(
(
inputs_embeds_p,
torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
),
dim=1,
)
labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
attention_mask_p = torch.cat(
(
attention_mask_p,
torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
),
dim=1,
)
position_ids_p = torch.cat(
(position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
)
return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
def forward(
self,
input_ids: torch.LongTensor = None,
media: Optional[Dict[str, List[torch.Tensor]]] = None,
images: Optional[torch.FloatTensor] = None,
media_config: Optional[List] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
packing: bool = True,
force_packing: bool = False,
seqlens_in_batch: Optional[torch.LongTensor] = None,
dpo_forward: bool = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
self.freezed_module_patch()
if images is not None:
if media is not None:
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
print("The 'images' argument is deprecated. Please use 'media' instead.")
media = {"image": images}
if media_config is None:
media_config = defaultdict(dict)
if inputs_embeds is None:
inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
if force_packing or (packing and self.training and not dpo_forward):
if seqlens_in_batch is None:
seqlens_in_batch = torch.sum(attention_mask, dim=1)
set_seqlens_in_batch(seqlens_in_batch)
(inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
inputs_embeds, attention_mask, position_ids, labels
)
outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
labels=labels,
**kwargs,
)
if self.training and getattr(self.config, "time_token_ids", []):
outputs.loss = soft_cross_entropy(
outputs.logits,
labels,
soft_tokens=self.config.time_token_ids,
std=self.config.soft_ce_std,
)
if dpo_forward:
return outputs.logits, labels
return outputs
@torch.inference_mode()
def generate(
self,
input_ids: Optional[torch.FloatTensor] = None,
media: Optional[Dict[str, List[torch.Tensor]]] = None,
media_config: Dict[str, Dict[str, Any]] = None,
attention_mask: Optional[torch.LongTensor] = None,
return_output_ids_only: bool = True,
**generation_kwargs,
) -> torch.LongTensor:
"""
input_tokens: <image> describe the image
media: [Tensor(1, 3, 384, 384), ]
----------->
input_tokens: 36000 001 002 003 004
input_emds: <media emd> 001 002 003 004
"""
inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
if return_output_ids_only:
return_value = output_ids
else:
# by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
generation_config = generation_kwargs.get("generation_config", None)
if generation_config is not None:
num_generations = generation_config.num_return_sequences
repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
else:
return_value = torch.cat([input_ids, output_ids], dim=-1)
return return_value
@torch.inference_mode()
def generate_content(
self,
prompt: Union[str, List],
generation_config: Optional[GenerationConfig] = None,
response_format=None,
) -> str:
conversation = [{"from": "human", "value": prompt}]
# Convert response format to logits processor
xgr_logits_processor = None
# Extract media from the conversation
media = extract_media(conversation, self.config)
# Process media
media_config = defaultdict(dict)
for name in media:
if name == "image":
if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
self.config.image_processor = self.vision_tower.image_processor
if self.config.image_aspect_ratio == "dynamic":
images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
conversation[0]["value"] = conversation[0]["value"].replace(
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
)
else:
if type(self.config.s2_scales) is str:
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
images, block_sizes = process_image(
media["image"][0], self.config, None, enable_dynamic_s2=True
)
images = images.half()
media_config[name]["block_sizes"] = [block_sizes]
else:
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
media[name] = [image for image in images]
elif name == "video":
if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
media[name] = [
process_images(
images,
self.vision_tower.image_processor,
self.config,
enable_dynamic_res=True,
max_tiles=self.config.video_max_tiles,
).half()
for images in media[name]
]
elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
self.config.image_processor = self.vision_tower.image_processor
if type(self.config.s2_scales) is str:
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
media[name] = [
torch.cat(
[
process_image(
image,
self.config,
None,
enable_dynamic_s2=True,
max_tiles=self.config.video_max_tiles,
)[0].half()
for image in images
]
)
for images in media[name]
]
else:
media[name] = [
process_images(images, self.vision_tower.image_processor, self.config)
for images in media[name]
]
elif name == "speech":
speeches = media["speech"]
media[name] = [speech for speech in speeches]
elif name == "sound":
# sounds = process_sounds(media["sound"]).half()
sounds = media["sound"]
# media[name] = [{k: v.half() for sound in sounds for k, v in sound.items()]
for sound in sounds:
if type(sound) is dict:
for k, v in sound.items():
sound[k] = v.half()
media[name] = [sound for sound in sounds]
elif name == "video_info":
media[name] = [media["video_info"]]
elif name == "audio_info":
media[name] = [media["audio_info"]]
else:
raise ValueError(f"Unsupported media type: {name}")
# Tokenize the conversation
input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
# Set up the generation config
generation_config = generation_config or self.default_generation_config
# Generate the response
try:
output_ids = self.generate(
input_ids=input_ids,
media=media,
media_config=media_config,
generation_config=generation_config,
logits_processor=xgr_logits_processor, # structured generation
)
except ValueError:
if not generation_config.do_sample:
raise
logging.warning("Generation failed with sampling, retrying with greedy decoding.")
generation_config.do_sample = False
output_ids = self.generate(
input_ids=input_ids,
media=media,
media_config=media_config,
generation_config=generation_config,
logits_processor=xgr_logits_processor,
)
# Decode the response
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
return response
@property
def default_generation_config(self) -> GenerationConfig:
generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
if self.tokenizer.eos_token_id is None:
raise ValueError("Tokenizer must have an EOS token")
if generation_config.max_length == GenerationConfig().max_length:
generation_config.max_length = self.tokenizer.model_max_length
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.tokenizer.eos_token_id
return generation_config